netlink_packet_wireguard/nlas/
peer.rs

1// SPDX-License-Identifier: MIT
2
3use super::WgAllowedIpAttrs;
4use crate::{
5    constants::*,
6    raw::{
7        emit_socket_addr, emit_timespec, parse_socket_addr, parse_timespec,
8        SOCKET_ADDR_V4_LEN, SOCKET_ADDR_V6_LEN, TIMESPEC_LEN,
9    },
10};
11use netlink_packet_core::{
12    emit_u16, emit_u32, emit_u64, parse_u16, parse_u32, parse_u64, DecodeError,
13    Emitable, ErrorContext, Nla, NlaBuffer, NlasIterator, Parseable,
14};
15use std::{
16    convert::TryInto, mem::size_of_val, net::SocketAddr, ops::Deref,
17    time::SystemTime,
18};
19
20#[derive(Clone, Debug, PartialEq, Eq)]
21pub struct WgPeer(pub Vec<WgPeerAttrs>);
22
23impl Nla for WgPeer {
24    fn value_len(&self) -> usize {
25        self.0.as_slice().buffer_len()
26    }
27
28    fn kind(&self) -> u16 {
29        0
30    }
31
32    fn emit_value(&self, buffer: &mut [u8]) {
33        self.0.as_slice().emit(buffer);
34    }
35
36    fn is_nested(&self) -> bool {
37        true
38    }
39}
40
41impl Deref for WgPeer {
42    type Target = Vec<WgPeerAttrs>;
43
44    fn deref(&self) -> &Self::Target {
45        &self.0
46    }
47}
48
49#[derive(Clone, Debug, PartialEq, Eq)]
50pub struct WgAllowedIp(pub Vec<WgAllowedIpAttrs>);
51
52impl Nla for WgAllowedIp {
53    fn value_len(&self) -> usize {
54        self.0.as_slice().buffer_len()
55    }
56
57    fn kind(&self) -> u16 {
58        0
59    }
60
61    fn emit_value(&self, buffer: &mut [u8]) {
62        self.0.as_slice().emit(buffer);
63    }
64
65    fn is_nested(&self) -> bool {
66        true
67    }
68}
69
70impl Deref for WgAllowedIp {
71    type Target = Vec<WgAllowedIpAttrs>;
72
73    fn deref(&self) -> &Self::Target {
74        &self.0
75    }
76}
77
78#[derive(Clone, Debug, PartialEq, Eq)]
79pub enum WgPeerAttrs {
80    Unspec(Vec<u8>),
81    PublicKey([u8; WG_KEY_LEN]),
82    PresharedKey([u8; WG_KEY_LEN]),
83    Endpoint(SocketAddr),
84    PersistentKeepalive(u16),
85    LastHandshake(SystemTime),
86    RxBytes(u64),
87    TxBytes(u64),
88    AllowedIps(Vec<WgAllowedIp>),
89    ProtocolVersion(u32),
90    Flags(u32),
91}
92
93impl Nla for WgPeerAttrs {
94    fn value_len(&self) -> usize {
95        match self {
96            WgPeerAttrs::Unspec(bytes) => bytes.len(),
97            WgPeerAttrs::PublicKey(v) => size_of_val(v),
98            WgPeerAttrs::PresharedKey(v) => size_of_val(v),
99            WgPeerAttrs::Endpoint(v) => match *v {
100                SocketAddr::V4(_) => SOCKET_ADDR_V4_LEN,
101                SocketAddr::V6(_) => SOCKET_ADDR_V6_LEN,
102            },
103            WgPeerAttrs::PersistentKeepalive(v) => size_of_val(v),
104            WgPeerAttrs::LastHandshake(_) => TIMESPEC_LEN,
105            WgPeerAttrs::RxBytes(v) => size_of_val(v),
106            WgPeerAttrs::TxBytes(v) => size_of_val(v),
107            WgPeerAttrs::AllowedIps(nlas) => {
108                nlas.iter().map(|op| op.buffer_len()).sum()
109            }
110            WgPeerAttrs::ProtocolVersion(v) => size_of_val(v),
111            WgPeerAttrs::Flags(v) => size_of_val(v),
112        }
113    }
114
115    fn kind(&self) -> u16 {
116        match self {
117            WgPeerAttrs::Unspec(_) => WGPEER_A_UNSPEC,
118            WgPeerAttrs::PublicKey(_) => WGPEER_A_PUBLIC_KEY,
119            WgPeerAttrs::PresharedKey(_) => WGPEER_A_PRESHARED_KEY,
120            WgPeerAttrs::Endpoint(_) => WGPEER_A_ENDPOINT,
121            WgPeerAttrs::PersistentKeepalive(_) => {
122                WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL
123            }
124            WgPeerAttrs::LastHandshake(_) => WGPEER_A_LAST_HANDSHAKE_TIME,
125            WgPeerAttrs::RxBytes(_) => WGPEER_A_RX_BYTES,
126            WgPeerAttrs::TxBytes(_) => WGPEER_A_TX_BYTES,
127            WgPeerAttrs::AllowedIps(_) => WGPEER_A_ALLOWEDIPS,
128            WgPeerAttrs::ProtocolVersion(_) => WGPEER_A_PROTOCOL_VERSION,
129            WgPeerAttrs::Flags(_) => WGPEER_A_FLAGS,
130        }
131    }
132
133    fn emit_value(&self, buffer: &mut [u8]) {
134        match self {
135            WgPeerAttrs::Unspec(bytes) => buffer.copy_from_slice(bytes),
136            WgPeerAttrs::PublicKey(v) => buffer.copy_from_slice(v),
137            WgPeerAttrs::PresharedKey(v) => buffer.copy_from_slice(v),
138            WgPeerAttrs::Endpoint(v) => emit_socket_addr(v, buffer),
139            WgPeerAttrs::PersistentKeepalive(v) => {
140                emit_u16(buffer, *v).unwrap()
141            }
142            WgPeerAttrs::LastHandshake(v) => emit_timespec(v, buffer),
143            WgPeerAttrs::RxBytes(v) => emit_u64(buffer, *v).unwrap(),
144            WgPeerAttrs::TxBytes(v) => emit_u64(buffer, *v).unwrap(),
145            WgPeerAttrs::AllowedIps(nlas) => {
146                let mut len = 0;
147                for op in nlas {
148                    op.emit(&mut buffer[len..]);
149                    len += op.buffer_len();
150                }
151            }
152            WgPeerAttrs::ProtocolVersion(v) => emit_u32(buffer, *v).unwrap(),
153            WgPeerAttrs::Flags(v) => emit_u32(buffer, *v).unwrap(),
154        }
155    }
156
157    fn is_nested(&self) -> bool {
158        matches!(self, WgPeerAttrs::AllowedIps(_))
159    }
160}
161
162impl<'a, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'a T>> for WgPeerAttrs {
163    fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
164        let payload = buf.value();
165        Ok(match buf.kind() {
166            WGPEER_A_UNSPEC => Self::Unspec(payload.to_vec()),
167            WGPEER_A_PUBLIC_KEY => Self::PublicKey(
168                payload
169                    .try_into()
170                    .map_err(|e: std::array::TryFromSliceError| {
171                        DecodeError::from(e.to_string())
172                    })
173                    .context("invalid WGPEER_A_PUBLIC_KEY")?,
174            ),
175            WGPEER_A_PRESHARED_KEY => Self::PresharedKey(
176                payload
177                    .try_into()
178                    .map_err(|e: std::array::TryFromSliceError| {
179                        DecodeError::from(e.to_string())
180                    })
181                    .context("invalid WGPEER_A_PRESHARED_KEY")?,
182            ),
183            WGPEER_A_ENDPOINT => Self::Endpoint(
184                parse_socket_addr(payload)
185                    .context("invalid WGPEER_A_ENDPOINT")?,
186            ),
187            WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL => {
188                Self::PersistentKeepalive(parse_u16(payload).context(
189                    "invalid WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL value",
190                )?)
191            }
192            WGPEER_A_LAST_HANDSHAKE_TIME => Self::LastHandshake(
193                parse_timespec(payload)
194                    .context("invalid WGPEER_A_LAST_HANDSHAKE_TIME")?,
195            ),
196            WGPEER_A_RX_BYTES => Self::RxBytes(
197                parse_u64(payload)
198                    .context("invalid WGPEER_A_RX_BYTES value")?,
199            ),
200            WGPEER_A_TX_BYTES => Self::TxBytes(
201                parse_u64(payload)
202                    .context("invalid WGPEER_A_TX_BYTES value")?,
203            ),
204            WGPEER_A_ALLOWEDIPS => {
205                let error_msg = "failed to parse WGPEER_A_ALLOWEDIPS";
206                let mut ips = Vec::new();
207                for nlas in NlasIterator::new(payload) {
208                    let nlas = &nlas.context(error_msg)?;
209                    let mut group = Vec::new();
210                    for nla in NlasIterator::new(nlas.value()) {
211                        let nla = &nla.context(error_msg)?;
212                        let parsed =
213                            WgAllowedIpAttrs::parse(nla).context(error_msg)?;
214                        group.push(parsed);
215                    }
216                    ips.push(WgAllowedIp(group));
217                }
218                Self::AllowedIps(ips)
219            }
220            WGPEER_A_PROTOCOL_VERSION => Self::ProtocolVersion(
221                parse_u32(payload)
222                    .context("invalid WGPEER_A_PROTOCOL_VERSION value")?,
223            ),
224            WGPEER_A_FLAGS => Self::Flags(
225                parse_u32(payload).context("invalid WGPEER_A_FLAGS value")?,
226            ),
227            kind => {
228                return Err(DecodeError::from(format!(
229                    "invalid NLA kind: {}",
230                    kind
231                )))
232            }
233        })
234    }
235}