use super::{ConnInitSettings, ConnectError, ConnectionReject};
use crate::{
accesscontrol::StreamAcceptor,
crypto::CryptoManager,
packet::{
HSV5Info, HandshakeControlInfo, HandshakeVSInfo, ServerRejectReason, SrtControlPacket,
SrtHandshake, SrtShakeFlags,
},
ConnectionSettings, SrtVersion,
};
use std::{
net::SocketAddr,
time::{Duration, Instant},
};
pub enum GenHsv5Result {
Accept(HandshakeVSInfo, ConnectionSettings),
NotHandled(ConnectError),
Reject(ConnectionReject),
}
pub fn gen_hsv5_response(
settings: &mut ConnInitSettings,
with_hsv5: &HandshakeControlInfo,
from: SocketAddr,
acceptor: &mut impl StreamAcceptor,
) -> GenHsv5Result {
let incoming = match &with_hsv5.info {
HandshakeVSInfo::V5(hs) => hs,
_ => {
return GenHsv5Result::Reject(ConnectionReject::Rejecting(
ServerRejectReason::Version.into(), ));
}
};
let mut accept_params = match acceptor.accept(incoming.sid.as_deref(), from) {
Ok(ap) => ap,
Err(rr) => return GenHsv5Result::Reject(ConnectionReject::Rejecting(rr)),
};
if let Some(co) = accept_params.take_crypto_options() {
settings.crypto = Some(co);
}
let hs = match incoming.ext_hs {
Some(SrtControlPacket::HandshakeRequest(hs)) => hs,
Some(_) => return GenHsv5Result::NotHandled(ConnectError::ExpectedHSReq),
None => return GenHsv5Result::NotHandled(ConnectError::ExpectedExtFlags),
};
let cm = match (&settings.crypto, &incoming.ext_km) {
(Some(co), Some(SrtControlPacket::KeyManagerRequest(km))) => {
if co.size != incoming.crypto_size {
unimplemented!("Key size mismatch");
}
Some(match CryptoManager::new_from_kmreq(co.clone(), km) {
Ok(cm) => cm,
Err(rr) => return GenHsv5Result::Reject(rr),
})
}
(None, None) => None,
(Some(_), Some(_)) => unimplemented!("Expected kmreq"),
(Some(_), None) | (None, Some(_)) => unimplemented!("Crypto mismatch"),
};
let outgoing_ext_km = if let Some(cm) = &cm {
Some(cm.generate_km())
} else {
None
};
let sid = if let HandshakeVSInfo::V5(info) = &with_hsv5.info {
info.sid.clone()
} else {
None
};
GenHsv5Result::Accept(
HandshakeVSInfo::V5(HSV5Info {
crypto_size: cm.as_ref().map(|c| c.key_length()).unwrap_or(0),
ext_hs: Some(SrtControlPacket::HandshakeResponse(SrtHandshake {
version: SrtVersion::CURRENT,
flags: SrtShakeFlags::SUPPORTED,
send_latency: settings.send_latency,
recv_latency: settings.recv_latency,
})),
ext_km: outgoing_ext_km.map(SrtControlPacket::KeyManagerResponse),
sid,
}),
ConnectionSettings {
remote: from,
remote_sockid: with_hsv5.socket_id,
local_sockid: settings.local_sockid,
socket_start_time: Instant::now(), init_send_seq_num: settings.starting_send_seqnum,
init_recv_seq_num: with_hsv5.init_seq_num,
max_packet_size: 1500, max_flow_size: 8192,
send_tsbpd_latency: Duration::max(settings.send_latency, hs.recv_latency),
recv_tsbpd_latency: Duration::max(settings.recv_latency, hs.send_latency),
crypto_manager: cm,
stream_id: incoming.sid.clone(),
},
)
}
#[derive(Debug, Clone)] pub struct StartedInitiator {
cm: Option<CryptoManager>,
settings: ConnInitSettings,
streamid: Option<String>,
}
pub fn start_hsv5_initiation(
settings: ConnInitSettings,
streamid: Option<String>,
) -> (HandshakeVSInfo, StartedInitiator) {
let self_crypto_size = settings.crypto.as_ref().map(|co| co.size).unwrap_or(0);
let (cm, ext_km) = if let Some(co) = &settings.crypto {
let cm = CryptoManager::new_random(co.clone());
let kmreq = SrtControlPacket::KeyManagerRequest(cm.generate_km());
(Some(cm), Some(kmreq))
} else {
(None, None)
};
(
HandshakeVSInfo::V5(HSV5Info {
crypto_size: self_crypto_size,
ext_hs: Some(SrtControlPacket::HandshakeRequest(SrtHandshake {
version: SrtVersion::CURRENT,
flags: SrtShakeFlags::SUPPORTED,
send_latency: settings.send_latency,
recv_latency: settings.recv_latency,
})),
ext_km,
sid: streamid.clone(),
}),
StartedInitiator {
cm,
settings,
streamid,
},
)
}
impl StartedInitiator {
pub fn finish_hsv5_initiation(
self,
response: &HandshakeControlInfo,
from: SocketAddr,
) -> Result<ConnectionSettings, ConnectError> {
let incoming = match &response.info {
HandshakeVSInfo::V5(hs) => hs,
i => return Err(ConnectError::UnsupportedProtocolVersion(i.version())),
};
let hs = match incoming.ext_hs {
Some(SrtControlPacket::HandshakeResponse(hs)) => hs,
Some(_) => return Err(ConnectError::ExpectedHSResp),
None => return Err(ConnectError::ExpectedExtFlags),
};
Ok(ConnectionSettings {
remote: from,
remote_sockid: response.socket_id,
local_sockid: self.settings.local_sockid,
socket_start_time: Instant::now(), init_send_seq_num: self.settings.starting_send_seqnum,
init_recv_seq_num: response.init_seq_num,
max_packet_size: 1500, max_flow_size: 8192,
send_tsbpd_latency: Duration::max(self.settings.send_latency, hs.recv_latency),
recv_tsbpd_latency: Duration::max(self.settings.recv_latency, hs.send_latency),
crypto_manager: self.cm,
stream_id: self.streamid,
})
}
}