use std::{
cmp::{max, min},
net::SocketAddr,
time::Instant,
};
use crate::{connection::ConnectionSettings, options::*, packet::*, settings::*};
use super::{ConnectError, ConnectionReject};
#[allow(clippy::large_enum_variant)]
pub enum GenHsv5Result {
Accept(HandshakeVsInfo, ConnectionSettings),
NotHandled(ConnectError),
Reject(ConnectionReject),
}
pub fn gen_hsv5_response(
settings: &mut ConnInitSettings,
with_hsv5: &HandshakeControlInfo,
from: SocketAddr,
induction_time: Instant,
now: Instant,
) -> GenHsv5Result {
let incoming = match &with_hsv5.info {
HandshakeVsInfo::V5(hs) => hs,
_ => {
return GenHsv5Result::Reject(ConnectionReject::Rejecting(
ServerRejectReason::Version.into(), ));
}
};
gen_access_control_response(
now,
settings,
from,
induction_time,
with_hsv5.clone(),
incoming.clone(),
None,
)
}
pub fn gen_access_control_response(
now: Instant,
settings: &mut ConnInitSettings,
from: SocketAddr,
induction_time: Instant,
with_hsv5: HandshakeControlInfo,
incoming: HsV5Info,
key_settings: Option<KeySettings>,
) -> GenHsv5Result {
if let Some(ks) = key_settings {
settings.key_settings = Some(ks);
}
let hs = match incoming.ext_hs {
Some(SrtControlPacket::HandshakeRequest(hs)) => hs,
Some(_) => return GenHsv5Result::NotHandled(ConnectError::ExpectedHsReq),
None => return GenHsv5Result::NotHandled(ConnectError::ExpectedExtFlags),
};
let cipher = match (&settings.key_settings, &incoming.ext_km) {
(Some(key_settings), Some(SrtControlPacket::KeyRefreshRequest(km))) => {
if key_settings.key_size.as_usize() != incoming.crypto_size as usize {
unimplemented!("Key size mismatch");
}
let cipher = match CipherSettings::new(key_settings, &settings.key_refresh, km) {
Ok(cm) => cm,
Err(_) => {
return GenHsv5Result::Reject(ConnectionReject::Rejecting(
CoreRejectReason::BadSecret.into(),
))
}
};
Some(cipher)
}
(None, None) => None,
(Some(_), Some(_)) => unimplemented!("Expected kmreq"),
(Some(_), None) => {
return GenHsv5Result::Reject(ConnectionReject::Rejecting(
CoreRejectReason::Unsecure.into(),
))
}
(None, Some(_)) => unimplemented!("expected no secrets"),
};
let outgoing_ext_km = cipher
.as_ref()
.and_then(CipherSettings::wrap_keying_material);
let sid = if let HandshakeVsInfo::V5(info) = &with_hsv5.info {
info.sid.clone()
} else {
None
};
let rtt = now - induction_time;
GenHsv5Result::Accept(
HandshakeVsInfo::V5(HsV5Info {
crypto_size: cipher
.as_ref()
.map(|c| c.key_settings.key_size.as_usize())
.unwrap_or(0) as u8,
ext_hs: Some(SrtControlPacket::HandshakeResponse(SrtHandshake {
version: SrtVersion::CURRENT,
flags: SrtShakeFlags::SUPPORTED,
send_latency: settings.send_latency,
recv_latency: settings.recv_latency,
})),
ext_km: outgoing_ext_km.map(SrtControlPacket::KeyRefreshResponse),
ext_group: None,
sid,
}),
ConnectionSettings {
remote: from,
rtt,
socket_start_time: now - rtt / 2, remote_sockid: with_hsv5.socket_id,
init_seq_num: with_hsv5.init_seq_num,
cipher,
stream_id: incoming.sid,
max_flow_size: max(settings.max_flow_size, with_hsv5.max_flow_size),
max_packet_size: min(settings.max_packet_size, with_hsv5.max_packet_size),
send_tsbpd_latency: max(settings.send_latency, hs.recv_latency),
recv_tsbpd_latency: max(settings.recv_latency, hs.send_latency),
bandwidth: settings.bandwidth.clone(),
local_sockid: settings.local_sockid,
recv_buffer_size: settings.recv_buffer_size,
send_buffer_size: settings.send_buffer_size,
statistics_interval: settings.statistics_interval,
},
)
}
#[derive(Debug, Clone)] pub struct StartedInitiator {
cipher: Option<CipherSettings>,
settings: ConnInitSettings,
streamid: Option<String>,
initiate_time: Instant,
}
pub fn start_hsv5_initiation(
settings: ConnInitSettings,
streamid: Option<String>,
now: Instant,
) -> (HandshakeVsInfo, StartedInitiator) {
let self_crypto_size = settings
.key_settings
.as_ref()
.map(|key_settings| key_settings.key_size.as_usize() as u8)
.unwrap_or(0);
let (cipher, ext_km) = if let Some(ks) = &settings.key_settings {
let cipher = CipherSettings::new_random(ks, &settings.key_refresh);
let keying_material = cipher
.wrap_keying_material()
.map(SrtControlPacket::KeyRefreshRequest);
(Some(cipher), keying_material)
} else {
(None, None)
};
(
HandshakeVsInfo::V5(HsV5Info {
crypto_size: self_crypto_size,
ext_hs: Some(SrtControlPacket::HandshakeRequest(SrtHandshake {
version: SrtVersion::CURRENT,
flags: SrtShakeFlags::SUPPORTED,
send_latency: settings.send_latency,
recv_latency: settings.recv_latency,
})),
ext_km,
ext_group: None,
sid: streamid.clone(),
}),
StartedInitiator {
cipher,
settings,
streamid,
initiate_time: now,
},
)
}
impl StartedInitiator {
pub fn finish_hsv5_initiation(
self,
response: &HandshakeControlInfo,
from: SocketAddr,
now: Instant,
) -> Result<ConnectionSettings, ConnectError> {
let incoming = match &response.info {
HandshakeVsInfo::V5(hs) => hs,
i => return Err(ConnectError::UnsupportedProtocolVersion(i.version())),
};
let hs = match incoming.ext_hs {
Some(SrtControlPacket::HandshakeResponse(hs)) => hs,
Some(_) => return Err(ConnectError::ExpectedHsResp),
None => return Err(ConnectError::ExpectedExtFlags),
};
Ok(ConnectionSettings {
remote: from,
rtt: now - self.initiate_time,
socket_start_time: self.initiate_time,
init_seq_num: response.init_seq_num,
remote_sockid: response.socket_id,
cipher: self.cipher,
stream_id: self.streamid,
max_flow_size: max(self.settings.max_flow_size, response.max_flow_size),
max_packet_size: min(self.settings.max_packet_size, response.max_packet_size),
send_tsbpd_latency: max(self.settings.send_latency, hs.recv_latency),
recv_tsbpd_latency: max(self.settings.recv_latency, hs.send_latency),
bandwidth: self.settings.bandwidth,
local_sockid: self.settings.local_sockid,
recv_buffer_size: self.settings.recv_buffer_size,
send_buffer_size: self.settings.send_buffer_size,
statistics_interval: self.settings.statistics_interval,
})
}
}