use crate::field::Field;
use crate::packet::{Layer, LayerContext, Packet};
use crate::protocols::ipsec::ikev2::payload::{
write_generic_payload_header, IkePayload, PayloadHeaderFields, PayloadType,
};
use crate::protocols::ipsec::sa::{iv_requirement, open, seal, SecurityAssociation};
use crate::protocols::transport::common::{impl_layer_div, impl_layer_object};
use crate::{CrafterError, Result};
pub const IKE_ENCRYPTED_PAYLOAD_NAME: &str = "IkeEncryptedPayload";
pub const SK_PAD_LENGTH_FIELD_LEN: usize = 1;
#[derive(Debug, Clone)]
pub struct IkeEncryptedPayload {
inner: Packet,
sa: SecurityAssociation,
iv: Field<Vec<u8>>,
pad: Field<Vec<u8>>,
pad_length: Field<u8>,
icv: Field<Vec<u8>>,
header: PayloadHeaderFields,
}
impl IkeEncryptedPayload {
pub fn new(sa: SecurityAssociation) -> Self {
Self {
inner: Packet::new(),
sa,
iv: Field::unset(),
pad: Field::unset(),
pad_length: Field::unset(),
icv: Field::unset(),
header: PayloadHeaderFields::new(),
}
}
pub fn payloads(mut self, inner: impl Into<Packet>) -> Self {
self.inner = inner.into();
self
}
pub fn payload<L>(mut self, layer: L) -> Self
where
L: Layer,
{
self.inner = self.inner.push(layer);
self
}
pub fn iv(mut self, iv: impl Into<Vec<u8>>) -> Self {
self.iv.set_user(iv.into());
self
}
pub fn pad(mut self, pad: impl Into<Vec<u8>>) -> Self {
self.pad.set_user(pad.into());
self
}
pub fn pad_length(mut self, pad_length: u8) -> Self {
self.pad_length.set_user(pad_length);
self
}
pub fn icv(mut self, icv: impl Into<Vec<u8>>) -> Self {
self.icv.set_user(icv.into());
self
}
pub fn next_payload(mut self, next_payload: u8) -> Self {
self.header.set_next_payload(next_payload);
self
}
pub fn payload_length(mut self, length: u16) -> Self {
self.header.set_length(length);
self
}
pub fn critical(mut self, critical: bool) -> Self {
self.header.set_critical(critical);
self
}
pub fn security_association(&self) -> &SecurityAssociation {
&self.sa
}
pub fn inner_payloads(&self) -> &Packet {
&self.inner
}
pub fn iv_value(&self) -> Option<&[u8]> {
self.iv.value().map(Vec::as_slice)
}
pub fn pad_value(&self) -> Option<&[u8]> {
self.pad.value().map(Vec::as_slice)
}
pub fn icv_value(&self) -> Option<&[u8]> {
self.icv.value().map(Vec::as_slice)
}
fn inner_payload_bytes(&self) -> Result<Vec<u8>> {
let mut out = Vec::new();
for (index, layer) in self.inner.iter().enumerate() {
let ctx = LayerContext::new(&self.inner, index);
layer.compile(&ctx, &mut out)?;
}
Ok(out)
}
fn effective_next_payload(&self) -> u8 {
if let Some(next_payload) = self.header.next_payload_override() {
return next_payload;
}
self.inner
.get(0)
.and_then(|layer| super::payload_type_for_layer_name(layer.name()))
.map(PayloadType::codepoint)
.unwrap_or(super::PAYLOAD_TYPE_NONE)
}
fn effective_pad(&self, inner_len: usize, block_size: usize) -> Vec<u8> {
if let Some(pad) = self.pad.value() {
return pad.clone();
}
let alignment = block_size.max(1);
let unaligned = inner_len + SK_PAD_LENGTH_FIELD_LEN;
let remainder = unaligned % alignment;
let pad_len = if remainder == 0 {
0
} else {
alignment - remainder
};
(1..=pad_len as u8).collect()
}
fn sk_seal_inputs(&self) -> Result<(Vec<u8>, Vec<u8>)> {
let inner = self.inner_payload_bytes()?;
let block_size = self.sa.enc.block_size();
let pad = self.effective_pad(inner.len(), block_size);
let pad_len = match self.pad_length.value().copied() {
Some(pad_len) => pad_len,
None => u8::try_from(pad.len()).map_err(|_| {
CrafterError::invalid_field_value("ikev2.sk.pad", "SK padding exceeds 255 octets")
})?,
};
let mut plaintext = Vec::with_capacity(inner.len() + pad.len() + SK_PAD_LENGTH_FIELD_LEN);
plaintext.extend_from_slice(&inner);
plaintext.extend_from_slice(&pad);
plaintext.push(pad_len);
let iv = match self.iv.value() {
Some(iv) => iv.clone(),
None => vec![0u8; iv_requirement(&self.sa).iv_len],
};
Ok((plaintext, iv))
}
fn sk_body(&self) -> Result<Vec<u8>> {
let (plaintext, iv) = self.sk_seal_inputs()?;
let sealed = seal(&self.sa, &iv, &[], &plaintext)?;
let icv = match self.icv.value() {
Some(icv) => icv.clone(),
None => sealed.icv,
};
let mut body = Vec::with_capacity(iv.len() + sealed.ciphertext.len() + icv.len());
body.extend_from_slice(&iv);
body.extend_from_slice(&sealed.ciphertext);
body.extend_from_slice(&icv);
Ok(body)
}
}
impl IkePayload for IkeEncryptedPayload {
fn payload_type(&self) -> PayloadType {
PayloadType::Encrypted
}
fn payload_body(&self, _ctx: &LayerContext<'_>) -> Result<Vec<u8>> {
self.sk_body()
}
fn next_payload_override(&self) -> Option<u8> {
Some(self.effective_next_payload())
}
fn payload_length_override(&self) -> Option<u16> {
self.header.payload_length_override()
}
fn critical(&self) -> bool {
self.header.critical()
}
}
impl Layer for IkeEncryptedPayload {
fn name(&self) -> &'static str {
IKE_ENCRYPTED_PAYLOAD_NAME
}
fn summary(&self) -> String {
format!(
"IkeEncryptedPayload(inner_payloads={}, {})",
self.inner.len(),
self.sa.summary(),
)
}
fn inspection_fields(&self) -> Vec<(&'static str, String)> {
vec![
("next_payload", self.effective_next_payload().to_string()),
("inner_payloads", self.inner.len().to_string()),
("spi", format!("0x{:08x}", self.sa.spi)),
("mode", self.sa.mode.label().to_string()),
]
}
fn encoded_len(&self) -> usize {
match self.sk_body() {
Ok(body) => super::GENERIC_PAYLOAD_HEADER_LEN + body.len(),
Err(_) => super::GENERIC_PAYLOAD_HEADER_LEN,
}
}
fn consumes_following(&self) -> bool {
false
}
fn compile(&self, ctx: &LayerContext<'_>, out: &mut Vec<u8>) -> Result<()> {
let body = self.payload_body(ctx)?;
write_generic_payload_header(
out,
ctx,
self.next_payload_override(),
self.critical(),
self.payload_length_override(),
body.len(),
)?;
out.extend_from_slice(&body);
Ok(())
}
impl_layer_object!(IkeEncryptedPayload);
}
impl_layer_div!(IkeEncryptedPayload);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DecodedSk {
pub first_inner_payload: u8,
pub inner_payloads: Vec<u8>,
}
pub fn decode_sk_payload_with_sa(bytes: &[u8], sa: &SecurityAssociation) -> Result<DecodedSk> {
use super::GENERIC_PAYLOAD_HEADER_LEN;
if bytes.len() < GENERIC_PAYLOAD_HEADER_LEN {
return Err(CrafterError::buffer_too_short(
"ikev2.sk",
GENERIC_PAYLOAD_HEADER_LEN,
bytes.len(),
));
}
let first_inner_payload = bytes[0];
let body = &bytes[GENERIC_PAYLOAD_HEADER_LEN..];
let iv_len = iv_requirement(sa).iv_len;
let icv_len = sk_icv_len(sa);
let fixed = iv_len + icv_len;
if body.len() < fixed {
return Err(CrafterError::buffer_too_short(
"ikev2.sk.body",
fixed,
body.len(),
));
}
let iv = &body[..iv_len];
let ciphertext = &body[iv_len..body.len() - icv_len];
let icv = &body[body.len() - icv_len..];
let plaintext = open(sa, iv, &[], ciphertext, icv)?;
if plaintext.is_empty() {
return Err(CrafterError::buffer_too_short(
"ikev2.sk.plaintext",
SK_PAD_LENGTH_FIELD_LEN,
0,
));
}
let Some(&pad_len_byte) = plaintext.last() else {
return Err(CrafterError::buffer_too_short(
"ikev2.sk.plaintext",
SK_PAD_LENGTH_FIELD_LEN,
0,
));
};
let pad_len = usize::from(pad_len_byte);
let inner_end = plaintext.len() - SK_PAD_LENGTH_FIELD_LEN;
if pad_len > inner_end {
return Err(CrafterError::invalid_field_value(
"ikev2.sk.pad_length",
"SK Pad Length exceeds the decrypted plaintext",
));
}
let inner_payloads = plaintext[..inner_end - pad_len].to_vec();
Ok(DecodedSk {
first_inner_payload,
inner_payloads,
})
}
fn sk_icv_len(sa: &SecurityAssociation) -> usize {
if sa.enc.is_aead() {
sa.enc.icv_len().unwrap_or(0)
} else {
sa.integ.icv_len().unwrap_or(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::ipsec::ikev2::header::IkeHeader;
use crate::protocols::ipsec::ikev2::payload::notify::{
IkeNotifyPayload, NotifyType, IKE_NOTIFY_PAYLOAD_NAME, NOTIFY_PROTOCOL_ESP,
};
use crate::protocols::ipsec::ikev2::payload::{
payload_type_for_layer_name, GENERIC_PAYLOAD_HEADER_LEN, PAYLOAD_NOTIFY,
};
use crate::protocols::ipsec::sa::{EncryptionAlgorithm, IntegrityAlgorithm};
fn cbc_hmac_sa() -> SecurityAssociation {
SecurityAssociation::new(0x0000_2000)
.encryption(EncryptionAlgorithm::AesCbc, vec![0x11u8; 16])
.integrity(IntegrityAlgorithm::HmacSha2_256_128, vec![0x33u8; 32])
}
fn gcm_sa() -> SecurityAssociation {
SecurityAssociation::new(0x0000_2000)
.encryption(EncryptionAlgorithm::AesGcm16, vec![0x22u8; 16])
.salt(vec![0xAA, 0xBB, 0xCC, 0xDD])
}
fn inner_notify() -> IkeNotifyPayload {
IkeNotifyPayload::new(NOTIFY_PROTOCOL_ESP, NotifyType::RekeySa, vec![0xDE, 0xAD])
.spi(vec![0x10u8, 0x20, 0x30, 0x40])
}
fn compile_sk(payload: IkeEncryptedPayload) -> Vec<u8> {
let packet = Packet::from_layer(payload);
let ctx = LayerContext::new(&packet, 0);
let mut out = Vec::new();
packet.get(0).unwrap().compile(&ctx, &mut out).unwrap();
out
}
#[test]
fn payload_type_is_encrypted_and_name_registered() {
let sk = IkeEncryptedPayload::new(cbc_hmac_sa()).payload(inner_notify());
assert_eq!(sk.payload_type(), PayloadType::Encrypted);
assert_eq!(sk.name(), IKE_ENCRYPTED_PAYLOAD_NAME);
assert_eq!(
payload_type_for_layer_name(IKE_ENCRYPTED_PAYLOAD_NAME),
Some(PayloadType::Encrypted)
);
}
#[test]
fn sk_does_not_consume_following_siblings() {
let sk = IkeEncryptedPayload::new(cbc_hmac_sa()).payload(inner_notify());
assert!(!sk.consumes_following());
}
#[test]
fn generic_header_next_payload_names_first_inner_payload() {
let sk = IkeEncryptedPayload::new(cbc_hmac_sa()).payload(inner_notify());
let bytes = compile_sk(sk);
assert_eq!(bytes[0], PAYLOAD_NOTIFY);
}
#[test]
fn cbc_round_trip_recovers_inner_notify() {
let sa = cbc_hmac_sa();
let inner = inner_notify();
let inner_packet = Packet::from_layer(inner.clone());
let inner_ctx = LayerContext::new(&inner_packet, 0);
let mut inner_bytes = Vec::new();
inner_packet
.get(0)
.unwrap()
.compile(&inner_ctx, &mut inner_bytes)
.unwrap();
let header = IkeHeader::new().exchange(35).initiator();
let sk = IkeEncryptedPayload::new(sa.clone()).payload(inner);
let message: Packet = Packet::from_layer(header) / sk;
let mut header_bytes = Vec::new();
let header_ctx = LayerContext::new(&message, 0);
message
.get(0)
.unwrap()
.compile(&header_ctx, &mut header_bytes)
.unwrap();
assert_eq!(header_bytes[16], PayloadType::Encrypted.codepoint());
let mut sk_bytes = Vec::new();
let sk_ctx = LayerContext::new(&message, 1);
message
.get(1)
.unwrap()
.compile(&sk_ctx, &mut sk_bytes)
.unwrap();
let decoded = decode_sk_payload_with_sa(&sk_bytes, &sa).unwrap();
assert_eq!(decoded.first_inner_payload, PAYLOAD_NOTIFY);
assert_eq!(decoded.inner_payloads, inner_bytes);
assert_eq!(
payload_type_for_layer_name(IKE_NOTIFY_PAYLOAD_NAME),
Some(PayloadType::Notify)
);
}
#[test]
fn gcm_round_trip_recovers_inner_notify() {
let sa = gcm_sa();
let inner = inner_notify();
let inner_packet = Packet::from_layer(inner.clone());
let inner_ctx = LayerContext::new(&inner_packet, 0);
let mut inner_bytes = Vec::new();
inner_packet
.get(0)
.unwrap()
.compile(&inner_ctx, &mut inner_bytes)
.unwrap();
let sk = IkeEncryptedPayload::new(sa.clone()).payload(inner);
let sk_bytes = compile_sk(sk);
let decoded = decode_sk_payload_with_sa(&sk_bytes, &sa).unwrap();
assert_eq!(decoded.first_inner_payload, PAYLOAD_NOTIFY);
assert_eq!(decoded.inner_payloads, inner_bytes);
}
#[test]
fn body_layout_is_iv_ciphertext_icv() {
let sa = cbc_hmac_sa();
let sk = IkeEncryptedPayload::new(sa.clone()).payload(inner_notify());
let bytes = compile_sk(sk);
let body = &bytes[GENERIC_PAYLOAD_HEADER_LEN..];
let iv_len = 16;
let icv_len = 16;
assert!(body.len() > iv_len + icv_len);
let ciphertext_len = body.len() - iv_len - icv_len;
assert_eq!(ciphertext_len % 16, 0);
let payload_len = u16::from_be_bytes([bytes[2], bytes[3]]) as usize;
assert_eq!(payload_len, bytes.len());
}
#[test]
fn iv_and_icv_overrides_are_emitted_verbatim() {
let sa = cbc_hmac_sa();
let iv: Vec<u8> = (0u8..16).collect();
let icv = vec![0xEEu8; 16];
let sk = IkeEncryptedPayload::new(sa)
.payload(inner_notify())
.iv(iv.clone())
.icv(icv.clone());
let bytes = compile_sk(sk);
let body = &bytes[GENERIC_PAYLOAD_HEADER_LEN..];
assert_eq!(&body[..16], &iv[..]);
assert_eq!(&body[body.len() - 16..], &icv[..]);
}
#[test]
fn flipped_icv_bit_makes_decode_error() {
let sa = cbc_hmac_sa();
let sk = IkeEncryptedPayload::new(sa.clone()).payload(inner_notify());
let mut bytes = compile_sk(sk);
let last = bytes.len() - 1;
bytes[last] ^= 0x01;
let result = decode_sk_payload_with_sa(&bytes, &sa);
assert!(result.is_err(), "a tampered ICV must make decode error");
}
#[test]
fn flipped_ciphertext_bit_makes_decode_error() {
let sa = cbc_hmac_sa();
let sk = IkeEncryptedPayload::new(sa.clone()).payload(inner_notify());
let mut bytes = compile_sk(sk);
let ct_index = GENERIC_PAYLOAD_HEADER_LEN + 16; bytes[ct_index] ^= 0x01;
let result = decode_sk_payload_with_sa(&bytes, &sa);
assert!(
result.is_err(),
"a tampered ciphertext must make decode error"
);
}
#[test]
fn gcm_flipped_icv_bit_makes_decode_error() {
let sa = gcm_sa();
let sk = IkeEncryptedPayload::new(sa.clone()).payload(inner_notify());
let mut bytes = compile_sk(sk);
let last = bytes.len() - 1;
bytes[last] ^= 0x01;
assert!(decode_sk_payload_with_sa(&bytes, &sa).is_err());
}
#[test]
fn decode_rejects_truncated_body() {
let sa = cbc_hmac_sa();
let bytes = vec![PAYLOAD_NOTIFY, 0, 0, 8, 0xAA, 0xBB, 0xCC, 0xDD];
let err = decode_sk_payload_with_sa(&bytes, &sa).unwrap_err();
assert!(matches!(err, CrafterError::BufferTooShort { .. }));
}
#[test]
fn decode_rejects_truncated_generic_header() {
let sa = cbc_hmac_sa();
let err = decode_sk_payload_with_sa(&[0u8, 0, 0], &sa).unwrap_err();
assert!(matches!(err, CrafterError::BufferTooShort { .. }));
}
#[test]
fn multiple_inner_payloads_chain_and_round_trip() {
let sa = cbc_hmac_sa();
let first = inner_notify();
let second = IkeNotifyPayload::new(
NOTIFY_PROTOCOL_ESP,
NotifyType::InitialContact,
Vec::<u8>::new(),
);
let inner_packet: Packet = Packet::from_layer(first.clone()) / second.clone();
let mut inner_bytes = Vec::new();
for (index, layer) in inner_packet.iter().enumerate() {
let ctx = LayerContext::new(&inner_packet, index);
layer.compile(&ctx, &mut inner_bytes).unwrap();
}
assert_eq!(inner_bytes[0], PAYLOAD_NOTIFY);
let sk = IkeEncryptedPayload::new(sa.clone())
.payload(first)
.payload(second);
let sk_bytes = compile_sk(sk);
let decoded = decode_sk_payload_with_sa(&sk_bytes, &sa).unwrap();
assert_eq!(decoded.first_inner_payload, PAYLOAD_NOTIFY);
assert_eq!(decoded.inner_payloads, inner_bytes);
}
}