use core::net::{Ipv4Addr, Ipv6Addr};
use crate::field::Field;
use crate::packet::{Layer, LayerContext};
use crate::protocols::ipsec::ikev2::payload::{
write_generic_payload_header, IkePayload, PayloadHeaderFields, PayloadType,
};
use crate::protocols::transport::common::{impl_layer_div, impl_layer_object};
use crate::CrafterError;
use crate::Result;
pub const IKE_TSI_PAYLOAD_NAME: &str = "IkeTsiPayload";
pub const IKE_TSR_PAYLOAD_NAME: &str = "IkeTsrPayload";
pub const TS_PAYLOAD_FIXED_LEN: usize = 4;
pub const TS_FIXED_LEN: usize = 8;
pub const TS_IPV4_ADDR_RANGE: u8 = 7;
pub const TS_IPV6_ADDR_RANGE: u8 = 8;
#[derive(Debug, Clone)]
pub struct TrafficSelector {
ts_type: u8,
ip_protocol: u8,
start_port: u16,
end_port: u16,
start_addr: Vec<u8>,
end_addr: Vec<u8>,
selector_length: Field<u16>,
}
impl TrafficSelector {
pub fn new(
ts_type: u8,
ip_protocol: u8,
start_port: u16,
end_port: u16,
start_addr: impl Into<Vec<u8>>,
end_addr: impl Into<Vec<u8>>,
) -> Self {
Self {
ts_type,
ip_protocol,
start_port,
end_port,
start_addr: start_addr.into(),
end_addr: end_addr.into(),
selector_length: Field::unset(),
}
}
pub fn ipv4_range(
ip_protocol: u8,
start_port: u16,
end_port: u16,
start: Ipv4Addr,
end: Ipv4Addr,
) -> Self {
Self::new(
TS_IPV4_ADDR_RANGE,
ip_protocol,
start_port,
end_port,
start.octets().to_vec(),
end.octets().to_vec(),
)
}
pub fn ipv6_range(
ip_protocol: u8,
start_port: u16,
end_port: u16,
start: Ipv6Addr,
end: Ipv6Addr,
) -> Self {
Self::new(
TS_IPV6_ADDR_RANGE,
ip_protocol,
start_port,
end_port,
start.octets().to_vec(),
end.octets().to_vec(),
)
}
pub fn selector_length(mut self, selector_length: u16) -> Self {
self.selector_length.set_user(selector_length);
self
}
pub fn ts_type(&self) -> u8 {
self.ts_type
}
pub fn ip_protocol(&self) -> u8 {
self.ip_protocol
}
pub fn start_port(&self) -> u16 {
self.start_port
}
pub fn end_port(&self) -> u16 {
self.end_port
}
pub fn start_addr(&self) -> &[u8] {
&self.start_addr
}
pub fn end_addr(&self) -> &[u8] {
&self.end_addr
}
pub fn encoded_len(&self) -> usize {
TS_FIXED_LEN + self.start_addr.len() + self.end_addr.len()
}
fn write(&self, out: &mut Vec<u8>) {
let selector_length = self
.selector_length
.value()
.copied()
.unwrap_or(self.encoded_len() as u16);
out.push(self.ts_type);
out.push(self.ip_protocol);
out.extend_from_slice(&selector_length.to_be_bytes());
out.extend_from_slice(&self.start_port.to_be_bytes());
out.extend_from_slice(&self.end_port.to_be_bytes());
out.extend_from_slice(&self.start_addr);
out.extend_from_slice(&self.end_addr);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TsRole {
Initiator,
Responder,
}
impl TsRole {
fn payload_type(self) -> PayloadType {
match self {
Self::Initiator => PayloadType::TrafficSelectorInitiator,
Self::Responder => PayloadType::TrafficSelectorResponder,
}
}
fn layer_name(self) -> &'static str {
match self {
Self::Initiator => IKE_TSI_PAYLOAD_NAME,
Self::Responder => IKE_TSR_PAYLOAD_NAME,
}
}
}
#[derive(Debug, Clone)]
pub struct IkeTsPayload {
role: TsRole,
selectors: Vec<TrafficSelector>,
number_of_ts: Field<u8>,
header: PayloadHeaderFields,
}
impl IkeTsPayload {
pub fn new(role: TsRole) -> Self {
Self {
role,
selectors: Vec::new(),
number_of_ts: Field::unset(),
header: PayloadHeaderFields::new(),
}
}
pub fn initiator() -> Self {
Self::new(TsRole::Initiator)
}
pub fn responder() -> Self {
Self::new(TsRole::Responder)
}
pub fn initiator_ipv4_range(
ip_protocol: u8,
start_port: u16,
end_port: u16,
start: Ipv4Addr,
end: Ipv4Addr,
) -> Self {
Self::initiator().with_selector(TrafficSelector::ipv4_range(
ip_protocol,
start_port,
end_port,
start,
end,
))
}
pub fn responder_ipv4_range(
ip_protocol: u8,
start_port: u16,
end_port: u16,
start: Ipv4Addr,
end: Ipv4Addr,
) -> Self {
Self::responder().with_selector(TrafficSelector::ipv4_range(
ip_protocol,
start_port,
end_port,
start,
end,
))
}
pub fn with_selector(mut self, selector: TrafficSelector) -> Self {
self.selectors.push(selector);
self
}
pub fn push_selector(&mut self, selector: TrafficSelector) {
self.selectors.push(selector);
}
pub fn number_of_ts(mut self, number_of_ts: u8) -> Self {
self.number_of_ts.set_user(number_of_ts);
self
}
pub fn next_payload(mut self, next_payload: u8) -> Self {
self.header.set_next_payload(next_payload);
self
}
pub fn payload_length(mut self, length: u16) -> Self {
self.header.set_length(length);
self
}
pub fn critical(mut self, critical: bool) -> Self {
self.header.set_critical(critical);
self
}
pub fn role(&self) -> TsRole {
self.role
}
pub fn selectors(&self) -> &[TrafficSelector] {
&self.selectors
}
fn ts_body(&self) -> Vec<u8> {
let number_of_ts = self
.number_of_ts
.value()
.copied()
.unwrap_or(self.selectors.len() as u8);
let mut out = Vec::with_capacity(TS_PAYLOAD_FIXED_LEN);
out.push(number_of_ts);
out.extend_from_slice(&[0u8, 0u8, 0u8]); for selector in &self.selectors {
selector.write(&mut out);
}
out
}
}
impl IkePayload for IkeTsPayload {
fn payload_type(&self) -> PayloadType {
self.role.payload_type()
}
fn payload_body(&self, _ctx: &LayerContext<'_>) -> Result<Vec<u8>> {
Ok(self.ts_body())
}
fn next_payload_override(&self) -> Option<u8> {
self.header.next_payload_override()
}
fn payload_length_override(&self) -> Option<u16> {
self.header.payload_length_override()
}
fn critical(&self) -> bool {
self.header.critical()
}
}
impl Layer for IkeTsPayload {
fn name(&self) -> &'static str {
self.role.layer_name()
}
fn summary(&self) -> String {
format!(
"{}(traffic_selectors={})",
self.role.layer_name(),
self.selectors.len()
)
}
fn inspection_fields(&self) -> Vec<(&'static str, String)> {
let mut fields = vec![("traffic_selectors", self.selectors.len().to_string())];
for selector in &self.selectors {
fields.push((
"traffic_selector",
format!(
"type={} protocol={} ports={}-{} addr_len={}",
selector.ts_type,
selector.ip_protocol,
selector.start_port,
selector.end_port,
selector.start_addr.len()
),
));
}
fields
}
fn encoded_len(&self) -> usize {
let selectors: usize = self.selectors.iter().map(|s| s.encoded_len()).sum();
super::GENERIC_PAYLOAD_HEADER_LEN + TS_PAYLOAD_FIXED_LEN + selectors
}
fn compile(&self, ctx: &LayerContext<'_>, out: &mut Vec<u8>) -> Result<()> {
let body = self.payload_body(ctx)?;
write_generic_payload_header(
out,
ctx,
self.next_payload_override(),
self.critical(),
self.payload_length_override(),
body.len(),
)?;
out.extend_from_slice(&body);
Ok(())
}
impl_layer_object!(IkeTsPayload);
}
impl_layer_div!(IkeTsPayload);
pub(crate) fn parse_traffic_selector(bytes: &[u8]) -> Result<(TrafficSelector, usize)> {
if bytes.len() < TS_FIXED_LEN {
return Err(CrafterError::buffer_too_short(
"ikev2.ts.selector",
TS_FIXED_LEN,
bytes.len(),
));
}
let ts_type = bytes[0];
let ip_protocol = bytes[1];
let selector_length = u16::from_be_bytes([bytes[2], bytes[3]]) as usize;
let start_port = u16::from_be_bytes([bytes[4], bytes[5]]);
let end_port = u16::from_be_bytes([bytes[6], bytes[7]]);
if selector_length < TS_FIXED_LEN || bytes.len() < selector_length {
return Err(CrafterError::buffer_too_short(
"ikev2.ts.selector.length",
selector_length.max(TS_FIXED_LEN),
bytes.len(),
));
}
let addr_total = selector_length - TS_FIXED_LEN;
let addr_len = addr_total / 2;
let start_addr = bytes[TS_FIXED_LEN..TS_FIXED_LEN + addr_len].to_vec();
let end_addr = bytes[TS_FIXED_LEN + addr_len..TS_FIXED_LEN + 2 * addr_len].to_vec();
Ok((
TrafficSelector::new(
ts_type,
ip_protocol,
start_port,
end_port,
start_addr,
end_addr,
),
selector_length,
))
}
pub(crate) fn parse_ts_payload_body(role: TsRole, bytes: &[u8]) -> Result<IkeTsPayload> {
if bytes.len() < TS_PAYLOAD_FIXED_LEN {
return Err(CrafterError::buffer_too_short(
"ikev2.ts",
TS_PAYLOAD_FIXED_LEN,
bytes.len(),
));
}
let number_of_ts = bytes[0] as usize;
let mut payload = IkeTsPayload::new(role);
let mut offset = TS_PAYLOAD_FIXED_LEN;
for _ in 0..number_of_ts {
let (selector, consumed) = parse_traffic_selector(&bytes[offset..])?;
payload.push_selector(selector);
offset += consumed;
}
Ok(payload)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::packet::{LayerContext, Packet, Raw};
use crate::protocols::ipsec::ikev2::payload::GENERIC_PAYLOAD_HEADER_LEN;
fn compile_payload(payload: IkeTsPayload) -> Vec<u8> {
let packet = Packet::from_layer(payload);
let ctx = LayerContext::new(&packet, 0);
let mut out = Vec::new();
packet.get(0).unwrap().compile(&ctx, &mut out).unwrap();
out
}
fn tsi_ipv4_payload() -> IkeTsPayload {
IkeTsPayload::initiator_ipv4_range(
6,
1024,
65535,
Ipv4Addr::new(192, 0, 2, 1),
Ipv4Addr::new(192, 0, 2, 254),
)
}
#[test]
fn ts_constants_match_manifest() {
assert_eq!(TS_PAYLOAD_FIXED_LEN, 4);
assert_eq!(TS_FIXED_LEN, 8);
assert_eq!(TS_IPV4_ADDR_RANGE, 7);
assert_eq!(TS_IPV6_ADDR_RANGE, 8);
assert_eq!(PayloadType::TrafficSelectorInitiator.codepoint(), 44);
assert_eq!(PayloadType::TrafficSelectorResponder.codepoint(), 45);
}
#[test]
fn role_selects_payload_type_and_name() {
let tsi = IkeTsPayload::initiator();
assert_eq!(tsi.role(), TsRole::Initiator);
assert_eq!(tsi.payload_type(), PayloadType::TrafficSelectorInitiator);
assert_eq!(tsi.name(), IKE_TSI_PAYLOAD_NAME);
let tsr = IkeTsPayload::responder();
assert_eq!(tsr.role(), TsRole::Responder);
assert_eq!(tsr.payload_type(), PayloadType::TrafficSelectorResponder);
assert_eq!(tsr.name(), IKE_TSR_PAYLOAD_NAME);
}
#[test]
fn ipv4_selector_is_sixteen_octets_with_auto_length() {
let selector = TrafficSelector::ipv4_range(
6,
1024,
65535,
Ipv4Addr::new(192, 0, 2, 1),
Ipv4Addr::new(192, 0, 2, 254),
);
assert_eq!(selector.ts_type(), TS_IPV4_ADDR_RANGE);
assert_eq!(selector.encoded_len(), 16);
let mut out = Vec::new();
selector.write(&mut out);
assert_eq!(out.len(), 16);
assert_eq!(out[0], TS_IPV4_ADDR_RANGE);
assert_eq!(out[1], 6); assert_eq!(u16::from_be_bytes([out[2], out[3]]), 16); assert_eq!(u16::from_be_bytes([out[4], out[5]]), 1024); assert_eq!(u16::from_be_bytes([out[6], out[7]]), 65535); assert_eq!(&out[8..12], &[192, 0, 2, 1]); assert_eq!(&out[12..16], &[192, 0, 2, 254]); }
#[test]
fn ipv6_selector_is_forty_octets() {
let start: Ipv6Addr = "2001:db8::1".parse().unwrap();
let end: Ipv6Addr = "2001:db8::ffff".parse().unwrap();
let selector = TrafficSelector::ipv6_range(17, 53, 53, start, end);
assert_eq!(selector.ts_type(), TS_IPV6_ADDR_RANGE);
assert_eq!(selector.encoded_len(), 40);
let mut out = Vec::new();
selector.write(&mut out);
assert_eq!(out.len(), 40);
assert_eq!(u16::from_be_bytes([out[2], out[3]]), 40); assert_eq!(&out[8..24], &start.octets());
assert_eq!(&out[24..40], &end.octets());
}
#[test]
fn ts_body_lays_out_count_reserved_then_selectors() {
let payload = tsi_ipv4_payload();
let body = payload.ts_body();
assert_eq!(body[0], 1); assert_eq!(&body[1..4], &[0, 0, 0]); assert_eq!(body.len(), TS_PAYLOAD_FIXED_LEN + 16);
assert_eq!(body[TS_PAYLOAD_FIXED_LEN], TS_IPV4_ADDR_RANGE);
}
#[test]
fn payload_compiles_generic_header_then_body() {
let payload = tsi_ipv4_payload();
let bytes = compile_payload(payload.clone());
assert_eq!(bytes[0], 0); assert_eq!(bytes[1], 0); let payload_len = u16::from_be_bytes([bytes[2], bytes[3]]) as usize;
assert_eq!(payload_len, bytes.len());
assert_eq!(payload_len, payload.encoded_len());
assert_eq!(&bytes[GENERIC_PAYLOAD_HEADER_LEN..], &payload.ts_body()[..]);
}
#[test]
fn payload_honors_generic_header_overrides() {
let payload = tsi_ipv4_payload()
.next_payload(45)
.critical(true)
.payload_length(0xBEEF);
let bytes = compile_payload(payload);
assert_eq!(bytes[0], 45);
assert_eq!(bytes[1], 0x80); assert_eq!(u16::from_be_bytes([bytes[2], bytes[3]]), 0xBEEF);
}
#[test]
fn number_of_ts_and_selector_length_overrides_emit_verbatim() {
let payload = IkeTsPayload::initiator()
.with_selector(
TrafficSelector::ipv4_range(
0,
0,
65535,
Ipv4Addr::new(192, 0, 2, 0),
Ipv4Addr::new(192, 0, 2, 255),
)
.selector_length(0x00FF),
)
.number_of_ts(9);
let bytes = compile_payload(payload);
let body = &bytes[GENERIC_PAYLOAD_HEADER_LEN..];
assert_eq!(body[0], 9); assert_eq!(u16::from_be_bytes([body[6], body[7]]), 0x00FF);
}
#[test]
fn payload_chain_next_payload_points_at_tsi_and_tsr() {
use crate::protocols::ipsec::ikev2::payload::{
following_next_payload, payload_type_for_layer_name, PAYLOAD_TSI, PAYLOAD_TSR,
};
assert_eq!(
payload_type_for_layer_name(IKE_TSI_PAYLOAD_NAME),
Some(PayloadType::TrafficSelectorInitiator)
);
assert_eq!(
payload_type_for_layer_name(IKE_TSR_PAYLOAD_NAME),
Some(PayloadType::TrafficSelectorResponder)
);
let tsi_packet: Packet = Packet::from_layer(Raw::from_bytes([0u8; 0])) / tsi_ipv4_payload();
let tsi_ctx = LayerContext::new(&tsi_packet, 0);
assert_eq!(following_next_payload(&tsi_ctx), PAYLOAD_TSI);
let tsr_packet: Packet = Packet::from_layer(Raw::from_bytes([0u8; 0]))
/ IkeTsPayload::responder_ipv4_range(
0,
0,
65535,
Ipv4Addr::new(198, 51, 100, 1),
Ipv4Addr::new(198, 51, 100, 254),
);
let tsr_ctx = LayerContext::new(&tsr_packet, 0);
assert_eq!(following_next_payload(&tsr_ctx), PAYLOAD_TSR);
}
#[test]
fn round_trip_tsi_ipv4_preserves_all_fields() {
let payload = tsi_ipv4_payload();
let bytes = compile_payload(payload.clone());
let parsed =
parse_ts_payload_body(TsRole::Initiator, &bytes[GENERIC_PAYLOAD_HEADER_LEN..]).unwrap();
assert_eq!(parsed.role(), TsRole::Initiator);
assert_eq!(parsed.selectors().len(), 1);
let selector = &parsed.selectors()[0];
assert_eq!(selector.ts_type(), TS_IPV4_ADDR_RANGE);
assert_eq!(selector.ip_protocol(), 6);
assert_eq!(selector.start_port(), 1024);
assert_eq!(selector.end_port(), 65535);
assert_eq!(selector.start_addr(), &[192, 0, 2, 1]);
assert_eq!(selector.end_addr(), &[192, 0, 2, 254]);
let recompiled = compile_payload(parsed);
assert_eq!(recompiled, bytes);
}
#[test]
fn round_trip_multiple_selectors() {
let payload = IkeTsPayload::responder()
.with_selector(TrafficSelector::ipv4_range(
0,
0,
65535,
Ipv4Addr::new(203, 0, 113, 0),
Ipv4Addr::new(203, 0, 113, 255),
))
.with_selector(TrafficSelector::ipv6_range(
0,
0,
65535,
"2001:db8::".parse().unwrap(),
"2001:db8::ffff".parse().unwrap(),
));
let bytes = compile_payload(payload);
let body = &bytes[GENERIC_PAYLOAD_HEADER_LEN..];
assert_eq!(body[0], 2);
let parsed = parse_ts_payload_body(TsRole::Responder, body).unwrap();
assert_eq!(parsed.selectors().len(), 2);
assert_eq!(parsed.selectors()[0].ts_type(), TS_IPV4_ADDR_RANGE);
assert_eq!(parsed.selectors()[0].encoded_len(), 16);
assert_eq!(parsed.selectors()[1].ts_type(), TS_IPV6_ADDR_RANGE);
assert_eq!(parsed.selectors()[1].encoded_len(), 40);
assert_eq!(compile_payload(parsed), bytes);
}
#[test]
fn parse_rejects_truncated_body() {
let err = parse_ts_payload_body(TsRole::Initiator, &[1u8, 0, 0]).unwrap_err();
assert!(matches!(err, CrafterError::BufferTooShort { .. }));
}
#[test]
fn parse_rejects_truncated_selector() {
let err =
parse_ts_payload_body(TsRole::Initiator, &[1u8, 0, 0, 0, 7, 6, 0, 16]).unwrap_err();
assert!(matches!(err, CrafterError::BufferTooShort { .. }));
}
}