use crate::packet::{Layer, LayerContext};
use crate::protocols::ipsec::ikev2::payload::{
write_generic_payload_header, IkePayload, PayloadHeaderFields, PayloadType,
};
use crate::protocols::transport::common::{impl_layer_div, impl_layer_object};
use crate::Result;
pub const IKE_VENDOR_ID_PAYLOAD_NAME: &str = "IkeVendorIdPayload";
#[derive(Debug, Clone)]
pub struct IkeVendorIdPayload {
vendor_id: Vec<u8>,
header: PayloadHeaderFields,
}
impl IkeVendorIdPayload {
pub fn new(vendor_id: impl Into<Vec<u8>>) -> Self {
Self {
vendor_id: vendor_id.into(),
header: PayloadHeaderFields::new(),
}
}
pub fn vendor_id(mut self, vendor_id: impl Into<Vec<u8>>) -> Self {
self.vendor_id = vendor_id.into();
self
}
pub fn next_payload(mut self, next_payload: u8) -> Self {
self.header.set_next_payload(next_payload);
self
}
pub fn payload_length(mut self, length: u16) -> Self {
self.header.set_length(length);
self
}
pub fn critical(mut self, critical: bool) -> Self {
self.header.set_critical(critical);
self
}
pub fn vendor_id_bytes(&self) -> &[u8] {
&self.vendor_id
}
}
impl IkePayload for IkeVendorIdPayload {
fn payload_type(&self) -> PayloadType {
PayloadType::VendorId
}
fn payload_body(&self, _ctx: &LayerContext<'_>) -> Result<Vec<u8>> {
Ok(self.vendor_id.clone())
}
fn next_payload_override(&self) -> Option<u8> {
self.header.next_payload_override()
}
fn payload_length_override(&self) -> Option<u16> {
self.header.payload_length_override()
}
fn critical(&self) -> bool {
self.header.critical()
}
}
impl Layer for IkeVendorIdPayload {
fn name(&self) -> &'static str {
IKE_VENDOR_ID_PAYLOAD_NAME
}
fn summary(&self) -> String {
format!("IkeVendorIdPayload(vendor_id_len={})", self.vendor_id.len())
}
fn inspection_fields(&self) -> Vec<(&'static str, String)> {
vec![("vendor_id_len", self.vendor_id.len().to_string())]
}
fn encoded_len(&self) -> usize {
super::GENERIC_PAYLOAD_HEADER_LEN + self.vendor_id.len()
}
fn compile(&self, ctx: &LayerContext<'_>, out: &mut Vec<u8>) -> Result<()> {
let body = self.payload_body(ctx)?;
write_generic_payload_header(
out,
ctx,
self.next_payload_override(),
self.critical(),
self.payload_length_override(),
body.len(),
)?;
out.extend_from_slice(&body);
Ok(())
}
impl_layer_object!(IkeVendorIdPayload);
}
impl_layer_div!(IkeVendorIdPayload);
pub(crate) fn parse_vendor_id_payload_body(bytes: &[u8]) -> Result<IkeVendorIdPayload> {
Ok(IkeVendorIdPayload::new(bytes.to_vec()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::packet::{LayerContext, Packet, Raw};
use crate::protocols::ipsec::ikev2::payload::GENERIC_PAYLOAD_HEADER_LEN;
fn compile_payload(payload: IkeVendorIdPayload) -> Vec<u8> {
let packet = Packet::from_layer(payload);
let ctx = LayerContext::new(&packet, 0);
let mut out = Vec::new();
packet.get(0).unwrap().compile(&ctx, &mut out).unwrap();
out
}
fn vendor_payload() -> IkeVendorIdPayload {
IkeVendorIdPayload::new((0u8..16).collect::<Vec<u8>>())
}
#[test]
fn vendor_payload_type_and_name() {
let payload = vendor_payload();
assert_eq!(payload.payload_type(), PayloadType::VendorId);
assert_eq!(PayloadType::VendorId.codepoint(), 43);
assert_eq!(payload.name(), IKE_VENDOR_ID_PAYLOAD_NAME);
}
#[test]
fn body_is_vendor_id_bytes() {
let payload = vendor_payload();
let packet = Packet::from_layer(Raw::from_bytes([0u8; 0]));
let ctx = LayerContext::new(&packet, 0);
let body = payload.payload_body(&ctx).unwrap();
assert_eq!(&body[..], &(0u8..16).collect::<Vec<u8>>()[..]);
assert_eq!(
payload.vendor_id_bytes(),
&(0u8..16).collect::<Vec<u8>>()[..]
);
}
#[test]
fn payload_compiles_generic_header_then_body() {
let payload = vendor_payload();
let bytes = compile_payload(payload.clone());
assert_eq!(bytes[0], 0); assert_eq!(bytes[1], 0); let payload_len = u16::from_be_bytes([bytes[2], bytes[3]]) as usize;
assert_eq!(payload_len, bytes.len());
assert_eq!(payload_len, payload.encoded_len());
assert_eq!(
&bytes[GENERIC_PAYLOAD_HEADER_LEN..],
&(0u8..16).collect::<Vec<u8>>()[..]
);
}
#[test]
fn payload_honors_generic_header_overrides() {
let payload = vendor_payload()
.next_payload(40)
.critical(true)
.payload_length(0xBEEF);
let bytes = compile_payload(payload);
assert_eq!(bytes[0], 40);
assert_eq!(bytes[1], 0x80); assert_eq!(u16::from_be_bytes([bytes[2], bytes[3]]), 0xBEEF);
}
#[test]
fn chain_next_payload_points_at_vendor_id() {
use crate::protocols::ipsec::ikev2::payload::{
following_next_payload, payload_type_for_layer_name, PAYLOAD_VENDOR_ID,
};
assert_eq!(
payload_type_for_layer_name(IKE_VENDOR_ID_PAYLOAD_NAME),
Some(PayloadType::VendorId)
);
let packet: Packet = Packet::from_layer(Raw::from_bytes([0u8; 0])) / vendor_payload();
let ctx = LayerContext::new(&packet, 0);
assert_eq!(following_next_payload(&ctx), PAYLOAD_VENDOR_ID);
}
#[test]
fn round_trip_preserves_vendor_id() {
let payload = vendor_payload();
let bytes = compile_payload(payload.clone());
let parsed = parse_vendor_id_payload_body(&bytes[GENERIC_PAYLOAD_HEADER_LEN..]).unwrap();
assert_eq!(
parsed.vendor_id_bytes(),
&(0u8..16).collect::<Vec<u8>>()[..]
);
let recompiled = compile_payload(parsed);
assert_eq!(recompiled, bytes);
}
#[test]
fn empty_vendor_id_round_trips() {
let payload = IkeVendorIdPayload::new(Vec::<u8>::new());
let bytes = compile_payload(payload.clone());
assert_eq!(bytes.len(), GENERIC_PAYLOAD_HEADER_LEN);
let parsed = parse_vendor_id_payload_body(&bytes[GENERIC_PAYLOAD_HEADER_LEN..]).unwrap();
assert!(parsed.vendor_id_bytes().is_empty());
assert_eq!(compile_payload(parsed), bytes);
}
}