use crate::transport::path::PATH_CHALLENGE_LEN;
use crate::transport::types::{
PacketFlags, PacketHeader, PhantomPacket, SequenceNumber, SessionId, StreamId,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PathValidationKind {
Challenge,
Response,
}
pub fn build_path_validation_packet(
session_id: SessionId,
path_id: u8,
sequence: SequenceNumber,
payload: [u8; PATH_CHALLENGE_LEN],
) -> PhantomPacket {
let stream_id: StreamId = 0;
let header = PacketHeader::new(
session_id,
stream_id,
sequence,
PacketFlags::new(PacketFlags::PATH_VALIDATION),
)
.with_path_id(path_id);
PhantomPacket::new(header, payload.to_vec())
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParsedPathValidation {
pub path_id: u8,
pub payload: [u8; PATH_CHALLENGE_LEN],
}
pub fn parse_path_validation(
packet: &PhantomPacket,
) -> Result<Option<ParsedPathValidation>, PathValidationParseError> {
if !packet.header.flags.contains(PacketFlags::PATH_VALIDATION) {
return Ok(None);
}
if packet.payload.len() != PATH_CHALLENGE_LEN {
return Err(PathValidationParseError::WrongPayloadLength {
got: packet.payload.len(),
});
}
let mut buf = [0u8; PATH_CHALLENGE_LEN];
buf.copy_from_slice(&packet.payload);
Ok(Some(ParsedPathValidation {
path_id: packet.header.path_id,
payload: buf,
}))
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PathValidationParseError {
WrongPayloadLength { got: usize },
}
impl std::fmt::Display for PathValidationParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::WrongPayloadLength { got } => write!(
f,
"PATH_VALIDATION payload length is {}, expected {}",
got, PATH_CHALLENGE_LEN
),
}
}
}
impl std::error::Error for PathValidationParseError {}
#[cfg(test)]
mod tests {
use super::*;
fn fixed_session_id() -> SessionId {
SessionId::from_bytes([0x42; 32])
}
#[test]
fn build_round_trip_preserves_path_id_and_payload() {
let payload = [0xAA; PATH_CHALLENGE_LEN];
let v2 = build_path_validation_packet(fixed_session_id(), 7, 42, payload);
assert_eq!(v2.header.path_id, 7);
assert!(v2.header.flags.contains(PacketFlags::PATH_VALIDATION));
assert_eq!(v2.header.stream_id, 0u16);
assert_eq!(v2.header.sequence, 42u32);
assert_eq!(v2.payload, payload.to_vec());
}
#[test]
fn parse_path_validation_returns_payload_on_match() {
let payload = [0xCC; PATH_CHALLENGE_LEN];
let v2 = build_path_validation_packet(fixed_session_id(), 3, 1, payload);
let parsed = parse_path_validation(&v2).expect("ok").expect("some");
assert_eq!(parsed.path_id, 3);
assert_eq!(parsed.payload, payload);
}
#[test]
fn parse_returns_none_when_flag_missing() {
let header = PacketHeader::new(
fixed_session_id(),
0u16,
0u32,
PacketFlags::new(PacketFlags::ENCRYPTED), );
let p = PhantomPacket::new(header, vec![0u8; PATH_CHALLENGE_LEN]);
let parsed = parse_path_validation(&p).expect("no error");
assert!(parsed.is_none());
}
#[test]
fn parse_errors_on_wrong_payload_length() {
let header = PacketHeader::new(
fixed_session_id(),
0u16,
0u32,
PacketFlags::new(PacketFlags::PATH_VALIDATION),
);
let p = PhantomPacket::new(header, vec![0u8; 16]); let err = parse_path_validation(&p).expect_err("err");
assert_eq!(
err,
PathValidationParseError::WrongPayloadLength { got: 16 }
);
}
#[test]
fn challenge_and_response_are_wire_identical() {
let payload = [0x55; PATH_CHALLENGE_LEN];
let a = build_path_validation_packet(fixed_session_id(), 1, 5, payload);
let b = build_path_validation_packet(fixed_session_id(), 1, 5, payload);
let buf_a = a.to_wire();
let buf_b = b.to_wire();
assert_eq!(buf_a, buf_b);
}
#[test]
fn kind_enum_round_trips_for_documentation() {
assert_ne!(PathValidationKind::Challenge, PathValidationKind::Response);
}
#[test]
fn full_challenge_response_round_trip_via_codec() {
use crate::transport::path::{PathRegistry, PathStateKind, RegistrationResult};
let side_a = PathRegistry::new();
let side_b = PathRegistry::new();
let path_id: u8 = 5;
assert_eq!(side_a.register(path_id), RegistrationResult::Created);
let challenge = side_a.issue_challenge(path_id).expect("challenge issued");
let session_id = fixed_session_id();
let outgoing = build_path_validation_packet(session_id, path_id, 0, challenge);
let buf = outgoing.to_wire();
let v2 = PhantomPacket::from_wire(&buf).expect("deserialize");
let parsed = parse_path_validation(&v2)
.expect("ok")
.expect("flag matched");
assert_eq!(parsed.path_id, path_id);
assert_eq!(parsed.payload, challenge);
let response = build_path_validation_packet(session_id, path_id, 0, parsed.payload);
let buf2 = response.to_wire();
let v2_echoed = PhantomPacket::from_wire(&buf2).expect("deserialize");
let echoed_parsed = parse_path_validation(&v2_echoed)
.expect("ok")
.expect("flag matched");
let accepted = side_a.verify_response(echoed_parsed.path_id, &echoed_parsed.payload);
assert!(accepted, "responder's echo must validate");
assert_eq!(side_a.state(path_id), Some(PathStateKind::Validated));
let _ = side_b;
}
#[test]
fn tampered_response_fails_validation() {
use crate::transport::path::{PathRegistry, PathStateKind};
let validator = PathRegistry::new();
validator.register(2);
let challenge = validator.issue_challenge(2).expect("challenge");
let mut tampered = challenge;
tampered[7] ^= 0xFF;
let v2 = build_path_validation_packet(fixed_session_id(), 2, 0, tampered);
let parsed = parse_path_validation(&v2).unwrap().unwrap();
assert!(!validator.verify_response(parsed.path_id, &parsed.payload));
assert_eq!(validator.state(2), Some(PathStateKind::Failed));
}
}