1#[macro_use]
4extern crate log;
5
6use crate::constants::*;
7use anyhow::Context;
8use netlink_packet_generic::{GenlFamily, GenlHeader};
9use netlink_packet_utils::{nla::NlasIterator, traits::*, DecodeError};
10use nlas::WgDeviceAttrs;
11use std::convert::{TryFrom, TryInto};
12
13pub mod constants;
14pub mod nlas;
15mod raw;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
18pub enum WireguardCmd {
19 GetDevice,
20 SetDevice,
21}
22
23impl From<WireguardCmd> for u8 {
24 fn from(cmd: WireguardCmd) -> Self {
25 use WireguardCmd::*;
26 match cmd {
27 GetDevice => WG_CMD_GET_DEVICE,
28 SetDevice => WG_CMD_SET_DEVICE,
29 }
30 }
31}
32
33impl TryFrom<u8> for WireguardCmd {
34 type Error = DecodeError;
35
36 fn try_from(value: u8) -> Result<Self, Self::Error> {
37 use WireguardCmd::*;
38 Ok(match value {
39 WG_CMD_GET_DEVICE => GetDevice,
40 WG_CMD_SET_DEVICE => SetDevice,
41 cmd => {
42 return Err(DecodeError::from(format!(
43 "Unknown wireguard command: {}",
44 cmd
45 )))
46 }
47 })
48 }
49}
50
51#[derive(Clone, Debug, PartialEq, Eq)]
52pub struct Wireguard {
53 pub cmd: WireguardCmd,
54 pub nlas: Vec<nlas::WgDeviceAttrs>,
55}
56
57impl GenlFamily for Wireguard {
58 fn family_name() -> &'static str {
59 "wireguard"
60 }
61
62 fn version(&self) -> u8 {
63 1
64 }
65
66 fn command(&self) -> u8 {
67 self.cmd.into()
68 }
69}
70
71impl Emitable for Wireguard {
72 fn emit(&self, buffer: &mut [u8]) {
73 self.nlas.as_slice().emit(buffer)
74 }
75
76 fn buffer_len(&self) -> usize {
77 self.nlas.as_slice().buffer_len()
78 }
79}
80
81impl ParseableParametrized<[u8], GenlHeader> for Wireguard {
82 fn parse_with_param(
83 buf: &[u8],
84 header: GenlHeader,
85 ) -> Result<Self, DecodeError> {
86 Ok(Self {
87 cmd: header.cmd.try_into()?,
88 nlas: parse_nlas(buf)?,
89 })
90 }
91}
92
93fn parse_nlas(buf: &[u8]) -> Result<Vec<WgDeviceAttrs>, DecodeError> {
94 let mut nlas = Vec::new();
95 let error_msg = "failed to parse message attributes";
96 for nla in NlasIterator::new(buf) {
97 let nla = &nla.context(error_msg)?;
98 let parsed = WgDeviceAttrs::parse(nla).context(error_msg)?;
99 nlas.push(parsed);
100 }
101 Ok(nlas)
102}
103
104#[cfg(test)]
105
106mod test {
107 use netlink_packet_core::{NetlinkMessage, NLM_F_ACK, NLM_F_REQUEST};
108 use netlink_packet_generic::GenlMessage;
109
110 use crate::nlas::{WgAllowedIp, WgAllowedIpAttrs, WgPeer, WgPeerAttrs};
111
112 use super::*;
113
114 const KNOWN_VALID_PACKET: &[u8] = &[
115 0x74, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x05, 0x00, 0x38, 0x24, 0xd6, 0x61,
116 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x0b, 0x00, 0x02, 0x00,
117 0x66, 0x72, 0x61, 0x6e, 0x64, 0x73, 0x00, 0x00, 0x54, 0x00, 0x08, 0x80,
118 0x50, 0x00, 0x00, 0x80, 0x24, 0x00, 0x01, 0x00, 0x01, 0x01, 0x01, 0x01,
119 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
120 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
121 0x01, 0x01, 0x01, 0x01, 0x08, 0x00, 0x03, 0x00, 0x02, 0x00, 0x00, 0x00,
122 0x20, 0x00, 0x09, 0x80, 0x1c, 0x00, 0x00, 0x80, 0x06, 0x00, 0x01, 0x00,
123 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 0x00, 0x0a, 0x0a, 0x0a, 0x0a,
124 0x05, 0x00, 0x03, 0x00, 0x1e, 0x00, 0x00, 0x00,
125 ];
126
127 #[test]
128 fn test_parse_known_valid_packet() {
129 NetlinkMessage::<GenlMessage<Wireguard>>::deserialize(
130 KNOWN_VALID_PACKET,
131 )
132 .unwrap();
133 }
134
135 #[test]
136 fn test_serialize_then_deserialize() {
137 let genlmsg: GenlMessage<Wireguard> =
138 GenlMessage::from_payload(Wireguard {
139 cmd: WireguardCmd::SetDevice,
140 nlas: vec![
141 WgDeviceAttrs::IfName("wg0".to_string()),
142 WgDeviceAttrs::PrivateKey([0xaa; 32]),
143 WgDeviceAttrs::Peers(vec![
144 WgPeer(vec![
145 WgPeerAttrs::PublicKey([0x01; 32]),
146 WgPeerAttrs::PresharedKey([0x01; 32]),
147 WgPeerAttrs::AllowedIps(vec![WgAllowedIp(vec![
148 WgAllowedIpAttrs::IpAddr([10, 0, 0, 0].into()),
149 WgAllowedIpAttrs::Cidr(24),
150 WgAllowedIpAttrs::Family(AF_INET),
151 ])]),
152 ]),
153 WgPeer(vec![
154 WgPeerAttrs::PublicKey([0x02; 32]),
155 WgPeerAttrs::PresharedKey([0x01; 32]),
156 WgPeerAttrs::AllowedIps(vec![WgAllowedIp(vec![
157 WgAllowedIpAttrs::IpAddr([10, 0, 1, 0].into()),
158 WgAllowedIpAttrs::Cidr(24),
159 WgAllowedIpAttrs::Family(AF_INET),
160 ])]),
161 ]),
162 ]),
163 ],
164 });
165 let mut nlmsg = NetlinkMessage::from(genlmsg);
166 nlmsg.header.flags = NLM_F_REQUEST | NLM_F_ACK;
167
168 nlmsg.finalize();
169 let mut buf = [0; 4096];
170 nlmsg.serialize(&mut buf);
171 let len = nlmsg.buffer_len();
172 NetlinkMessage::<GenlMessage<Wireguard>>::deserialize(&buf[..len])
173 .unwrap();
174 }
175}