netlink_packet_wireguard/
lib.rs

1// SPDX-License-Identifier: MIT
2
3#[macro_use]
4extern crate log;
5
6use crate::constants::*;
7use netlink_packet_core::{
8    DecodeError, Emitable, ErrorContext, NlasIterator, Parseable,
9    ParseableParametrized,
10};
11use netlink_packet_generic::{GenlFamily, GenlHeader};
12use nlas::WgDeviceAttrs;
13use std::convert::{TryFrom, TryInto};
14
15pub mod constants;
16pub mod nlas;
17mod raw;
18
19#[derive(Clone, Copy, Debug, PartialEq, Eq)]
20pub enum WireguardCmd {
21    GetDevice,
22    SetDevice,
23}
24
25impl From<WireguardCmd> for u8 {
26    fn from(cmd: WireguardCmd) -> Self {
27        use WireguardCmd::*;
28        match cmd {
29            GetDevice => WG_CMD_GET_DEVICE,
30            SetDevice => WG_CMD_SET_DEVICE,
31        }
32    }
33}
34
35impl TryFrom<u8> for WireguardCmd {
36    type Error = DecodeError;
37
38    fn try_from(value: u8) -> Result<Self, Self::Error> {
39        use WireguardCmd::*;
40        Ok(match value {
41            WG_CMD_GET_DEVICE => GetDevice,
42            WG_CMD_SET_DEVICE => SetDevice,
43            cmd => {
44                return Err(DecodeError::from(format!(
45                    "Unknown wireguard command: {}",
46                    cmd
47                )))
48            }
49        })
50    }
51}
52
53#[derive(Clone, Debug, PartialEq, Eq)]
54pub struct Wireguard {
55    pub cmd: WireguardCmd,
56    pub nlas: Vec<nlas::WgDeviceAttrs>,
57}
58
59impl GenlFamily for Wireguard {
60    fn family_name() -> &'static str {
61        "wireguard"
62    }
63
64    fn version(&self) -> u8 {
65        1
66    }
67
68    fn command(&self) -> u8 {
69        self.cmd.into()
70    }
71}
72
73impl Emitable for Wireguard {
74    fn emit(&self, buffer: &mut [u8]) {
75        self.nlas.as_slice().emit(buffer)
76    }
77
78    fn buffer_len(&self) -> usize {
79        self.nlas.as_slice().buffer_len()
80    }
81}
82
83impl ParseableParametrized<[u8], GenlHeader> for Wireguard {
84    fn parse_with_param(
85        buf: &[u8],
86        header: GenlHeader,
87    ) -> Result<Self, DecodeError> {
88        Ok(Self {
89            cmd: header.cmd.try_into()?,
90            nlas: parse_nlas(buf)?,
91        })
92    }
93}
94
95fn parse_nlas(buf: &[u8]) -> Result<Vec<WgDeviceAttrs>, DecodeError> {
96    let mut nlas = Vec::new();
97    let error_msg = "failed to parse message attributes";
98    for nla in NlasIterator::new(buf) {
99        let nla = &nla.context(error_msg)?;
100        let parsed = WgDeviceAttrs::parse(nla).context(error_msg)?;
101        nlas.push(parsed);
102    }
103    Ok(nlas)
104}
105
106#[cfg(test)]
107
108mod test {
109    use netlink_packet_core::{NetlinkMessage, NLM_F_ACK, NLM_F_REQUEST};
110    use netlink_packet_generic::GenlMessage;
111
112    use crate::nlas::{WgAllowedIp, WgAllowedIpAttrs, WgPeer, WgPeerAttrs};
113
114    use super::*;
115
116    const KNOWN_VALID_PACKET: &[u8] = &[
117        0x74, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x05, 0x00, 0x38, 0x24, 0xd6, 0x61,
118        0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x0b, 0x00, 0x02, 0x00,
119        0x66, 0x72, 0x61, 0x6e, 0x64, 0x73, 0x00, 0x00, 0x54, 0x00, 0x08, 0x80,
120        0x50, 0x00, 0x00, 0x80, 0x24, 0x00, 0x01, 0x00, 0x01, 0x01, 0x01, 0x01,
121        0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
122        0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
123        0x01, 0x01, 0x01, 0x01, 0x08, 0x00, 0x03, 0x00, 0x02, 0x00, 0x00, 0x00,
124        0x20, 0x00, 0x09, 0x80, 0x1c, 0x00, 0x00, 0x80, 0x06, 0x00, 0x01, 0x00,
125        0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 0x00, 0x0a, 0x0a, 0x0a, 0x0a,
126        0x05, 0x00, 0x03, 0x00, 0x1e, 0x00, 0x00, 0x00,
127    ];
128
129    #[test]
130    fn test_parse_known_valid_packet() {
131        NetlinkMessage::<GenlMessage<Wireguard>>::deserialize(
132            KNOWN_VALID_PACKET,
133        )
134        .unwrap();
135    }
136
137    #[test]
138    fn test_serialize_then_deserialize() {
139        let genlmsg: GenlMessage<Wireguard> =
140            GenlMessage::from_payload(Wireguard {
141                cmd: WireguardCmd::SetDevice,
142                nlas: vec![
143                    WgDeviceAttrs::IfName("wg0".to_string()),
144                    WgDeviceAttrs::PrivateKey([0xaa; 32]),
145                    WgDeviceAttrs::Peers(vec![
146                        WgPeer(vec![
147                            WgPeerAttrs::PublicKey([0x01; 32]),
148                            WgPeerAttrs::PresharedKey([0x01; 32]),
149                            WgPeerAttrs::AllowedIps(vec![WgAllowedIp(vec![
150                                WgAllowedIpAttrs::IpAddr([10, 0, 0, 0].into()),
151                                WgAllowedIpAttrs::Cidr(24),
152                                WgAllowedIpAttrs::Family(AF_INET),
153                            ])]),
154                        ]),
155                        WgPeer(vec![
156                            WgPeerAttrs::PublicKey([0x02; 32]),
157                            WgPeerAttrs::PresharedKey([0x01; 32]),
158                            WgPeerAttrs::AllowedIps(vec![WgAllowedIp(vec![
159                                WgAllowedIpAttrs::IpAddr([10, 0, 1, 0].into()),
160                                WgAllowedIpAttrs::Cidr(24),
161                                WgAllowedIpAttrs::Family(AF_INET),
162                            ])]),
163                        ]),
164                    ]),
165                ],
166            });
167        let mut nlmsg = NetlinkMessage::from(genlmsg);
168        nlmsg.header.flags = NLM_F_REQUEST | NLM_F_ACK;
169
170        nlmsg.finalize();
171        let mut buf = [0; 4096];
172        nlmsg.serialize(&mut buf);
173        let len = nlmsg.buffer_len();
174        NetlinkMessage::<GenlMessage<Wireguard>>::deserialize(&buf[..len])
175            .unwrap();
176    }
177}