use crate::field::Field;
use crate::packet::{Layer, LayerContext};
use crate::protocols::ipsec::ikev2::payload::{
write_generic_payload_header, IkePayload, PayloadHeaderFields, PayloadType,
};
use crate::protocols::transport::common::{impl_layer_div, impl_layer_object};
use crate::CrafterError;
use crate::Result;
pub const IKE_AUTH_PAYLOAD_NAME: &str = "IkeAuthPayload";
pub const AUTH_FIXED_LEN: usize = 4;
pub const AUTH_RSA_DIGITAL_SIGNATURE: u8 = 1;
pub const AUTH_SHARED_KEY_MIC: u8 = 2;
pub const AUTH_DSS_DIGITAL_SIGNATURE: u8 = 3;
pub const AUTH_DIGITAL_SIGNATURE: u8 = 14;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AuthMethod {
RsaDigitalSignature,
SharedKeyMic,
DssDigitalSignature,
DigitalSignature,
Unknown(u8),
}
impl AuthMethod {
pub fn codepoint(self) -> u8 {
match self {
Self::RsaDigitalSignature => AUTH_RSA_DIGITAL_SIGNATURE,
Self::SharedKeyMic => AUTH_SHARED_KEY_MIC,
Self::DssDigitalSignature => AUTH_DSS_DIGITAL_SIGNATURE,
Self::DigitalSignature => AUTH_DIGITAL_SIGNATURE,
Self::Unknown(value) => value,
}
}
}
impl From<u8> for AuthMethod {
fn from(value: u8) -> Self {
match value {
AUTH_RSA_DIGITAL_SIGNATURE => Self::RsaDigitalSignature,
AUTH_SHARED_KEY_MIC => Self::SharedKeyMic,
AUTH_DSS_DIGITAL_SIGNATURE => Self::DssDigitalSignature,
AUTH_DIGITAL_SIGNATURE => Self::DigitalSignature,
other => Self::Unknown(other),
}
}
}
impl From<AuthMethod> for u8 {
fn from(auth_method: AuthMethod) -> Self {
auth_method.codepoint()
}
}
#[derive(Debug, Clone)]
pub struct IkeAuthPayload {
auth_method: Field<u8>,
auth_data: Vec<u8>,
header: PayloadHeaderFields,
}
impl IkeAuthPayload {
pub fn new(auth_method: impl Into<AuthMethod>, auth_data: impl Into<Vec<u8>>) -> Self {
Self {
auth_method: Field::user(auth_method.into().codepoint()),
auth_data: auth_data.into(),
header: PayloadHeaderFields::new(),
}
}
pub fn shared_key_mic(auth_data: impl Into<Vec<u8>>) -> Self {
Self::new(AuthMethod::SharedKeyMic, auth_data)
}
pub fn rsa_digital_signature(auth_data: impl Into<Vec<u8>>) -> Self {
Self::new(AuthMethod::RsaDigitalSignature, auth_data)
}
pub fn digital_signature(auth_data: impl Into<Vec<u8>>) -> Self {
Self::new(AuthMethod::DigitalSignature, auth_data)
}
pub fn auth_method(mut self, auth_method: impl Into<AuthMethod>) -> Self {
self.auth_method.set_user(auth_method.into().codepoint());
self
}
pub fn auth_data(mut self, auth_data: impl Into<Vec<u8>>) -> Self {
self.auth_data = auth_data.into();
self
}
pub fn next_payload(mut self, next_payload: u8) -> Self {
self.header.set_next_payload(next_payload);
self
}
pub fn payload_length(mut self, length: u16) -> Self {
self.header.set_length(length);
self
}
pub fn critical(mut self, critical: bool) -> Self {
self.header.set_critical(critical);
self
}
pub fn auth_method_value(&self) -> u8 {
self.auth_method.value().copied().unwrap_or(0)
}
pub fn auth_method_kind(&self) -> AuthMethod {
AuthMethod::from(self.auth_method_value())
}
pub fn auth_data_bytes(&self) -> &[u8] {
&self.auth_data
}
fn auth_body(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(AUTH_FIXED_LEN + self.auth_data.len());
out.push(self.auth_method_value());
out.extend_from_slice(&[0u8, 0u8, 0u8]); out.extend_from_slice(&self.auth_data);
out
}
}
impl IkePayload for IkeAuthPayload {
fn payload_type(&self) -> PayloadType {
PayloadType::Authentication
}
fn payload_body(&self, _ctx: &LayerContext<'_>) -> Result<Vec<u8>> {
Ok(self.auth_body())
}
fn next_payload_override(&self) -> Option<u8> {
self.header.next_payload_override()
}
fn payload_length_override(&self) -> Option<u16> {
self.header.payload_length_override()
}
fn critical(&self) -> bool {
self.header.critical()
}
}
impl Layer for IkeAuthPayload {
fn name(&self) -> &'static str {
IKE_AUTH_PAYLOAD_NAME
}
fn summary(&self) -> String {
format!(
"IkeAuthPayload(auth_method={}, auth_data_len={})",
self.auth_method_value(),
self.auth_data.len()
)
}
fn inspection_fields(&self) -> Vec<(&'static str, String)> {
vec![
("auth_method", self.auth_method_value().to_string()),
("auth_data_len", self.auth_data.len().to_string()),
]
}
fn encoded_len(&self) -> usize {
super::GENERIC_PAYLOAD_HEADER_LEN + AUTH_FIXED_LEN + self.auth_data.len()
}
fn compile(&self, ctx: &LayerContext<'_>, out: &mut Vec<u8>) -> Result<()> {
let body = self.payload_body(ctx)?;
write_generic_payload_header(
out,
ctx,
self.next_payload_override(),
self.critical(),
self.payload_length_override(),
body.len(),
)?;
out.extend_from_slice(&body);
Ok(())
}
impl_layer_object!(IkeAuthPayload);
}
impl_layer_div!(IkeAuthPayload);
pub(crate) fn parse_auth_payload_body(bytes: &[u8]) -> Result<IkeAuthPayload> {
if bytes.len() < AUTH_FIXED_LEN {
return Err(CrafterError::buffer_too_short(
"ikev2.auth",
AUTH_FIXED_LEN,
bytes.len(),
));
}
let auth_method = bytes[0];
let auth_data = bytes[AUTH_FIXED_LEN..].to_vec();
Ok(IkeAuthPayload::new(auth_method, auth_data))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::packet::{LayerContext, Packet, Raw};
use crate::protocols::ipsec::ikev2::payload::GENERIC_PAYLOAD_HEADER_LEN;
fn compile_payload(payload: IkeAuthPayload) -> Vec<u8> {
let packet = Packet::from_layer(payload);
let ctx = LayerContext::new(&packet, 0);
let mut out = Vec::new();
packet.get(0).unwrap().compile(&ctx, &mut out).unwrap();
out
}
fn shared_key_mic_payload() -> IkeAuthPayload {
IkeAuthPayload::shared_key_mic((0u8..20).collect::<Vec<u8>>())
}
#[test]
fn auth_constants_match_manifest() {
assert_eq!(AUTH_FIXED_LEN, 4);
assert_eq!(AUTH_RSA_DIGITAL_SIGNATURE, 1);
assert_eq!(AUTH_SHARED_KEY_MIC, 2);
assert_eq!(AUTH_DSS_DIGITAL_SIGNATURE, 3);
assert_eq!(AUTH_DIGITAL_SIGNATURE, 14);
assert_eq!(PayloadType::Authentication.codepoint(), 39);
}
#[test]
fn auth_method_round_trips_through_u8() {
for value in 0u8..=255 {
let auth_method = AuthMethod::from(value);
assert_eq!(auth_method.codepoint(), value);
assert_eq!(u8::from(auth_method), value);
}
}
#[test]
fn named_auth_methods_map_to_codepoints() {
assert_eq!(
AuthMethod::from(AUTH_RSA_DIGITAL_SIGNATURE),
AuthMethod::RsaDigitalSignature
);
assert_eq!(
AuthMethod::from(AUTH_SHARED_KEY_MIC),
AuthMethod::SharedKeyMic
);
assert_eq!(
AuthMethod::from(AUTH_DSS_DIGITAL_SIGNATURE),
AuthMethod::DssDigitalSignature
);
assert_eq!(
AuthMethod::from(AUTH_DIGITAL_SIGNATURE),
AuthMethod::DigitalSignature
);
}
#[test]
fn unknown_auth_method_is_preserved() {
let unassigned = 200u8;
assert_eq!(
AuthMethod::from(unassigned),
AuthMethod::Unknown(unassigned)
);
assert_eq!(AuthMethod::Unknown(unassigned).codepoint(), unassigned);
assert_eq!(AuthMethod::from(9), AuthMethod::Unknown(9));
}
#[allow(clippy::unnecessary_fallible_conversions)]
#[test]
fn try_from_u8_is_infallible_and_preserves_unknown() {
let known: AuthMethod = u8::try_into(AUTH_SHARED_KEY_MIC).unwrap();
assert_eq!(known, AuthMethod::SharedKeyMic);
let unknown: AuthMethod = u8::try_into(222u8).unwrap();
assert_eq!(unknown, AuthMethod::Unknown(222));
}
#[test]
fn payload_type_is_authentication() {
let payload = shared_key_mic_payload();
assert_eq!(payload.payload_type(), PayloadType::Authentication);
assert_eq!(payload.name(), IKE_AUTH_PAYLOAD_NAME);
}
#[test]
fn body_lays_out_method_reserved_then_data() {
let payload = shared_key_mic_payload();
let body = payload.auth_body();
assert_eq!(body[0], AUTH_SHARED_KEY_MIC);
assert_eq!(&body[1..4], &[0, 0, 0]); assert_eq!(&body[AUTH_FIXED_LEN..], &(0u8..20).collect::<Vec<u8>>()[..]);
assert_eq!(body.len(), AUTH_FIXED_LEN + 20);
assert_eq!(payload.auth_method_kind(), AuthMethod::SharedKeyMic);
}
#[test]
fn payload_compiles_generic_header_then_body() {
let payload = shared_key_mic_payload();
let bytes = compile_payload(payload.clone());
assert_eq!(bytes[0], 0); assert_eq!(bytes[1], 0); let payload_len = u16::from_be_bytes([bytes[2], bytes[3]]) as usize;
assert_eq!(payload_len, bytes.len());
assert_eq!(payload_len, payload.encoded_len());
assert_eq!(
&bytes[GENERIC_PAYLOAD_HEADER_LEN..],
&payload.auth_body()[..]
);
}
#[test]
fn payload_honors_generic_header_overrides() {
let payload = shared_key_mic_payload()
.next_payload(40)
.critical(true)
.payload_length(0xBEEF);
let bytes = compile_payload(payload);
assert_eq!(bytes[0], 40);
assert_eq!(bytes[1], 0x80); assert_eq!(u16::from_be_bytes([bytes[2], bytes[3]]), 0xBEEF);
}
#[test]
fn payload_chain_next_payload_points_at_auth() {
use crate::protocols::ipsec::ikev2::payload::{
following_next_payload, payload_type_for_layer_name, PAYLOAD_AUTH,
};
assert_eq!(
payload_type_for_layer_name(IKE_AUTH_PAYLOAD_NAME),
Some(PayloadType::Authentication)
);
let packet: Packet =
Packet::from_layer(Raw::from_bytes([0u8; 0])) / shared_key_mic_payload();
let ctx = LayerContext::new(&packet, 0);
assert_eq!(following_next_payload(&ctx), PAYLOAD_AUTH);
}
#[test]
fn round_trip_shared_key_mic_preserves_method_and_data() {
let payload = shared_key_mic_payload();
let bytes = compile_payload(payload.clone());
let parsed = parse_auth_payload_body(&bytes[GENERIC_PAYLOAD_HEADER_LEN..]).unwrap();
assert_eq!(parsed.auth_method_value(), AUTH_SHARED_KEY_MIC);
assert_eq!(parsed.auth_method_kind(), AuthMethod::SharedKeyMic);
assert_eq!(
parsed.auth_data_bytes(),
&(0u8..20).collect::<Vec<u8>>()[..]
);
let recompiled = compile_payload(parsed);
assert_eq!(recompiled, bytes);
}
#[test]
fn round_trip_digital_signature_preserves_method_and_data() {
let payload = IkeAuthPayload::digital_signature(vec![0xAA, 0xBB, 0xCC, 0xDD]);
let bytes = compile_payload(payload.clone());
let parsed = parse_auth_payload_body(&bytes[GENERIC_PAYLOAD_HEADER_LEN..]).unwrap();
assert_eq!(parsed.auth_method_value(), AUTH_DIGITAL_SIGNATURE);
assert_eq!(parsed.auth_method_kind(), AuthMethod::DigitalSignature);
assert_eq!(parsed.auth_data_bytes(), &[0xAA, 0xBB, 0xCC, 0xDD]);
let recompiled = compile_payload(parsed);
assert_eq!(recompiled, bytes);
}
#[test]
fn parse_rejects_truncated_body() {
let err = parse_auth_payload_body(&[2u8, 0, 0]).unwrap_err();
assert!(matches!(err, CrafterError::BufferTooShort { .. }));
}
}