use crate::bytes::{Buf, BufMut, Bytes, BytesMut};
use crate::net::atp::handshake::state_machine::{HandshakeError, QuicVersion};
use crate::types::outcome::Outcome;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct VersionNegotiationPacket {
pub source_cid: Bytes,
pub dest_cid: Bytes,
pub supported_versions: Vec<u32>,
}
impl VersionNegotiationPacket {
pub fn new(source_cid: Bytes, dest_cid: Bytes, supported_versions: Vec<u32>) -> Self {
Self {
source_cid,
dest_cid,
supported_versions,
}
}
pub fn encode(&self) -> Outcome<Bytes, HandshakeError> {
let mut buf = BytesMut::new();
let first_byte = 0x80; buf.put_u8(first_byte);
buf.put_u32(0);
if self.dest_cid.len() > 255 {
return Outcome::err(HandshakeError::ConnectionIdError {
reason: "destination CID too long".to_string(),
});
}
buf.put_u8(self.dest_cid.len() as u8);
buf.put_slice(&self.dest_cid);
if self.source_cid.len() > 255 {
return Outcome::err(HandshakeError::ConnectionIdError {
reason: "source CID too long".to_string(),
});
}
buf.put_u8(self.source_cid.len() as u8);
buf.put_slice(&self.source_cid);
for &version in &self.supported_versions {
buf.put_u32(version);
}
Outcome::ok(buf.freeze())
}
pub fn decode(data: &[u8]) -> Outcome<Self, HandshakeError> {
if data.len() < 7 {
return Outcome::err(HandshakeError::InvalidPacket {
reason: "version negotiation packet too short".to_string(),
});
}
let mut buf = data;
let first_byte = buf.get_u8();
if first_byte & 0x80 == 0 {
return Outcome::err(HandshakeError::InvalidPacket {
reason: "not a long header packet".to_string(),
});
}
let version = buf.get_u32();
if version != 0 {
return Outcome::err(HandshakeError::InvalidPacket {
reason: "version negotiation must have version 0".to_string(),
});
}
let dest_cid_len = buf.get_u8() as usize;
if buf.remaining() < dest_cid_len {
return Outcome::err(HandshakeError::InvalidPacket {
reason: "insufficient data for destination CID".to_string(),
});
}
let dest_cid = Bytes::copy_from_slice(&buf[..dest_cid_len]);
buf.advance(dest_cid_len);
if buf.is_empty() {
return Outcome::err(HandshakeError::InvalidPacket {
reason: "missing source CID length".to_string(),
});
}
let source_cid_len = buf.get_u8() as usize;
if buf.remaining() < source_cid_len {
return Outcome::err(HandshakeError::InvalidPacket {
reason: "insufficient data for source CID".to_string(),
});
}
let source_cid = Bytes::copy_from_slice(&buf[..source_cid_len]);
buf.advance(source_cid_len);
if buf.remaining() % 4 != 0 {
return Outcome::err(HandshakeError::InvalidPacket {
reason: "invalid version list length".to_string(),
});
}
let mut supported_versions = Vec::new();
while buf.remaining() >= 4 {
supported_versions.push(buf.get_u32());
}
if supported_versions.is_empty() {
return Outcome::err(HandshakeError::InvalidPacket {
reason: "no supported versions".to_string(),
});
}
Outcome::ok(Self {
source_cid,
dest_cid,
supported_versions,
})
}
pub fn supports_version(&self, version: u32) -> bool {
self.supported_versions.contains(&version)
}
pub fn select_version(&self, attempted_version: u32) -> Option<u32> {
if self.supports_version(attempted_version) {
return Some(attempted_version);
}
self.supported_versions.iter().max().copied()
}
}
pub struct VersionNegotiation;
impl VersionNegotiation {
pub fn is_negotiation_needed(client_version: u32, server_versions: &[u32]) -> bool {
!server_versions.contains(&client_version)
}
pub fn create_server_response(
client_dest_cid: Bytes,
server_source_cid: Bytes,
) -> VersionNegotiationPacket {
VersionNegotiationPacket::new(
server_source_cid,
client_dest_cid,
QuicVersion::supported_versions(),
)
}
pub fn validate_server_response(
packet: &VersionNegotiationPacket,
original_dest_cid: &[u8],
original_source_cid: &[u8],
) -> Outcome<(), HandshakeError> {
if packet.dest_cid.as_ref() != original_source_cid {
return Outcome::err(HandshakeError::InvalidPacket {
reason: "version negotiation destination CID mismatch".to_string(),
});
}
if packet.source_cid.as_ref() != original_dest_cid {
return Outcome::err(HandshakeError::InvalidPacket {
reason: "version negotiation source CID mismatch".to_string(),
});
}
let has_supported = packet
.supported_versions
.iter()
.any(|&v| QuicVersion::is_supported(v));
if !has_supported {
return Outcome::err(HandshakeError::UnsupportedVersion {
version: 0, });
}
Outcome::ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_version_negotiation_packet_roundtrip() {
let source_cid = Bytes::from_static(b"server_cid");
let dest_cid = Bytes::from_static(b"client_cid");
let supported_versions = vec![0x00000001, 0x12345678];
let packet = VersionNegotiationPacket::new(
source_cid.clone(),
dest_cid.clone(),
supported_versions.clone(),
);
let encoded = packet.encode().unwrap();
let decoded = VersionNegotiationPacket::decode(&encoded).unwrap();
assert_eq!(decoded.source_cid, source_cid);
assert_eq!(decoded.dest_cid, dest_cid);
assert_eq!(decoded.supported_versions, supported_versions);
}
#[test]
fn test_version_support_check() {
let packet = VersionNegotiationPacket::new(
Bytes::from_static(b"src"),
Bytes::from_static(b"dst"),
vec![0x00000001, 0x12345678],
);
assert!(packet.supports_version(0x00000001));
assert!(packet.supports_version(0x12345678));
assert!(!packet.supports_version(0xabcdef00));
}
#[test]
fn test_version_selection() {
let packet = VersionNegotiationPacket::new(
Bytes::from_static(b"src"),
Bytes::from_static(b"dst"),
vec![0x00000001, 0x12345678],
);
assert_eq!(packet.select_version(0x00000001), Some(0x00000001));
assert_eq!(packet.select_version(0xabcdef00), Some(0x12345678));
}
#[test]
fn test_negotiation_needed() {
let server_versions = vec![0x00000001, 0x12345678];
assert!(!VersionNegotiation::is_negotiation_needed(
0x00000001,
&server_versions
));
assert!(VersionNegotiation::is_negotiation_needed(
0xabcdef00,
&server_versions
));
}
#[test]
fn test_invalid_packet_decode() {
let result = VersionNegotiationPacket::decode(&[0x80, 0x00]);
assert!(result.is_err());
let result = VersionNegotiationPacket::decode(&[0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
assert!(result.is_err());
let result = VersionNegotiationPacket::decode(&[0x80, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00]);
assert!(result.is_err());
}
}