use crate::endian::read_u32_be;
use crate::error::{CrafterError, Result};
use crate::packet::{Packet, Raw};
use crate::registry::ProtocolRegistry;
use super::header::{IkeHeader, IKE_HEADER_LEN};
use super::payload::auth::parse_auth_payload_body;
use super::payload::cert::{parse_cert_payload_body, parse_certreq_payload_body};
use super::payload::config::parse_config_payload_body;
use super::payload::delete::parse_delete_payload_body;
use super::payload::eap::parse_eap_payload_body;
use super::payload::id::{parse_id_payload_body, IdRole};
use super::payload::ke::parse_ke_payload_body;
use super::payload::nonce::parse_nonce_payload_body;
use super::payload::notify::parse_notify_payload_body;
use super::payload::sa::parse_sa_payload_body;
use super::payload::ts::{parse_ts_payload_body, TsRole};
use super::payload::vendor::parse_vendor_id_payload_body;
use super::payload::{PayloadType, GENERIC_PAYLOAD_HEADER_LEN};
fn read_u64_be_field(
bytes: &[u8],
range: std::ops::Range<usize>,
context: &'static str,
) -> Result<u64> {
let field = bytes
.get(range.clone())
.ok_or_else(|| CrafterError::buffer_too_short(context, range.end, bytes.len()))?;
let field: [u8; 8] = field
.try_into()
.map_err(|_| CrafterError::buffer_too_short(context, range.end, bytes.len()))?;
Ok(u64::from_be_bytes(field))
}
fn parse_ike_header(bytes: &[u8]) -> Result<IkeHeader> {
if bytes.len() < IKE_HEADER_LEN {
return Err(CrafterError::buffer_too_short(
"ikev2.header",
IKE_HEADER_LEN,
bytes.len(),
));
}
let initiator_spi = read_u64_be_field(bytes, 0..8, "ikev2.header.initiator_spi")?;
let responder_spi = read_u64_be_field(bytes, 8..16, "ikev2.header.responder_spi")?;
let next_payload = bytes[16];
let version = bytes[17];
let exchange_type = bytes[18];
let flags = bytes[19];
let message_id = read_u32_be(&bytes[20..24])?;
let length = read_u32_be(&bytes[24..28])?;
Ok(IkeHeader::new()
.initiator_spi(initiator_spi)
.responder_spi(responder_spi)
.next_payload(next_payload)
.version(version)
.exchange(exchange_type)
.flags(flags)
.message_id(message_id)
.length(length))
}
fn push_typed_payload(packet: Packet, payload_type: PayloadType, payload: &[u8]) -> Result<Packet> {
let body = &payload[GENERIC_PAYLOAD_HEADER_LEN..];
let packet = match payload_type {
PayloadType::SecurityAssociation => packet.push(parse_sa_payload_body(body)?),
PayloadType::KeyExchange => packet.push(parse_ke_payload_body(body)?),
PayloadType::IdInitiator => packet.push(parse_id_payload_body(IdRole::Initiator, body)?),
PayloadType::IdResponder => packet.push(parse_id_payload_body(IdRole::Responder, body)?),
PayloadType::Certificate => packet.push(parse_cert_payload_body(body)?),
PayloadType::CertificateRequest => packet.push(parse_certreq_payload_body(body)?),
PayloadType::Authentication => packet.push(parse_auth_payload_body(body)?),
PayloadType::Nonce => packet.push(parse_nonce_payload_body(body)?),
PayloadType::Notify => packet.push(parse_notify_payload_body(body)?),
PayloadType::Delete => packet.push(parse_delete_payload_body(body)?),
PayloadType::VendorId => packet.push(parse_vendor_id_payload_body(body)?),
PayloadType::TrafficSelectorInitiator => {
packet.push(parse_ts_payload_body(TsRole::Initiator, body)?)
}
PayloadType::TrafficSelectorResponder => {
packet.push(parse_ts_payload_body(TsRole::Responder, body)?)
}
PayloadType::Configuration => packet.push(parse_config_payload_body(body)?),
PayloadType::ExtensibleAuthentication => packet.push(parse_eap_payload_body(body)?),
PayloadType::Encrypted | PayloadType::None | PayloadType::Unknown(_) => {
packet.push(Raw::from_bytes(payload))
}
};
Ok(packet)
}
fn decode_payload_chain(
mut packet: Packet,
first_payload_type: u8,
bytes: &[u8],
) -> Result<Packet> {
let mut next_payload = first_payload_type;
let mut offset = 0usize;
while next_payload != PayloadType::None.codepoint() {
let remaining = &bytes[offset..];
if remaining.len() < GENERIC_PAYLOAD_HEADER_LEN {
return Err(CrafterError::buffer_too_short(
"ikev2.payload.header",
GENERIC_PAYLOAD_HEADER_LEN,
remaining.len(),
));
}
let this_next_payload = remaining[0];
let payload_length = usize::from(u16::from_be_bytes([remaining[2], remaining[3]]));
if payload_length < GENERIC_PAYLOAD_HEADER_LEN || payload_length > remaining.len() {
return Err(CrafterError::buffer_too_short(
"ikev2.payload.length",
payload_length.max(GENERIC_PAYLOAD_HEADER_LEN),
remaining.len(),
));
}
let payload = &remaining[..payload_length];
let this_payload_type = PayloadType::from(next_payload);
packet = push_typed_payload(packet, this_payload_type, payload)?;
offset += payload_length;
if matches!(this_payload_type, PayloadType::Encrypted) {
break;
}
next_payload = this_next_payload;
}
Ok(packet)
}
pub(crate) fn decode_ike_message(
_registry: &ProtocolRegistry,
packet: Packet,
bytes: &[u8],
) -> Result<Packet> {
let header = parse_ike_header(bytes)?;
let length = usize::try_from(header.length_value().unwrap_or(0)).unwrap_or(usize::MAX);
if length < IKE_HEADER_LEN {
return Err(CrafterError::buffer_too_short(
"ikev2.header.length",
IKE_HEADER_LEN,
length,
));
}
if length > bytes.len() {
return Err(CrafterError::buffer_too_short(
"ikev2.message",
length,
bytes.len(),
));
}
let first_payload_type = header.next_payload_value().unwrap_or(0);
let packet = packet.push(header);
decode_payload_chain(packet, first_payload_type, &bytes[IKE_HEADER_LEN..length])
}
pub(crate) fn append_ikev2_packet_with_registry(
registry: &ProtocolRegistry,
packet: Packet,
bytes: &[u8],
) -> Result<Packet> {
decode_ike_message(registry, packet, bytes)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::packet::{NetworkLayer, Packet, Raw};
use crate::protocols::ipsec::ikev2::header::{IkeHeader, IKE_SA_INIT};
use crate::protocols::ipsec::ikev2::payload::ke::IkeKePayload;
use crate::protocols::ipsec::ikev2::payload::nonce::IkeNoncePayload;
use crate::protocols::ipsec::ikev2::payload::notify::NOTIFY_PROTOCOL_NONE;
use crate::protocols::ipsec::ikev2::payload::sa::{IkeSaPayload, Proposal, Transform};
use crate::protocols::ipsec::ikev2::payload::{
IkeNotifyPayload, IkeVendorIdPayload, NotifyType, PAYLOAD_SA,
};
use crate::protocols::ipv4::{Ipv4, IPPROTO_UDP};
use crate::protocols::transport::Udp;
const IKE_UDP_PORT: u16 = 500;
fn ike_sa_init_payloads() -> (IkeSaPayload, IkeKePayload, IkeNoncePayload) {
let proposal = Proposal::new(1, 3).with_transform(Transform::new(1, 20));
let sa = IkeSaPayload::new().with_proposal(proposal);
let ke = IkeKePayload::new(14, vec![0xAB; 32]);
let ni = IkeNoncePayload::new(vec![0x5A; 16]);
(sa, ke, ni)
}
fn compile_ike_sa_init_packet() -> Packet {
let (sa, ke, ni) = ike_sa_init_payloads();
let ipv4 = Ipv4::new()
.protocol(IPPROTO_UDP)
.src("192.0.2.1".parse().unwrap())
.dst("192.0.2.2".parse().unwrap());
let udp = Udp::new().sport(IKE_UDP_PORT).dport(IKE_UDP_PORT);
let header = IkeHeader::new().exchange(IKE_SA_INIT).initiator();
Packet::from_layer(ipv4) / udp / header / sa / ke / ni
}
#[test]
fn full_message_decodes_all_four_payload_layers_and_round_trips() {
let packet = compile_ike_sa_init_packet();
let wire = packet.compile().expect("compile IKE message").into_bytes();
let decoded =
Packet::decode_from_l3(NetworkLayer::Ipv4, &wire).expect("decode IKE_SA_INIT from L3");
assert!(
decoded.layer::<IkeHeader>().is_some(),
"IkeHeader must be present"
);
assert!(decoded.layer::<IkeSaPayload>().is_some(), "SA present");
assert!(decoded.layer::<IkeKePayload>().is_some(), "KE present");
assert!(decoded.layer::<IkeNoncePayload>().is_some(), "Ni present");
let header = decoded.layer::<IkeHeader>().unwrap();
assert_eq!(header.next_payload_value(), Some(PAYLOAD_SA));
assert_eq!(header.exchange_type_value(), Some(IKE_SA_INIT));
let recompiled = decoded.compile().expect("recompile decoded").into_bytes();
assert_eq!(recompiled, wire, "round-trip must be byte-exact");
}
#[test]
fn decode_ike_message_pushes_header_then_payload_chain() {
let packet = compile_ike_sa_init_packet();
let wire = packet.compile().expect("compile").into_bytes();
let ip_header_len = usize::from(wire[0] & 0x0f) * 4;
let ike_bytes = &wire[ip_header_len + 8..];
let registry = ProtocolRegistry::with_builtin_bindings();
let decoded =
decode_ike_message(®istry, Packet::new(), ike_bytes).expect("decode IKE message");
assert_eq!(decoded.len(), 4);
assert_eq!(decoded.get(0).unwrap().name(), "IkeHeader");
assert_eq!(decoded.get(1).unwrap().name(), "IkeSaPayload");
assert_eq!(decoded.get(2).unwrap().name(), "IkeKePayload");
assert_eq!(decoded.get(3).unwrap().name(), "IkeNoncePayload");
}
#[test]
fn unknown_payload_type_decodes_as_raw() {
let unknown_type = 60u8;
let payload_body = [0x11u8, 0x22, 0x33, 0x44];
let payload_len = GENERIC_PAYLOAD_HEADER_LEN + payload_body.len();
let total_len = IKE_HEADER_LEN + payload_len;
let mut message = Vec::new();
message.extend_from_slice(&0x0102_0304_0506_0708u64.to_be_bytes()); message.extend_from_slice(&0u64.to_be_bytes()); message.push(unknown_type); message.push(0x20); message.push(IKE_SA_INIT); message.push(0x08); message.extend_from_slice(&0u32.to_be_bytes()); message.extend_from_slice(&(total_len as u32).to_be_bytes()); message.push(0); message.push(0); message.extend_from_slice(&(payload_len as u16).to_be_bytes()); message.extend_from_slice(&payload_body);
let registry = ProtocolRegistry::with_builtin_bindings();
let decoded = decode_ike_message(®istry, Packet::new(), &message)
.expect("decode message with unknown payload");
assert_eq!(decoded.len(), 2);
assert_eq!(decoded.get(0).unwrap().name(), "IkeHeader");
let raw = decoded
.get(1)
.unwrap()
.as_any()
.downcast_ref::<Raw>()
.expect("unknown payload preserved as Raw");
assert_eq!(
raw.as_bytes(),
&message[IKE_HEADER_LEN..],
"unknown payload preserved verbatim"
);
let recompiled = decoded.compile().expect("recompile").into_bytes();
assert_eq!(recompiled, message);
}
#[test]
fn truncated_header_is_structured_error() {
let registry = ProtocolRegistry::with_builtin_bindings();
let err = decode_ike_message(®istry, Packet::new(), &[0u8; 10])
.expect_err("must reject truncated header");
assert!(matches!(err, CrafterError::BufferTooShort { .. }));
}
#[test]
fn payload_running_off_the_end_is_structured_error() {
let mut message = Vec::new();
message.extend_from_slice(&0u64.to_be_bytes()); message.extend_from_slice(&0u64.to_be_bytes()); message.push(PAYLOAD_SA); message.push(0x20); message.push(IKE_SA_INIT); message.push(0x08); message.extend_from_slice(&0u32.to_be_bytes()); message.extend_from_slice(&36u32.to_be_bytes());
message.push(0); message.push(0); message.extend_from_slice(&100u16.to_be_bytes()); message.extend_from_slice(&[0u8; 4]);
let registry = ProtocolRegistry::with_builtin_bindings();
let err = decode_ike_message(®istry, Packet::new(), &message)
.expect_err("oversized payload length must error");
assert!(matches!(err, CrafterError::BufferTooShort { .. }));
}
#[test]
fn notify_and_vendor_chain_decodes_to_typed_layers() {
let notify =
IkeNotifyPayload::new(NOTIFY_PROTOCOL_NONE, NotifyType::InitialContact, Vec::new());
let vendor = IkeVendorIdPayload::new(vec![0xDE, 0xAD, 0xBE, 0xEF]);
let header = IkeHeader::new().exchange(IKE_SA_INIT).initiator();
let udp = Udp::new().sport(IKE_UDP_PORT).dport(IKE_UDP_PORT);
let ipv4 = Ipv4::new()
.protocol(IPPROTO_UDP)
.src("192.0.2.1".parse().unwrap())
.dst("192.0.2.2".parse().unwrap());
let packet = Packet::from_layer(ipv4) / udp / header / notify / vendor;
let wire = packet.compile().expect("compile").into_bytes();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, &wire).expect("decode from L3");
assert!(
decoded.layer::<IkeNotifyPayload>().is_some(),
"Notify present"
);
assert!(
decoded.layer::<IkeVendorIdPayload>().is_some(),
"Vendor ID present"
);
let recompiled = decoded.compile().expect("recompile").into_bytes();
assert_eq!(recompiled, wire, "round-trip byte-exact");
}
}