Skip to main content

netlink_packet_wireguard/
peer.rs

1// SPDX-License-Identifier: MIT
2
3use std::{convert::TryInto, net::SocketAddr};
4
5use netlink_packet_core::{
6    emit_i64, emit_u16, emit_u32, emit_u64, parse_i64, parse_u16, parse_u32,
7    parse_u64, DecodeError, DefaultNla, Emitable, ErrorContext, Nla, NlaBuffer,
8    NlasIterator, Parseable, NLA_F_NESTED,
9};
10
11use super::{
12    allowedip::WireguardAllowedIps,
13    socket_addr::{
14        emit_socket_addr, parse_socket_addr, SOCKET_ADDR_V4_LEN,
15        SOCKET_ADDR_V6_LEN,
16    },
17};
18use crate::WireguardAllowedIp;
19
20pub(crate) struct WireguardPeers(pub(crate) Vec<WireguardPeer>);
21
22impl<'a, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'a T>>
23    for WireguardPeers
24{
25    fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
26        let mut ret = Vec::new();
27        let nlas = NlasIterator::new(buf.value());
28        for nla in nlas {
29            let nla = nla?;
30            ret.push(WireguardPeer::parse(&nla)?);
31        }
32        Ok(Self(ret))
33    }
34}
35
36impl std::ops::Deref for WireguardPeers {
37    type Target = Vec<WireguardPeer>;
38
39    fn deref(&self) -> &Self::Target {
40        &self.0
41    }
42}
43
44#[derive(Clone, Debug, PartialEq, Eq)]
45pub struct WireguardPeer(pub Vec<WireguardPeerAttribute>);
46
47impl<'a, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'a T>>
48    for WireguardPeer
49{
50    fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
51        let mut ret = Vec::new();
52        let nlas = NlasIterator::new(buf.value());
53        for nla in nlas {
54            let nla = nla?;
55            ret.push(WireguardPeerAttribute::parse(&nla)?);
56        }
57        Ok(Self(ret))
58    }
59}
60
61impl Nla for WireguardPeer {
62    fn kind(&self) -> u16 {
63        // linux kernel always set it to 0
64        NLA_F_NESTED
65    }
66
67    fn value_len(&self) -> usize {
68        self.0.as_slice().buffer_len()
69    }
70
71    fn emit_value(&self, buffer: &mut [u8]) {
72        self.0.as_slice().emit(buffer)
73    }
74}
75
76impl std::ops::Deref for WireguardPeer {
77    type Target = Vec<WireguardPeerAttribute>;
78
79    fn deref(&self) -> &Self::Target {
80        &self.0
81    }
82}
83
84const TIMESPEC_LEN: usize = 16;
85
86#[derive(Default, Clone, Copy, Debug, PartialEq, Eq)]
87pub struct WireguardTimeSpec {
88    pub seconds: i64,
89    pub nano_seconds: i64,
90}
91
92impl Emitable for WireguardTimeSpec {
93    fn buffer_len(&self) -> usize {
94        TIMESPEC_LEN
95    }
96
97    fn emit(&self, buffer: &mut [u8]) {
98        emit_i64(&mut buffer[..8], self.seconds).unwrap();
99        emit_i64(&mut buffer[8..16], self.nano_seconds).unwrap();
100    }
101}
102
103impl<'a, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'a T>>
104    for WireguardTimeSpec
105{
106    fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
107        let data = buf.value();
108        if data.len() < TIMESPEC_LEN {
109            Err(format!(
110                "Invalid WGPEER_A_LAST_HANDSHAKE_TIME, expecting size {}, but \
111                 got {:?}",
112                TIMESPEC_LEN, data
113            )
114            .into())
115        } else {
116            Ok(Self {
117                seconds: parse_i64(&data[..8])?,
118                nano_seconds: parse_i64(&data[8..16])?,
119            })
120        }
121    }
122}
123
124const NOISE_PUBLIC_KEY_LEN: usize = 32;
125const NOISE_SYMMETRIC_KEY_LEN: usize = 32;
126
127const WGPEER_A_PUBLIC_KEY: u16 = 1;
128const WGPEER_A_PRESHARED_KEY: u16 = 2;
129const WGPEER_A_FLAGS: u16 = 3;
130const WGPEER_A_ENDPOINT: u16 = 4;
131const WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL: u16 = 5;
132const WGPEER_A_LAST_HANDSHAKE_TIME: u16 = 6;
133const WGPEER_A_RX_BYTES: u16 = 7;
134const WGPEER_A_TX_BYTES: u16 = 8;
135const WGPEER_A_ALLOWEDIPS: u16 = 9;
136const WGPEER_A_PROTOCOL_VERSION: u16 = 10;
137
138#[derive(Clone, Debug, PartialEq, Eq)]
139#[non_exhaustive]
140pub enum WireguardPeerAttribute {
141    PublicKey([u8; NOISE_PUBLIC_KEY_LEN]),
142    PresharedKey([u8; NOISE_SYMMETRIC_KEY_LEN]),
143    Endpoint(SocketAddr),
144    PersistentKeepalive(u16),
145    LastHandshake(WireguardTimeSpec),
146    RxBytes(u64),
147    TxBytes(u64),
148    AllowedIps(Vec<WireguardAllowedIp>),
149    ProtocolVersion(u32),
150    Flags(u32),
151    Other(DefaultNla),
152}
153
154impl Nla for WireguardPeerAttribute {
155    fn value_len(&self) -> usize {
156        match self {
157            Self::PublicKey(v) => size_of_val(v),
158            Self::PresharedKey(v) => size_of_val(v),
159            Self::Endpoint(v) => match *v {
160                SocketAddr::V4(_) => SOCKET_ADDR_V4_LEN,
161                SocketAddr::V6(_) => SOCKET_ADDR_V6_LEN,
162            },
163            Self::PersistentKeepalive(v) => size_of_val(v),
164            Self::LastHandshake(v) => v.buffer_len(),
165            Self::RxBytes(v) => size_of_val(v),
166            Self::TxBytes(v) => size_of_val(v),
167            Self::AllowedIps(v) => v.as_slice().buffer_len(),
168            Self::ProtocolVersion(v) => size_of_val(v),
169            Self::Flags(v) => size_of_val(v),
170            Self::Other(v) => v.value_len(),
171        }
172    }
173
174    fn kind(&self) -> u16 {
175        match self {
176            Self::PublicKey(_) => WGPEER_A_PUBLIC_KEY,
177            Self::PresharedKey(_) => WGPEER_A_PRESHARED_KEY,
178            Self::Endpoint(_) => WGPEER_A_ENDPOINT,
179            Self::PersistentKeepalive(_) => {
180                WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL
181            }
182            Self::LastHandshake(_) => WGPEER_A_LAST_HANDSHAKE_TIME,
183            Self::RxBytes(_) => WGPEER_A_RX_BYTES,
184            Self::TxBytes(_) => WGPEER_A_TX_BYTES,
185            Self::AllowedIps(_) => WGPEER_A_ALLOWEDIPS | NLA_F_NESTED,
186            Self::ProtocolVersion(_) => WGPEER_A_PROTOCOL_VERSION,
187            Self::Flags(_) => WGPEER_A_FLAGS,
188            Self::Other(v) => v.kind(),
189        }
190    }
191
192    fn emit_value(&self, buffer: &mut [u8]) {
193        match self {
194            Self::PublicKey(v) => buffer.copy_from_slice(v),
195            Self::PresharedKey(v) => buffer.copy_from_slice(v),
196            Self::Endpoint(v) => emit_socket_addr(v, buffer),
197            Self::PersistentKeepalive(v) => emit_u16(buffer, *v).unwrap(),
198            Self::LastHandshake(v) => v.emit(buffer),
199            Self::RxBytes(v) => emit_u64(buffer, *v).unwrap(),
200            Self::TxBytes(v) => emit_u64(buffer, *v).unwrap(),
201            Self::AllowedIps(v) => v.as_slice().emit(buffer),
202            Self::ProtocolVersion(v) => emit_u32(buffer, *v).unwrap(),
203            Self::Flags(v) => emit_u32(buffer, *v).unwrap(),
204            Self::Other(v) => v.emit_value(buffer),
205        }
206    }
207}
208
209impl<'a, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'a T>>
210    for WireguardPeerAttribute
211{
212    fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
213        let payload = buf.value();
214        Ok(match buf.kind() {
215            WGPEER_A_PUBLIC_KEY => Self::PublicKey(
216                payload
217                    .try_into()
218                    .map_err(|e: std::array::TryFromSliceError| {
219                        DecodeError::from(e.to_string())
220                    })
221                    .context("invalid WGPEER_A_PUBLIC_KEY")?,
222            ),
223            WGPEER_A_PRESHARED_KEY => Self::PresharedKey(
224                payload
225                    .try_into()
226                    .map_err(|e: std::array::TryFromSliceError| {
227                        DecodeError::from(e.to_string())
228                    })
229                    .context("invalid WGPEER_A_PRESHARED_KEY")?,
230            ),
231            WGPEER_A_ENDPOINT => Self::Endpoint(
232                parse_socket_addr(payload)
233                    .context("invalid WGPEER_A_ENDPOINT")?,
234            ),
235            WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL => {
236                Self::PersistentKeepalive(parse_u16(payload).context(
237                    "invalid WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL value",
238                )?)
239            }
240            WGPEER_A_LAST_HANDSHAKE_TIME => Self::LastHandshake(
241                WireguardTimeSpec::parse(buf)
242                    .context("invalid WGPEER_A_LAST_HANDSHAKE_TIME")?,
243            ),
244            WGPEER_A_RX_BYTES => Self::RxBytes(
245                parse_u64(payload)
246                    .context("invalid WGPEER_A_RX_BYTES value")?,
247            ),
248            WGPEER_A_TX_BYTES => Self::TxBytes(
249                parse_u64(payload)
250                    .context("invalid WGPEER_A_TX_BYTES value")?,
251            ),
252            WGPEER_A_ALLOWEDIPS => {
253                Self::AllowedIps(WireguardAllowedIps::parse(buf)?.0)
254            }
255            WGPEER_A_PROTOCOL_VERSION => Self::ProtocolVersion(
256                parse_u32(payload)
257                    .context("invalid WGPEER_A_PROTOCOL_VERSION value")?,
258            ),
259            WGPEER_A_FLAGS => Self::Flags(
260                parse_u32(payload).context("invalid WGPEER_A_FLAGS value")?,
261            ),
262            kind => Self::Other(
263                DefaultNla::parse(buf)
264                    .context(format!("unknown NLA type {kind}"))?,
265            ),
266        })
267    }
268}