use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
use super::{
connection::Connection,
error::Result,
protocol::{ProtocolState, Xfrm},
socket::NetlinkSocket,
};
const NLMSG_DONE: u16 = 3;
const NLMSG_ERROR: u16 = 2;
const NLM_F_REQUEST: u16 = 0x01;
const NLM_F_DUMP: u16 = 0x300;
const XFRM_MSG_GETSA: u16 = 0x12;
const XFRM_MSG_GETPOLICY: u16 = 0x15;
const XFRMA_ALG_AUTH: u16 = 1;
const XFRMA_ALG_CRYPT: u16 = 2;
const XFRMA_ALG_COMP: u16 = 3;
const XFRMA_ENCAP: u16 = 4;
const XFRMA_ALG_AEAD: u16 = 18;
const XFRMA_ALG_AUTH_TRUNC: u16 = 20;
const XFRMA_MARK: u16 = 21;
const XFRMA_IF_ID: u16 = 31;
const XFRM_MODE_TRANSPORT: u8 = 0;
const XFRM_MODE_TUNNEL: u8 = 1;
const XFRM_MODE_BEET: u8 = 4;
const IPPROTO_ESP: u8 = 50;
const IPPROTO_AH: u8 = 51;
const IPPROTO_COMP: u8 = 108;
const XFRM_POLICY_IN: u8 = 0;
const XFRM_POLICY_OUT: u8 = 1;
const XFRM_POLICY_FWD: u8 = 2;
const XFRM_POLICY_ALLOW: u8 = 0;
const XFRM_POLICY_BLOCK: u8 = 1;
const NLMSG_HDRLEN: usize = 16;
#[repr(C)]
#[derive(Debug, Clone, Copy, Default, FromBytes, IntoBytes, Immutable, KnownLayout)]
pub struct XfrmAddress {
pub bytes: [u8; 16],
}
impl XfrmAddress {
pub fn from_v4(addr: Ipv4Addr) -> Self {
let mut bytes = [0u8; 16];
bytes[..4].copy_from_slice(&addr.octets());
Self { bytes }
}
pub fn from_v6(addr: Ipv6Addr) -> Self {
Self {
bytes: addr.octets(),
}
}
pub fn to_ip(&self, family: u16) -> Option<IpAddr> {
match family {
2 => {
Some(IpAddr::V4(Ipv4Addr::new(
self.bytes[0],
self.bytes[1],
self.bytes[2],
self.bytes[3],
)))
}
10 => {
Some(IpAddr::V6(Ipv6Addr::from(self.bytes)))
}
_ => None,
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default, FromBytes, IntoBytes, Immutable, KnownLayout)]
pub struct XfrmId {
pub daddr: XfrmAddress,
pub spi: u32,
pub proto: u8,
pub _pad: [u8; 3],
}
#[repr(C, packed)]
#[derive(Debug, Clone, Copy, Default, FromBytes, IntoBytes, Immutable, KnownLayout)]
pub struct XfrmSelector {
pub daddr: XfrmAddress,
pub saddr: XfrmAddress,
pub dport: u16,
pub dport_mask: u16,
pub sport: u16,
pub sport_mask: u16,
pub family: u16,
pub prefixlen_d: u8,
pub prefixlen_s: u8,
pub proto: u8,
pub _pad1: [u8; 3],
pub ifindex: i32,
pub user: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default, FromBytes, IntoBytes, Immutable, KnownLayout)]
pub struct XfrmLifetimeCfg {
pub soft_byte_limit: u64,
pub hard_byte_limit: u64,
pub soft_packet_limit: u64,
pub hard_packet_limit: u64,
pub soft_add_expires_seconds: u64,
pub hard_add_expires_seconds: u64,
pub soft_use_expires_seconds: u64,
pub hard_use_expires_seconds: u64,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default, FromBytes, IntoBytes, Immutable, KnownLayout)]
pub struct XfrmLifetimeCur {
pub bytes: u64,
pub packets: u64,
pub add_time: u64,
pub use_time: u64,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default, FromBytes, IntoBytes, Immutable, KnownLayout)]
pub struct XfrmStats {
pub replay_window: u32,
pub replay: u32,
pub integrity_failed: u32,
}
#[repr(C, packed)]
#[derive(Debug, Clone, Copy, Default, FromBytes, IntoBytes, Immutable, KnownLayout)]
pub struct XfrmUsersaInfo {
pub sel: XfrmSelector,
pub id: XfrmId,
pub saddr: XfrmAddress,
pub lft: XfrmLifetimeCfg,
pub curlft: XfrmLifetimeCur,
pub stats: XfrmStats,
pub seq: u32,
pub reqid: u32,
pub family: u16,
pub mode: u8,
pub replay_window: u8,
pub flags: u8,
pub _pad: [u8; 7],
}
#[repr(C, packed)]
#[derive(Debug, Clone, Copy, Default, FromBytes, IntoBytes, Immutable, KnownLayout)]
pub struct XfrmUserpolicyInfo {
pub sel: XfrmSelector,
pub lft: XfrmLifetimeCfg,
pub curlft: XfrmLifetimeCur,
pub priority: u32,
pub index: u32,
pub dir: u8,
pub action: u8,
pub flags: u8,
pub share: u8,
}
#[derive(Debug, Clone)]
pub struct XfrmAlgorithm {
pub name: String,
pub key_len: u32,
pub key: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct XfrmAlgorithmAead {
pub name: String,
pub key_len: u32,
pub icv_len: u32,
pub key: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct XfrmAlgorithmAuthTrunc {
pub name: String,
pub key_len: u32,
pub trunc_len: u32,
pub key: Vec<u8>,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default, FromBytes, IntoBytes, Immutable, KnownLayout)]
pub struct XfrmEncapTmpl {
pub encap_type: u16,
pub encap_sport: u16,
pub encap_dport: u16,
pub _pad: u16,
pub encap_oa: XfrmAddress,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default, FromBytes, IntoBytes, Immutable, KnownLayout)]
pub struct XfrmMark {
pub v: u32,
pub m: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum IpsecProtocol {
Esp,
Ah,
Comp,
Other(u8),
}
impl IpsecProtocol {
fn from_u8(val: u8) -> Self {
match val {
IPPROTO_ESP => Self::Esp,
IPPROTO_AH => Self::Ah,
IPPROTO_COMP => Self::Comp,
other => Self::Other(other),
}
}
pub fn number(&self) -> u8 {
match self {
Self::Esp => IPPROTO_ESP,
Self::Ah => IPPROTO_AH,
Self::Comp => IPPROTO_COMP,
Self::Other(n) => *n,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum XfrmMode {
Transport,
Tunnel,
Beet,
Other(u8),
}
impl XfrmMode {
fn from_u8(val: u8) -> Self {
match val {
XFRM_MODE_TRANSPORT => Self::Transport,
XFRM_MODE_TUNNEL => Self::Tunnel,
XFRM_MODE_BEET => Self::Beet,
other => Self::Other(other),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum PolicyDirection {
In,
Out,
Forward,
Unknown(u8),
}
impl PolicyDirection {
fn from_u8(val: u8) -> Self {
match val {
XFRM_POLICY_IN => Self::In,
XFRM_POLICY_OUT => Self::Out,
XFRM_POLICY_FWD => Self::Forward,
other => Self::Unknown(other),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum PolicyAction {
Allow,
Block,
Unknown(u8),
}
impl PolicyAction {
fn from_u8(val: u8) -> Self {
match val {
XFRM_POLICY_ALLOW => Self::Allow,
XFRM_POLICY_BLOCK => Self::Block,
other => Self::Unknown(other),
}
}
}
#[derive(Debug, Clone)]
pub struct TrafficSelector {
pub src_addr: Option<IpAddr>,
pub dst_addr: Option<IpAddr>,
pub src_prefix_len: u8,
pub dst_prefix_len: u8,
pub src_port: Option<u16>,
pub dst_port: Option<u16>,
pub proto: u8,
}
impl TrafficSelector {
fn from_selector(sel: XfrmSelector) -> Self {
Self {
src_addr: sel.saddr.to_ip(sel.family),
dst_addr: sel.daddr.to_ip(sel.family),
src_prefix_len: sel.prefixlen_s,
dst_prefix_len: sel.prefixlen_d,
src_port: if sel.sport != 0 {
Some(u16::from_be(sel.sport))
} else {
None
},
dst_port: if sel.dport != 0 {
Some(u16::from_be(sel.dport))
} else {
None
},
proto: sel.proto,
}
}
}
#[derive(Debug, Clone)]
pub struct SecurityAssociation {
pub src_addr: Option<IpAddr>,
pub dst_addr: Option<IpAddr>,
pub spi: u32,
pub protocol: IpsecProtocol,
pub mode: XfrmMode,
pub reqid: u32,
pub selector: TrafficSelector,
pub enc_alg: Option<XfrmAlgorithm>,
pub auth_alg: Option<XfrmAlgorithm>,
pub aead_alg: Option<XfrmAlgorithmAead>,
pub auth_trunc_alg: Option<XfrmAlgorithmAuthTrunc>,
pub comp_alg: Option<XfrmAlgorithm>,
pub encap: Option<XfrmEncapTmpl>,
pub mark: Option<XfrmMark>,
pub if_id: Option<u32>,
pub bytes: u64,
pub packets: u64,
pub replay_window: u8,
pub flags: u8,
}
#[derive(Debug, Clone)]
pub struct SecurityPolicy {
pub selector: TrafficSelector,
pub direction: PolicyDirection,
pub action: PolicyAction,
pub priority: u32,
pub index: u32,
pub flags: u8,
pub mark: Option<XfrmMark>,
pub if_id: Option<u32>,
}
impl Connection<Xfrm> {
pub fn new() -> Result<Self> {
let socket = NetlinkSocket::new(Xfrm::PROTOCOL)?;
Ok(Self::from_parts(socket, Xfrm))
}
#[tracing::instrument(
level = "debug",
skip_all,
fields(method = "get_security_associations")
)]
pub async fn get_security_associations(&self) -> Result<Vec<SecurityAssociation>> {
let seq = self.socket().next_seq();
let pid = self.socket().pid();
let mut buf = Vec::with_capacity(64);
buf.extend_from_slice(&0u32.to_ne_bytes()); buf.extend_from_slice(&XFRM_MSG_GETSA.to_ne_bytes()); buf.extend_from_slice(&(NLM_F_REQUEST | NLM_F_DUMP).to_ne_bytes()); buf.extend_from_slice(&seq.to_ne_bytes()); buf.extend_from_slice(&pid.to_ne_bytes());
let sa_info = XfrmUsersaInfo::default();
buf.extend_from_slice(sa_info.as_bytes());
let len = buf.len() as u32;
buf[0..4].copy_from_slice(&len.to_ne_bytes());
self.socket().send(&buf).await?;
let mut sas = Vec::new();
loop {
let data = self.socket().recv_msg().await?;
let mut offset = 0;
while offset + NLMSG_HDRLEN <= data.len() {
let nlmsg_len = u32::from_ne_bytes([
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
]) as usize;
let nlmsg_type = u16::from_ne_bytes([data[offset + 4], data[offset + 5]]);
if nlmsg_len < NLMSG_HDRLEN || offset + nlmsg_len > data.len() {
break;
}
match nlmsg_type {
NLMSG_DONE => return Ok(sas),
NLMSG_ERROR => {
if nlmsg_len >= 20 {
let errno = i32::from_ne_bytes([
data[offset + 16],
data[offset + 17],
data[offset + 18],
data[offset + 19],
]);
if errno != 0 {
return Err(super::error::Error::from_errno(-errno));
}
}
}
_ => {
if let Some(sa) = self.parse_sa(&data[offset..offset + nlmsg_len]) {
sas.push(sa);
}
}
}
offset += (nlmsg_len + 3) & !3;
}
}
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "get_security_policies"))]
pub async fn get_security_policies(&self) -> Result<Vec<SecurityPolicy>> {
let seq = self.socket().next_seq();
let pid = self.socket().pid();
let mut buf = Vec::with_capacity(64);
buf.extend_from_slice(&0u32.to_ne_bytes()); buf.extend_from_slice(&XFRM_MSG_GETPOLICY.to_ne_bytes()); buf.extend_from_slice(&(NLM_F_REQUEST | NLM_F_DUMP).to_ne_bytes()); buf.extend_from_slice(&seq.to_ne_bytes()); buf.extend_from_slice(&pid.to_ne_bytes());
let pol_info = XfrmUserpolicyInfo::default();
buf.extend_from_slice(pol_info.as_bytes());
let len = buf.len() as u32;
buf[0..4].copy_from_slice(&len.to_ne_bytes());
self.socket().send(&buf).await?;
let mut policies = Vec::new();
loop {
let data = self.socket().recv_msg().await?;
let mut offset = 0;
while offset + NLMSG_HDRLEN <= data.len() {
let nlmsg_len = u32::from_ne_bytes([
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
]) as usize;
let nlmsg_type = u16::from_ne_bytes([data[offset + 4], data[offset + 5]]);
if nlmsg_len < NLMSG_HDRLEN || offset + nlmsg_len > data.len() {
break;
}
match nlmsg_type {
NLMSG_DONE => return Ok(policies),
NLMSG_ERROR => {
if nlmsg_len >= 20 {
let errno = i32::from_ne_bytes([
data[offset + 16],
data[offset + 17],
data[offset + 18],
data[offset + 19],
]);
if errno != 0 {
return Err(super::error::Error::from_errno(-errno));
}
}
}
_ => {
if let Some(pol) = self.parse_policy(&data[offset..offset + nlmsg_len]) {
policies.push(pol);
}
}
}
offset += (nlmsg_len + 3) & !3;
}
}
}
fn parse_sa(&self, data: &[u8]) -> Option<SecurityAssociation> {
if data.len() < NLMSG_HDRLEN + std::mem::size_of::<XfrmUsersaInfo>() {
return None;
}
let msg_data = &data[NLMSG_HDRLEN..];
let (info, _) = XfrmUsersaInfo::ref_from_prefix(msg_data).ok()?;
let mut sa = SecurityAssociation {
src_addr: info.saddr.to_ip(info.family),
dst_addr: info.id.daddr.to_ip(info.family),
spi: u32::from_be(info.id.spi),
protocol: IpsecProtocol::from_u8(info.id.proto),
mode: XfrmMode::from_u8(info.mode),
reqid: info.reqid,
selector: TrafficSelector::from_selector(info.sel),
enc_alg: None,
auth_alg: None,
aead_alg: None,
auth_trunc_alg: None,
comp_alg: None,
encap: None,
mark: None,
if_id: None,
bytes: info.curlft.bytes,
packets: info.curlft.packets,
replay_window: info.replay_window,
flags: info.flags,
};
let attr_start = NLMSG_HDRLEN + std::mem::size_of::<XfrmUsersaInfo>();
if data.len() > attr_start {
let mut input = &data[attr_start..];
while let Some((attr_type, attr_data)) = parse_nla(&mut input) {
match attr_type {
XFRMA_ALG_CRYPT => {
sa.enc_alg = parse_algorithm(attr_data);
}
XFRMA_ALG_AUTH => {
sa.auth_alg = parse_algorithm(attr_data);
}
XFRMA_ALG_AEAD => {
sa.aead_alg = parse_aead_algorithm(attr_data);
}
XFRMA_ALG_AUTH_TRUNC => {
sa.auth_trunc_alg = parse_auth_trunc_algorithm(attr_data);
}
XFRMA_ALG_COMP => {
sa.comp_alg = parse_algorithm(attr_data);
}
XFRMA_ENCAP => {
if attr_data.len() >= std::mem::size_of::<XfrmEncapTmpl>()
&& let Ok((encap, _)) = XfrmEncapTmpl::ref_from_prefix(attr_data)
{
sa.encap = Some(*encap);
}
}
XFRMA_MARK => {
if attr_data.len() >= std::mem::size_of::<XfrmMark>()
&& let Ok((mark, _)) = XfrmMark::ref_from_prefix(attr_data)
{
sa.mark = Some(*mark);
}
}
XFRMA_IF_ID if attr_data.len() >= 4 => {
sa.if_id = Some(u32::from_ne_bytes([
attr_data[0],
attr_data[1],
attr_data[2],
attr_data[3],
]));
}
_ => {}
}
}
}
Some(sa)
}
fn parse_policy(&self, data: &[u8]) -> Option<SecurityPolicy> {
if data.len() < NLMSG_HDRLEN + std::mem::size_of::<XfrmUserpolicyInfo>() {
return None;
}
let msg_data = &data[NLMSG_HDRLEN..];
let (info, _) = XfrmUserpolicyInfo::ref_from_prefix(msg_data).ok()?;
let mut policy = SecurityPolicy {
selector: TrafficSelector::from_selector(info.sel),
direction: PolicyDirection::from_u8(info.dir),
action: PolicyAction::from_u8(info.action),
priority: info.priority,
index: info.index,
flags: info.flags,
mark: None,
if_id: None,
};
let attr_start = NLMSG_HDRLEN + std::mem::size_of::<XfrmUserpolicyInfo>();
if data.len() > attr_start {
let mut input = &data[attr_start..];
while let Some((attr_type, attr_data)) = parse_nla(&mut input) {
match attr_type {
XFRMA_MARK => {
if attr_data.len() >= std::mem::size_of::<XfrmMark>()
&& let Ok((mark, _)) = XfrmMark::ref_from_prefix(attr_data)
{
policy.mark = Some(*mark);
}
}
XFRMA_IF_ID if attr_data.len() >= 4 => {
policy.if_id = Some(u32::from_ne_bytes([
attr_data[0],
attr_data[1],
attr_data[2],
attr_data[3],
]));
}
_ => {}
}
}
}
Some(policy)
}
}
fn parse_nla<'a>(input: &mut &'a [u8]) -> Option<(u16, &'a [u8])> {
if input.len() < 4 {
return None;
}
let len = u16::from_le_bytes([input[0], input[1]]) as usize;
let attr_type = u16::from_le_bytes([input[2], input[3]]);
*input = &input[4..];
if len < 4 {
return None;
}
let payload_len = len.saturating_sub(4);
if input.len() < payload_len {
return None;
}
let payload = &input[..payload_len];
*input = &input[payload_len..];
let aligned = (len + 3) & !3;
let padding = aligned.saturating_sub(len);
if input.len() >= padding {
*input = &input[padding..];
}
Some((attr_type, payload))
}
fn parse_algorithm(data: &[u8]) -> Option<XfrmAlgorithm> {
if data.len() < 68 {
return None;
}
let name = parse_cstring(&data[..64]);
let key_len = u32::from_le_bytes([data[64], data[65], data[66], data[67]]);
let key_bytes = (key_len as usize).div_ceil(8);
let key = if data.len() >= 68 + key_bytes {
data[68..68 + key_bytes].to_vec()
} else {
Vec::new()
};
Some(XfrmAlgorithm { name, key_len, key })
}
fn parse_aead_algorithm(data: &[u8]) -> Option<XfrmAlgorithmAead> {
if data.len() < 72 {
return None;
}
let name = parse_cstring(&data[..64]);
let key_len = u32::from_le_bytes([data[64], data[65], data[66], data[67]]);
let icv_len = u32::from_le_bytes([data[68], data[69], data[70], data[71]]);
let key_bytes = (key_len as usize).div_ceil(8);
let key = if data.len() >= 72 + key_bytes {
data[72..72 + key_bytes].to_vec()
} else {
Vec::new()
};
Some(XfrmAlgorithmAead {
name,
key_len,
icv_len,
key,
})
}
fn parse_auth_trunc_algorithm(data: &[u8]) -> Option<XfrmAlgorithmAuthTrunc> {
if data.len() < 72 {
return None;
}
let name = parse_cstring(&data[..64]);
let key_len = u32::from_le_bytes([data[64], data[65], data[66], data[67]]);
let trunc_len = u32::from_le_bytes([data[68], data[69], data[70], data[71]]);
let key_bytes = (key_len as usize).div_ceil(8);
let key = if data.len() >= 72 + key_bytes {
data[72..72 + key_bytes].to_vec()
} else {
Vec::new()
};
Some(XfrmAlgorithmAuthTrunc {
name,
key_len,
trunc_len,
key,
})
}
fn parse_cstring(data: &[u8]) -> String {
let end = data.iter().position(|&b| b == 0).unwrap_or(data.len());
String::from_utf8_lossy(&data[..end]).to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn xfrm_address_ipv4() {
let addr = XfrmAddress::from_v4(Ipv4Addr::new(192, 168, 1, 1));
assert_eq!(
addr.to_ip(2),
Some(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)))
);
}
#[test]
fn xfrm_address_ipv6() {
let addr = XfrmAddress::from_v6(Ipv6Addr::LOCALHOST);
assert_eq!(addr.to_ip(10), Some(IpAddr::V6(Ipv6Addr::LOCALHOST)));
}
#[test]
fn ipsec_protocol_roundtrip() {
assert_eq!(IpsecProtocol::Esp.number(), 50);
assert_eq!(IpsecProtocol::from_u8(50), IpsecProtocol::Esp);
assert_eq!(IpsecProtocol::Ah.number(), 51);
assert_eq!(IpsecProtocol::from_u8(51), IpsecProtocol::Ah);
}
#[test]
fn xfrm_mode_from_u8() {
assert_eq!(XfrmMode::from_u8(0), XfrmMode::Transport);
assert_eq!(XfrmMode::from_u8(1), XfrmMode::Tunnel);
assert_eq!(XfrmMode::from_u8(4), XfrmMode::Beet);
}
#[test]
fn policy_direction_from_u8() {
assert_eq!(PolicyDirection::from_u8(0), PolicyDirection::In);
assert_eq!(PolicyDirection::from_u8(1), PolicyDirection::Out);
assert_eq!(PolicyDirection::from_u8(2), PolicyDirection::Forward);
}
#[test]
fn zerocopy_sizes() {
assert_eq!(std::mem::size_of::<XfrmAddress>(), 16);
assert_eq!(std::mem::size_of::<XfrmId>(), 24);
assert_eq!(std::mem::size_of::<XfrmMark>(), 8);
}
}