use std::time::Duration;
use {
futures::StreamExt,
quinn::{Connection, IncomingBiStreams, RecvStream, SendStream},
tokio::time::timeout,
};
const HANDSHAKE_TIMEOUT: Duration = Duration::from_millis(1000);
pub async fn verify_handshake_server(
connection: &Connection,
protocol_checksum: u32,
) -> Option<(SendStream, RecvStream)> {
let checksum_data = protocol_checksum.to_be_bytes();
let (mut send, mut recv) = connection.open_bi().await.ok()?;
send.write_all(&checksum_data).await.ok()?;
let mut client_checksum_data = [0u8; 4];
timeout(
HANDSHAKE_TIMEOUT,
recv.read_exact(&mut client_checksum_data),
)
.await
.ok()?
.ok()?;
let client_checksum = u32::from_be_bytes(client_checksum_data);
if protocol_checksum != client_checksum {
return None;
}
Some((send, recv))
}
pub async fn verify_handshake_client(
bi_streams: &mut IncomingBiStreams,
protocol_checksum: u32,
) -> Option<(SendStream, RecvStream)> {
let stream_result = timeout(HANDSHAKE_TIMEOUT, bi_streams.next()).await.ok()?;
let (mut send, mut recv) = stream_result?.ok()?;
let mut server_checksum_data = [0u8; 4];
timeout(
HANDSHAKE_TIMEOUT,
recv.read_exact(&mut server_checksum_data),
)
.await
.ok()?
.ok()?;
let server_checksum = u32::from_be_bytes(server_checksum_data);
if protocol_checksum != server_checksum {
return None;
}
send.write_all(&server_checksum_data).await.ok()?;
Some((send, recv))
}