netlink_packet_wireguard/
peer.rs1use std::{convert::TryInto, net::SocketAddr};
4
5use netlink_packet_core::{
6 emit_i64, emit_u16, emit_u32, emit_u64, parse_i64, parse_u16, parse_u32,
7 parse_u64, DecodeError, DefaultNla, Emitable, ErrorContext, Nla, NlaBuffer,
8 NlasIterator, Parseable, NLA_F_NESTED,
9};
10
11use super::{
12 allowedip::WireguardAllowedIps,
13 socket_addr::{
14 emit_socket_addr, parse_socket_addr, SOCKET_ADDR_V4_LEN,
15 SOCKET_ADDR_V6_LEN,
16 },
17};
18use crate::WireguardAllowedIp;
19
20pub(crate) struct WireguardPeers(pub(crate) Vec<WireguardPeer>);
21
22impl<'a, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'a T>>
23 for WireguardPeers
24{
25 fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
26 let mut ret = Vec::new();
27 let nlas = NlasIterator::new(buf.value());
28 for nla in nlas {
29 let nla = nla?;
30 ret.push(WireguardPeer::parse(&nla)?);
31 }
32 Ok(Self(ret))
33 }
34}
35
36impl std::ops::Deref for WireguardPeers {
37 type Target = Vec<WireguardPeer>;
38
39 fn deref(&self) -> &Self::Target {
40 &self.0
41 }
42}
43
44#[derive(Clone, Debug, PartialEq, Eq)]
45pub struct WireguardPeer(pub Vec<WireguardPeerAttribute>);
46
47impl<'a, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'a T>>
48 for WireguardPeer
49{
50 fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
51 let mut ret = Vec::new();
52 let nlas = NlasIterator::new(buf.value());
53 for nla in nlas {
54 let nla = nla?;
55 ret.push(WireguardPeerAttribute::parse(&nla)?);
56 }
57 Ok(Self(ret))
58 }
59}
60
61impl Nla for WireguardPeer {
62 fn kind(&self) -> u16 {
63 NLA_F_NESTED
65 }
66
67 fn value_len(&self) -> usize {
68 self.0.as_slice().buffer_len()
69 }
70
71 fn emit_value(&self, buffer: &mut [u8]) {
72 self.0.as_slice().emit(buffer)
73 }
74}
75
76impl std::ops::Deref for WireguardPeer {
77 type Target = Vec<WireguardPeerAttribute>;
78
79 fn deref(&self) -> &Self::Target {
80 &self.0
81 }
82}
83
84const TIMESPEC_LEN: usize = 16;
85
86#[derive(Default, Clone, Copy, Debug, PartialEq, Eq)]
87pub struct WireguardTimeSpec {
88 pub seconds: i64,
89 pub nano_seconds: i64,
90}
91
92impl Emitable for WireguardTimeSpec {
93 fn buffer_len(&self) -> usize {
94 TIMESPEC_LEN
95 }
96
97 fn emit(&self, buffer: &mut [u8]) {
98 emit_i64(&mut buffer[..8], self.seconds).unwrap();
99 emit_i64(&mut buffer[8..16], self.nano_seconds).unwrap();
100 }
101}
102
103impl<'a, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'a T>>
104 for WireguardTimeSpec
105{
106 fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
107 let data = buf.value();
108 if data.len() < TIMESPEC_LEN {
109 Err(format!(
110 "Invalid WGPEER_A_LAST_HANDSHAKE_TIME, expecting size {}, but \
111 got {:?}",
112 TIMESPEC_LEN, data
113 )
114 .into())
115 } else {
116 Ok(Self {
117 seconds: parse_i64(&data[..8])?,
118 nano_seconds: parse_i64(&data[8..16])?,
119 })
120 }
121 }
122}
123
124const NOISE_PUBLIC_KEY_LEN: usize = 32;
125const NOISE_SYMMETRIC_KEY_LEN: usize = 32;
126
127const WGPEER_A_PUBLIC_KEY: u16 = 1;
128const WGPEER_A_PRESHARED_KEY: u16 = 2;
129const WGPEER_A_FLAGS: u16 = 3;
130const WGPEER_A_ENDPOINT: u16 = 4;
131const WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL: u16 = 5;
132const WGPEER_A_LAST_HANDSHAKE_TIME: u16 = 6;
133const WGPEER_A_RX_BYTES: u16 = 7;
134const WGPEER_A_TX_BYTES: u16 = 8;
135const WGPEER_A_ALLOWEDIPS: u16 = 9;
136const WGPEER_A_PROTOCOL_VERSION: u16 = 10;
137
138#[derive(Clone, Debug, PartialEq, Eq)]
139#[non_exhaustive]
140pub enum WireguardPeerAttribute {
141 PublicKey([u8; NOISE_PUBLIC_KEY_LEN]),
142 PresharedKey([u8; NOISE_SYMMETRIC_KEY_LEN]),
143 Endpoint(SocketAddr),
144 PersistentKeepalive(u16),
145 LastHandshake(WireguardTimeSpec),
146 RxBytes(u64),
147 TxBytes(u64),
148 AllowedIps(Vec<WireguardAllowedIp>),
149 ProtocolVersion(u32),
150 Flags(u32),
151 Other(DefaultNla),
152}
153
154impl Nla for WireguardPeerAttribute {
155 fn value_len(&self) -> usize {
156 match self {
157 Self::PublicKey(v) => size_of_val(v),
158 Self::PresharedKey(v) => size_of_val(v),
159 Self::Endpoint(v) => match *v {
160 SocketAddr::V4(_) => SOCKET_ADDR_V4_LEN,
161 SocketAddr::V6(_) => SOCKET_ADDR_V6_LEN,
162 },
163 Self::PersistentKeepalive(v) => size_of_val(v),
164 Self::LastHandshake(v) => v.buffer_len(),
165 Self::RxBytes(v) => size_of_val(v),
166 Self::TxBytes(v) => size_of_val(v),
167 Self::AllowedIps(v) => v.as_slice().buffer_len(),
168 Self::ProtocolVersion(v) => size_of_val(v),
169 Self::Flags(v) => size_of_val(v),
170 Self::Other(v) => v.value_len(),
171 }
172 }
173
174 fn kind(&self) -> u16 {
175 match self {
176 Self::PublicKey(_) => WGPEER_A_PUBLIC_KEY,
177 Self::PresharedKey(_) => WGPEER_A_PRESHARED_KEY,
178 Self::Endpoint(_) => WGPEER_A_ENDPOINT,
179 Self::PersistentKeepalive(_) => {
180 WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL
181 }
182 Self::LastHandshake(_) => WGPEER_A_LAST_HANDSHAKE_TIME,
183 Self::RxBytes(_) => WGPEER_A_RX_BYTES,
184 Self::TxBytes(_) => WGPEER_A_TX_BYTES,
185 Self::AllowedIps(_) => WGPEER_A_ALLOWEDIPS | NLA_F_NESTED,
186 Self::ProtocolVersion(_) => WGPEER_A_PROTOCOL_VERSION,
187 Self::Flags(_) => WGPEER_A_FLAGS,
188 Self::Other(v) => v.kind(),
189 }
190 }
191
192 fn emit_value(&self, buffer: &mut [u8]) {
193 match self {
194 Self::PublicKey(v) => buffer.copy_from_slice(v),
195 Self::PresharedKey(v) => buffer.copy_from_slice(v),
196 Self::Endpoint(v) => emit_socket_addr(v, buffer),
197 Self::PersistentKeepalive(v) => emit_u16(buffer, *v).unwrap(),
198 Self::LastHandshake(v) => v.emit(buffer),
199 Self::RxBytes(v) => emit_u64(buffer, *v).unwrap(),
200 Self::TxBytes(v) => emit_u64(buffer, *v).unwrap(),
201 Self::AllowedIps(v) => v.as_slice().emit(buffer),
202 Self::ProtocolVersion(v) => emit_u32(buffer, *v).unwrap(),
203 Self::Flags(v) => emit_u32(buffer, *v).unwrap(),
204 Self::Other(v) => v.emit_value(buffer),
205 }
206 }
207}
208
209impl<'a, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'a T>>
210 for WireguardPeerAttribute
211{
212 fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
213 let payload = buf.value();
214 Ok(match buf.kind() {
215 WGPEER_A_PUBLIC_KEY => Self::PublicKey(
216 payload
217 .try_into()
218 .map_err(|e: std::array::TryFromSliceError| {
219 DecodeError::from(e.to_string())
220 })
221 .context("invalid WGPEER_A_PUBLIC_KEY")?,
222 ),
223 WGPEER_A_PRESHARED_KEY => Self::PresharedKey(
224 payload
225 .try_into()
226 .map_err(|e: std::array::TryFromSliceError| {
227 DecodeError::from(e.to_string())
228 })
229 .context("invalid WGPEER_A_PRESHARED_KEY")?,
230 ),
231 WGPEER_A_ENDPOINT => Self::Endpoint(
232 parse_socket_addr(payload)
233 .context("invalid WGPEER_A_ENDPOINT")?,
234 ),
235 WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL => {
236 Self::PersistentKeepalive(parse_u16(payload).context(
237 "invalid WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL value",
238 )?)
239 }
240 WGPEER_A_LAST_HANDSHAKE_TIME => Self::LastHandshake(
241 WireguardTimeSpec::parse(buf)
242 .context("invalid WGPEER_A_LAST_HANDSHAKE_TIME")?,
243 ),
244 WGPEER_A_RX_BYTES => Self::RxBytes(
245 parse_u64(payload)
246 .context("invalid WGPEER_A_RX_BYTES value")?,
247 ),
248 WGPEER_A_TX_BYTES => Self::TxBytes(
249 parse_u64(payload)
250 .context("invalid WGPEER_A_TX_BYTES value")?,
251 ),
252 WGPEER_A_ALLOWEDIPS => {
253 Self::AllowedIps(WireguardAllowedIps::parse(buf)?.0)
254 }
255 WGPEER_A_PROTOCOL_VERSION => Self::ProtocolVersion(
256 parse_u32(payload)
257 .context("invalid WGPEER_A_PROTOCOL_VERSION value")?,
258 ),
259 WGPEER_A_FLAGS => Self::Flags(
260 parse_u32(payload).context("invalid WGPEER_A_FLAGS value")?,
261 ),
262 kind => Self::Other(
263 DefaultNla::parse(buf)
264 .context(format!("unknown NLA type {kind}"))?,
265 ),
266 })
267 }
268}