Skip to main content

nl_wireguard/
peer_parsed.rs

1// SPDX-License-Identifier: MIT
2
3use std::{
4    convert::TryFrom,
5    net::{IpAddr, SocketAddr},
6    time::Duration,
7};
8
9use base64::{prelude::BASE64_STANDARD, Engine};
10
11use super::parsed::decode_key;
12use crate::{
13    ErrorKind, WireguardAddressFamily, WireguardAllowedIp,
14    WireguardAllowedIpAttr, WireguardError, WireguardPeer,
15    WireguardPeerAttribute, WireguardTimeSpec,
16};
17
18#[derive(Clone, PartialEq, Eq, Default)]
19#[non_exhaustive]
20pub struct WireguardPeerParsed {
21    pub endpoint: Option<SocketAddr>,
22    /// Base64 encoded public key
23    pub public_key: Option<String>,
24    /// Base64 encoded pre-shared key, this property will be display as
25    /// `(hidden)` for `Debug` trait.
26    pub preshared_key: Option<String>,
27    pub persistent_keepalive: Option<u16>,
28    /// Last handshake time since UNIX_EPOCH
29    pub last_handshake: Option<Duration>,
30    pub rx_bytes: Option<u64>,
31    pub tx_bytes: Option<u64>,
32    pub allowed_ips: Option<Vec<WireguardIpAddress>>,
33    pub protocol_version: Option<u32>,
34    // TODO: Flags
35}
36
37// For simplifying the code on hide `preshared_key` in Debug display of
38// [WireguardPeerParsed]
39#[allow(dead_code)]
40#[derive(Debug)]
41struct _WireguardPeerParsed<'a> {
42    endpoint: &'a Option<SocketAddr>,
43    public_key: &'a Option<String>,
44    preshared_key: Option<String>,
45    persistent_keepalive: &'a Option<u16>,
46    last_handshake: &'a Option<Duration>,
47    rx_bytes: &'a Option<u64>,
48    tx_bytes: &'a Option<u64>,
49    allowed_ips: &'a Option<Vec<WireguardIpAddress>>,
50    protocol_version: &'a Option<u32>,
51}
52
53impl std::fmt::Debug for WireguardPeerParsed {
54    fn fmt(
55        &self,
56        f: &mut std::fmt::Formatter<'_>,
57    ) -> Result<(), std::fmt::Error> {
58        let Self {
59            endpoint,
60            public_key,
61            preshared_key,
62            persistent_keepalive,
63            last_handshake,
64            rx_bytes,
65            tx_bytes,
66            allowed_ips,
67            protocol_version,
68        } = self;
69
70        std::fmt::Debug::fmt(
71            &_WireguardPeerParsed {
72                endpoint,
73                public_key,
74                preshared_key: if preshared_key.is_some() {
75                    Some("(hidden)".to_string())
76                } else {
77                    None
78                },
79                persistent_keepalive,
80                last_handshake,
81                rx_bytes,
82                tx_bytes,
83                allowed_ips,
84                protocol_version,
85            },
86            f,
87        )
88    }
89}
90
91impl From<WireguardPeer> for WireguardPeerParsed {
92    fn from(attrs: WireguardPeer) -> Self {
93        let mut ret = Self::default();
94        for attr in attrs.0 {
95            match attr {
96                WireguardPeerAttribute::PublicKey(v) => {
97                    ret.public_key = Some(BASE64_STANDARD.encode(v));
98                }
99                WireguardPeerAttribute::PresharedKey(v) => {
100                    if v.as_slice().iter().all(|i| *i == 0) {
101                        ret.preshared_key = None;
102                    } else {
103                        ret.preshared_key = Some(BASE64_STANDARD.encode(v));
104                    }
105                }
106                WireguardPeerAttribute::Endpoint(v) => ret.endpoint = Some(v),
107                WireguardPeerAttribute::PersistentKeepalive(v) => {
108                    ret.persistent_keepalive = Some(v)
109                }
110                WireguardPeerAttribute::LastHandshake(v) => {
111                    if v.seconds == 0 && v.nano_seconds == 0 {
112                        ret.last_handshake = None;
113                    } else if v.seconds >= 0
114                        && v.nano_seconds >= 0
115                        && (v.nano_seconds as u64) < (u32::MAX as u64)
116                    {
117                        ret.last_handshake = Some(Duration::new(
118                            v.seconds as u64,
119                            v.nano_seconds as u32,
120                        ));
121                    } else {
122                        log::warn!(
123                            "Ignoring invalid last handshake time: {v:?}"
124                        );
125                    }
126                }
127                WireguardPeerAttribute::RxBytes(v) => ret.rx_bytes = Some(v),
128                WireguardPeerAttribute::TxBytes(v) => ret.tx_bytes = Some(v),
129                WireguardPeerAttribute::ProtocolVersion(v) => {
130                    ret.protocol_version = Some(v)
131                }
132                WireguardPeerAttribute::AllowedIps(wg_ips) => {
133                    let mut ips = Vec::new();
134                    for wg_ip in &wg_ips {
135                        match WireguardIpAddress::try_from(wg_ip) {
136                            Ok(i) => ips.push(i),
137                            Err(e) => {
138                                log::warn!(
139                                    "Ignoring invalid WireguardAllowedIp: {e}"
140                                );
141                            }
142                        }
143                    }
144                    ret.allowed_ips = Some(ips.into_iter().collect());
145                }
146                _ => {
147                    log::debug!("Unsupported WireguardPeerAttribute {attr:?}");
148                }
149            }
150        }
151        ret
152    }
153}
154
155impl WireguardPeerParsed {
156    pub fn build(&self) -> Result<WireguardPeer, WireguardError> {
157        let mut attrs: Vec<WireguardPeerAttribute> = Vec::new();
158        if let Some(v) = self.endpoint {
159            attrs.push(WireguardPeerAttribute::Endpoint(v));
160        }
161
162        if let Some(v) = self.public_key.as_deref() {
163            attrs.push(WireguardPeerAttribute::PublicKey(decode_key(
164                "peer.public_key",
165                v,
166            )?));
167        }
168
169        if let Some(v) = self.preshared_key.as_deref() {
170            attrs.push(WireguardPeerAttribute::PresharedKey(decode_key(
171                "peer.preshared_key",
172                v,
173            )?));
174        }
175
176        if let Some(v) = self.persistent_keepalive {
177            attrs.push(WireguardPeerAttribute::PersistentKeepalive(v));
178        }
179
180        if let Some(v) = self.last_handshake {
181            attrs.push(WireguardPeerAttribute::LastHandshake(
182                WireguardTimeSpec {
183                    seconds: v.as_secs() as i64,
184                    nano_seconds: v.subsec_nanos() as i64,
185                },
186            ));
187        }
188
189        if let Some(v) = self.rx_bytes {
190            attrs.push(WireguardPeerAttribute::RxBytes(v));
191        }
192
193        if let Some(v) = self.tx_bytes {
194            attrs.push(WireguardPeerAttribute::TxBytes(v));
195        }
196
197        if let Some(ips) = self.allowed_ips.as_ref() {
198            attrs.push(WireguardPeerAttribute::AllowedIps(
199                ips.iter()
200                    .map(|ip| {
201                        WireguardAllowedIp(Vec::<WireguardAllowedIpAttr>::from(
202                            ip,
203                        ))
204                    })
205                    .collect(),
206            ));
207        }
208
209        if let Some(v) = self.protocol_version {
210            attrs.push(WireguardPeerAttribute::ProtocolVersion(v));
211        }
212
213        Ok(WireguardPeer(attrs))
214    }
215}
216
217#[derive(Debug, Clone, PartialEq, Eq)]
218pub struct WireguardIpAddress {
219    pub prefix_length: u8,
220    pub ip_addr: IpAddr,
221}
222
223impl TryFrom<&WireguardAllowedIp> for WireguardIpAddress {
224    type Error = WireguardError;
225
226    fn try_from(attrs: &WireguardAllowedIp) -> Result<Self, WireguardError> {
227        let mut ip_addr: Option<IpAddr> = None;
228        let mut prefix_length: Option<u8> = None;
229
230        for attr in &attrs.0 {
231            match attr {
232                WireguardAllowedIpAttr::IpAddr(v) => ip_addr = Some(*v),
233                WireguardAllowedIpAttr::Cidr(v) => prefix_length = Some(*v),
234                WireguardAllowedIpAttr::Family(_) => (),
235                _ => {
236                    log::debug!("Unsupported WireguardAllowedIpAttr {attr:?}");
237                }
238            }
239        }
240        if let Some(ip_addr) = ip_addr {
241            if let Some(prefix_length) = prefix_length {
242                Ok(Self {
243                    ip_addr,
244                    prefix_length,
245                })
246            } else {
247                Err(WireguardError::new(
248                    ErrorKind::DecodeError,
249                    "WireguardAllowedIp does not have \
250                     WireguardAllowedIpAttr::Cidr defined"
251                        .to_string(),
252                    None,
253                ))
254            }
255        } else {
256            Err(WireguardError::new(
257                ErrorKind::DecodeError,
258                "WireguardAllowedIp does not have \
259                 WireguardAllowedIpAttr::IpAddr defined"
260                    .to_string(),
261                None,
262            ))
263        }
264    }
265}
266
267impl From<&WireguardIpAddress> for Vec<WireguardAllowedIpAttr> {
268    fn from(ip: &WireguardIpAddress) -> Self {
269        vec![
270            WireguardAllowedIpAttr::Cidr(ip.prefix_length),
271            if ip.ip_addr.is_ipv4() {
272                WireguardAllowedIpAttr::Family(WireguardAddressFamily::Ipv4)
273            } else {
274                WireguardAllowedIpAttr::Family(WireguardAddressFamily::Ipv6)
275            },
276            WireguardAllowedIpAttr::IpAddr(ip.ip_addr),
277        ]
278    }
279}