1#[macro_use]
4extern crate log;
5
6use crate::constants::*;
7use netlink_packet_core::{
8 DecodeError, Emitable, ErrorContext, NlasIterator, Parseable,
9 ParseableParametrized,
10};
11use netlink_packet_generic::{GenlFamily, GenlHeader};
12use nlas::WgDeviceAttrs;
13use std::convert::{TryFrom, TryInto};
14
15pub mod constants;
16pub mod nlas;
17mod raw;
18
19#[derive(Clone, Copy, Debug, PartialEq, Eq)]
20pub enum WireguardCmd {
21 GetDevice,
22 SetDevice,
23}
24
25impl From<WireguardCmd> for u8 {
26 fn from(cmd: WireguardCmd) -> Self {
27 use WireguardCmd::*;
28 match cmd {
29 GetDevice => WG_CMD_GET_DEVICE,
30 SetDevice => WG_CMD_SET_DEVICE,
31 }
32 }
33}
34
35impl TryFrom<u8> for WireguardCmd {
36 type Error = DecodeError;
37
38 fn try_from(value: u8) -> Result<Self, Self::Error> {
39 use WireguardCmd::*;
40 Ok(match value {
41 WG_CMD_GET_DEVICE => GetDevice,
42 WG_CMD_SET_DEVICE => SetDevice,
43 cmd => {
44 return Err(DecodeError::from(format!(
45 "Unknown wireguard command: {}",
46 cmd
47 )))
48 }
49 })
50 }
51}
52
53#[derive(Clone, Debug, PartialEq, Eq)]
54pub struct Wireguard {
55 pub cmd: WireguardCmd,
56 pub nlas: Vec<nlas::WgDeviceAttrs>,
57}
58
59impl GenlFamily for Wireguard {
60 fn family_name() -> &'static str {
61 "wireguard"
62 }
63
64 fn version(&self) -> u8 {
65 1
66 }
67
68 fn command(&self) -> u8 {
69 self.cmd.into()
70 }
71}
72
73impl Emitable for Wireguard {
74 fn emit(&self, buffer: &mut [u8]) {
75 self.nlas.as_slice().emit(buffer)
76 }
77
78 fn buffer_len(&self) -> usize {
79 self.nlas.as_slice().buffer_len()
80 }
81}
82
83impl ParseableParametrized<[u8], GenlHeader> for Wireguard {
84 fn parse_with_param(
85 buf: &[u8],
86 header: GenlHeader,
87 ) -> Result<Self, DecodeError> {
88 Ok(Self {
89 cmd: header.cmd.try_into()?,
90 nlas: parse_nlas(buf)?,
91 })
92 }
93}
94
95fn parse_nlas(buf: &[u8]) -> Result<Vec<WgDeviceAttrs>, DecodeError> {
96 let mut nlas = Vec::new();
97 let error_msg = "failed to parse message attributes";
98 for nla in NlasIterator::new(buf) {
99 let nla = &nla.context(error_msg)?;
100 let parsed = WgDeviceAttrs::parse(nla).context(error_msg)?;
101 nlas.push(parsed);
102 }
103 Ok(nlas)
104}
105
106#[cfg(test)]
107
108mod test {
109 use netlink_packet_core::{NetlinkMessage, NLM_F_ACK, NLM_F_REQUEST};
110 use netlink_packet_generic::GenlMessage;
111
112 use crate::nlas::{WgAllowedIp, WgAllowedIpAttrs, WgPeer, WgPeerAttrs};
113
114 use super::*;
115
116 const KNOWN_VALID_PACKET: &[u8] = &[
117 0x74, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x05, 0x00, 0x38, 0x24, 0xd6, 0x61,
118 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x0b, 0x00, 0x02, 0x00,
119 0x66, 0x72, 0x61, 0x6e, 0x64, 0x73, 0x00, 0x00, 0x54, 0x00, 0x08, 0x80,
120 0x50, 0x00, 0x00, 0x80, 0x24, 0x00, 0x01, 0x00, 0x01, 0x01, 0x01, 0x01,
121 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
122 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
123 0x01, 0x01, 0x01, 0x01, 0x08, 0x00, 0x03, 0x00, 0x02, 0x00, 0x00, 0x00,
124 0x20, 0x00, 0x09, 0x80, 0x1c, 0x00, 0x00, 0x80, 0x06, 0x00, 0x01, 0x00,
125 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 0x00, 0x0a, 0x0a, 0x0a, 0x0a,
126 0x05, 0x00, 0x03, 0x00, 0x1e, 0x00, 0x00, 0x00,
127 ];
128
129 #[test]
130 fn test_parse_known_valid_packet() {
131 NetlinkMessage::<GenlMessage<Wireguard>>::deserialize(
132 KNOWN_VALID_PACKET,
133 )
134 .unwrap();
135 }
136
137 #[test]
138 fn test_serialize_then_deserialize() {
139 let genlmsg: GenlMessage<Wireguard> =
140 GenlMessage::from_payload(Wireguard {
141 cmd: WireguardCmd::SetDevice,
142 nlas: vec![
143 WgDeviceAttrs::IfName("wg0".to_string()),
144 WgDeviceAttrs::PrivateKey([0xaa; 32]),
145 WgDeviceAttrs::Peers(vec![
146 WgPeer(vec![
147 WgPeerAttrs::PublicKey([0x01; 32]),
148 WgPeerAttrs::PresharedKey([0x01; 32]),
149 WgPeerAttrs::AllowedIps(vec![WgAllowedIp(vec![
150 WgAllowedIpAttrs::IpAddr([10, 0, 0, 0].into()),
151 WgAllowedIpAttrs::Cidr(24),
152 WgAllowedIpAttrs::Family(AF_INET),
153 ])]),
154 ]),
155 WgPeer(vec![
156 WgPeerAttrs::PublicKey([0x02; 32]),
157 WgPeerAttrs::PresharedKey([0x01; 32]),
158 WgPeerAttrs::AllowedIps(vec![WgAllowedIp(vec![
159 WgAllowedIpAttrs::IpAddr([10, 0, 1, 0].into()),
160 WgAllowedIpAttrs::Cidr(24),
161 WgAllowedIpAttrs::Family(AF_INET),
162 ])]),
163 ]),
164 ]),
165 ],
166 });
167 let mut nlmsg = NetlinkMessage::from(genlmsg);
168 nlmsg.header.flags = NLM_F_REQUEST | NLM_F_ACK;
169
170 nlmsg.finalize();
171 let mut buf = [0; 4096];
172 nlmsg.serialize(&mut buf);
173 let len = nlmsg.buffer_len();
174 NetlinkMessage::<GenlMessage<Wireguard>>::deserialize(&buf[..len])
175 .unwrap();
176 }
177}