netlink_packet_wireguard/
lib.rs

1// SPDX-License-Identifier: MIT
2
3#[macro_use]
4extern crate log;
5
6use crate::constants::*;
7use anyhow::Context;
8use netlink_packet_generic::{GenlFamily, GenlHeader};
9use netlink_packet_utils::{nla::NlasIterator, traits::*, DecodeError};
10use nlas::WgDeviceAttrs;
11use std::convert::{TryFrom, TryInto};
12
13pub mod constants;
14pub mod nlas;
15mod raw;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
18pub enum WireguardCmd {
19    GetDevice,
20    SetDevice,
21}
22
23impl From<WireguardCmd> for u8 {
24    fn from(cmd: WireguardCmd) -> Self {
25        use WireguardCmd::*;
26        match cmd {
27            GetDevice => WG_CMD_GET_DEVICE,
28            SetDevice => WG_CMD_SET_DEVICE,
29        }
30    }
31}
32
33impl TryFrom<u8> for WireguardCmd {
34    type Error = DecodeError;
35
36    fn try_from(value: u8) -> Result<Self, Self::Error> {
37        use WireguardCmd::*;
38        Ok(match value {
39            WG_CMD_GET_DEVICE => GetDevice,
40            WG_CMD_SET_DEVICE => SetDevice,
41            cmd => {
42                return Err(DecodeError::from(format!(
43                    "Unknown wireguard command: {}",
44                    cmd
45                )))
46            }
47        })
48    }
49}
50
51#[derive(Clone, Debug, PartialEq, Eq)]
52pub struct Wireguard {
53    pub cmd: WireguardCmd,
54    pub nlas: Vec<nlas::WgDeviceAttrs>,
55}
56
57impl GenlFamily for Wireguard {
58    fn family_name() -> &'static str {
59        "wireguard"
60    }
61
62    fn version(&self) -> u8 {
63        1
64    }
65
66    fn command(&self) -> u8 {
67        self.cmd.into()
68    }
69}
70
71impl Emitable for Wireguard {
72    fn emit(&self, buffer: &mut [u8]) {
73        self.nlas.as_slice().emit(buffer)
74    }
75
76    fn buffer_len(&self) -> usize {
77        self.nlas.as_slice().buffer_len()
78    }
79}
80
81impl ParseableParametrized<[u8], GenlHeader> for Wireguard {
82    fn parse_with_param(
83        buf: &[u8],
84        header: GenlHeader,
85    ) -> Result<Self, DecodeError> {
86        Ok(Self {
87            cmd: header.cmd.try_into()?,
88            nlas: parse_nlas(buf)?,
89        })
90    }
91}
92
93fn parse_nlas(buf: &[u8]) -> Result<Vec<WgDeviceAttrs>, DecodeError> {
94    let mut nlas = Vec::new();
95    let error_msg = "failed to parse message attributes";
96    for nla in NlasIterator::new(buf) {
97        let nla = &nla.context(error_msg)?;
98        let parsed = WgDeviceAttrs::parse(nla).context(error_msg)?;
99        nlas.push(parsed);
100    }
101    Ok(nlas)
102}
103
104#[cfg(test)]
105
106mod test {
107    use netlink_packet_core::{NetlinkMessage, NLM_F_ACK, NLM_F_REQUEST};
108    use netlink_packet_generic::GenlMessage;
109
110    use crate::nlas::{WgAllowedIp, WgAllowedIpAttrs, WgPeer, WgPeerAttrs};
111
112    use super::*;
113
114    const KNOWN_VALID_PACKET: &[u8] = &[
115        0x74, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x05, 0x00, 0x38, 0x24, 0xd6, 0x61,
116        0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x0b, 0x00, 0x02, 0x00,
117        0x66, 0x72, 0x61, 0x6e, 0x64, 0x73, 0x00, 0x00, 0x54, 0x00, 0x08, 0x80,
118        0x50, 0x00, 0x00, 0x80, 0x24, 0x00, 0x01, 0x00, 0x01, 0x01, 0x01, 0x01,
119        0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
120        0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
121        0x01, 0x01, 0x01, 0x01, 0x08, 0x00, 0x03, 0x00, 0x02, 0x00, 0x00, 0x00,
122        0x20, 0x00, 0x09, 0x80, 0x1c, 0x00, 0x00, 0x80, 0x06, 0x00, 0x01, 0x00,
123        0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 0x00, 0x0a, 0x0a, 0x0a, 0x0a,
124        0x05, 0x00, 0x03, 0x00, 0x1e, 0x00, 0x00, 0x00,
125    ];
126
127    #[test]
128    fn test_parse_known_valid_packet() {
129        NetlinkMessage::<GenlMessage<Wireguard>>::deserialize(
130            KNOWN_VALID_PACKET,
131        )
132        .unwrap();
133    }
134
135    #[test]
136    fn test_serialize_then_deserialize() {
137        let genlmsg: GenlMessage<Wireguard> =
138            GenlMessage::from_payload(Wireguard {
139                cmd: WireguardCmd::SetDevice,
140                nlas: vec![
141                    WgDeviceAttrs::IfName("wg0".to_string()),
142                    WgDeviceAttrs::PrivateKey([0xaa; 32]),
143                    WgDeviceAttrs::Peers(vec![
144                        WgPeer(vec![
145                            WgPeerAttrs::PublicKey([0x01; 32]),
146                            WgPeerAttrs::PresharedKey([0x01; 32]),
147                            WgPeerAttrs::AllowedIps(vec![WgAllowedIp(vec![
148                                WgAllowedIpAttrs::IpAddr([10, 0, 0, 0].into()),
149                                WgAllowedIpAttrs::Cidr(24),
150                                WgAllowedIpAttrs::Family(AF_INET),
151                            ])]),
152                        ]),
153                        WgPeer(vec![
154                            WgPeerAttrs::PublicKey([0x02; 32]),
155                            WgPeerAttrs::PresharedKey([0x01; 32]),
156                            WgPeerAttrs::AllowedIps(vec![WgAllowedIp(vec![
157                                WgAllowedIpAttrs::IpAddr([10, 0, 1, 0].into()),
158                                WgAllowedIpAttrs::Cidr(24),
159                                WgAllowedIpAttrs::Family(AF_INET),
160                            ])]),
161                        ]),
162                    ]),
163                ],
164            });
165        let mut nlmsg = NetlinkMessage::from(genlmsg);
166        nlmsg.header.flags = NLM_F_REQUEST | NLM_F_ACK;
167
168        nlmsg.finalize();
169        let mut buf = [0; 4096];
170        nlmsg.serialize(&mut buf);
171        let len = nlmsg.buffer_len();
172        NetlinkMessage::<GenlMessage<Wireguard>>::deserialize(&buf[..len])
173            .unwrap();
174    }
175}