use crate::field::Field;
use crate::packet::{Layer, LayerContext};
use crate::protocols::ipsec::ikev2::payload::{
write_generic_payload_header, IkePayload, PayloadHeaderFields, PayloadType,
};
use crate::protocols::transport::common::{impl_layer_div, impl_layer_object};
use crate::CrafterError;
use crate::Result;
pub const IKE_KE_PAYLOAD_NAME: &str = "IkeKePayload";
pub const KE_FIXED_LEN: usize = 4;
pub const DH_GROUP_MODP_1024: u16 = 2;
pub const DH_GROUP_MODP_2048: u16 = 14;
pub const DH_GROUP_ECP_256: u16 = 19;
pub const DH_GROUP_CURVE25519: u16 = 31;
#[derive(Debug, Clone)]
pub struct IkeKePayload {
dh_group: Field<u16>,
key_exchange_data: Vec<u8>,
header: PayloadHeaderFields,
}
impl IkeKePayload {
pub fn new(dh_group: u16, key_exchange_data: impl Into<Vec<u8>>) -> Self {
Self {
dh_group: Field::user(dh_group),
key_exchange_data: key_exchange_data.into(),
header: PayloadHeaderFields::new(),
}
}
pub fn dh_group(mut self, dh_group: u16) -> Self {
self.dh_group.set_user(dh_group);
self
}
pub fn key_exchange_data(mut self, key_exchange_data: impl Into<Vec<u8>>) -> Self {
self.key_exchange_data = key_exchange_data.into();
self
}
pub fn dh_group_num(&self) -> u16 {
self.dh_group.value().copied().unwrap_or(0)
}
pub fn key_exchange_data_bytes(&self) -> &[u8] {
&self.key_exchange_data
}
pub fn next_payload(mut self, next_payload: u8) -> Self {
self.header.set_next_payload(next_payload);
self
}
pub fn payload_length(mut self, length: u16) -> Self {
self.header.set_length(length);
self
}
pub fn critical(mut self, critical: bool) -> Self {
self.header.set_critical(critical);
self
}
fn ke_body(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(KE_FIXED_LEN + self.key_exchange_data.len());
out.extend_from_slice(&self.dh_group_num().to_be_bytes());
out.extend_from_slice(&[0u8, 0u8]); out.extend_from_slice(&self.key_exchange_data);
out
}
}
impl IkePayload for IkeKePayload {
fn payload_type(&self) -> PayloadType {
PayloadType::KeyExchange
}
fn payload_body(&self, _ctx: &LayerContext<'_>) -> Result<Vec<u8>> {
Ok(self.ke_body())
}
fn next_payload_override(&self) -> Option<u8> {
self.header.next_payload_override()
}
fn payload_length_override(&self) -> Option<u16> {
self.header.payload_length_override()
}
fn critical(&self) -> bool {
self.header.critical()
}
}
impl Layer for IkeKePayload {
fn name(&self) -> &'static str {
IKE_KE_PAYLOAD_NAME
}
fn summary(&self) -> String {
format!(
"IkeKePayload(dh_group={}, kex_len={})",
self.dh_group_num(),
self.key_exchange_data.len()
)
}
fn inspection_fields(&self) -> Vec<(&'static str, String)> {
vec![
("dh_group", self.dh_group_num().to_string()),
("kex_len", self.key_exchange_data.len().to_string()),
]
}
fn encoded_len(&self) -> usize {
super::GENERIC_PAYLOAD_HEADER_LEN + KE_FIXED_LEN + self.key_exchange_data.len()
}
fn compile(&self, ctx: &LayerContext<'_>, out: &mut Vec<u8>) -> Result<()> {
let body = self.payload_body(ctx)?;
write_generic_payload_header(
out,
ctx,
self.next_payload_override(),
self.critical(),
self.payload_length_override(),
body.len(),
)?;
out.extend_from_slice(&body);
Ok(())
}
impl_layer_object!(IkeKePayload);
}
impl_layer_div!(IkeKePayload);
pub(crate) fn parse_ke_payload_body(bytes: &[u8]) -> Result<IkeKePayload> {
if bytes.len() < KE_FIXED_LEN {
return Err(CrafterError::buffer_too_short(
"ikev2.ke",
KE_FIXED_LEN,
bytes.len(),
));
}
let dh_group = u16::from_be_bytes([bytes[0], bytes[1]]);
let key_exchange_data = bytes[KE_FIXED_LEN..].to_vec();
Ok(IkeKePayload::new(dh_group, key_exchange_data))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::packet::{LayerContext, Packet, Raw};
fn compile_payload(payload: IkeKePayload) -> Vec<u8> {
let packet = Packet::from_layer(payload);
let ctx = LayerContext::new(&packet, 0);
let mut out = Vec::new();
packet.get(0).unwrap().compile(&ctx, &mut out).unwrap();
out
}
fn sample_payload() -> IkeKePayload {
IkeKePayload::new(
DH_GROUP_MODP_2048,
vec![0xDEu8, 0xAD, 0xBE, 0xEF, 0x01, 0x02],
)
}
#[test]
fn ke_constants_match_rfc() {
assert_eq!(KE_FIXED_LEN, 4);
assert_eq!(DH_GROUP_MODP_1024, 2);
assert_eq!(DH_GROUP_MODP_2048, 14);
assert_eq!(DH_GROUP_ECP_256, 19);
assert_eq!(DH_GROUP_CURVE25519, 31);
}
#[test]
fn payload_type_is_key_exchange() {
assert_eq!(sample_payload().payload_type(), PayloadType::KeyExchange);
assert_eq!(sample_payload().name(), IKE_KE_PAYLOAD_NAME);
}
#[test]
fn body_lays_out_group_reserved_then_data() {
let payload = sample_payload();
let body = payload.ke_body();
assert_eq!(u16::from_be_bytes([body[0], body[1]]), DH_GROUP_MODP_2048);
assert_eq!(&body[2..4], &[0, 0]); assert_eq!(&body[4..], &[0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x02]);
assert_eq!(body.len(), KE_FIXED_LEN + 6);
}
#[test]
fn payload_compiles_generic_header_then_body() {
let payload = sample_payload();
let bytes = compile_payload(payload.clone());
assert_eq!(bytes[0], 0);
assert_eq!(bytes[1], 0);
let payload_len = u16::from_be_bytes([bytes[2], bytes[3]]) as usize;
assert_eq!(payload_len, bytes.len());
assert_eq!(payload_len, payload.encoded_len());
}
#[test]
fn payload_honors_generic_header_overrides() {
let payload = sample_payload()
.next_payload(40)
.critical(true)
.payload_length(0xBEEF);
let bytes = compile_payload(payload);
assert_eq!(bytes[0], 40);
assert_eq!(bytes[1], 0x80); assert_eq!(u16::from_be_bytes([bytes[2], bytes[3]]), 0xBEEF);
}
#[test]
fn payload_chain_next_payload_points_at_ke() {
use crate::protocols::ipsec::ikev2::payload::{
following_next_payload, payload_type_for_layer_name, PAYLOAD_KE,
};
assert_eq!(
payload_type_for_layer_name(IKE_KE_PAYLOAD_NAME),
Some(PayloadType::KeyExchange)
);
let packet: Packet = Packet::from_layer(Raw::from_bytes([0u8; 0])) / sample_payload();
let ctx = LayerContext::new(&packet, 0);
assert_eq!(following_next_payload(&ctx), PAYLOAD_KE);
}
#[test]
fn round_trip_preserves_group_and_data() {
let payload = sample_payload();
let bytes = compile_payload(payload.clone());
let parsed = parse_ke_payload_body(&bytes[4..]).unwrap();
assert_eq!(parsed.dh_group_num(), DH_GROUP_MODP_2048);
assert_eq!(
parsed.key_exchange_data_bytes(),
payload.key_exchange_data_bytes()
);
}
#[test]
fn round_trip_recompiles_byte_for_byte() {
let payload = sample_payload();
let bytes = compile_payload(payload);
let parsed = parse_ke_payload_body(&bytes[4..]).unwrap();
let recompiled = compile_payload(parsed);
assert_eq!(recompiled, bytes);
}
#[test]
fn parse_rejects_truncated_body() {
let err = parse_ke_payload_body(&[0u8, 14, 0]).unwrap_err();
assert!(matches!(err, CrafterError::BufferTooShort { .. }));
}
}