Skip to main content

nl_wireguard/
parsed.rs

1// SPDX-License-Identifier: MIT
2
3use base64::{prelude::BASE64_STANDARD, Engine};
4
5use crate::{
6    ErrorKind, WireguardAttribute, WireguardCmd, WireguardError,
7    WireguardMessage, WireguardPeerParsed,
8};
9
10#[derive(Clone, PartialEq, Eq, Default)]
11#[non_exhaustive]
12pub struct WireguardParsed {
13    pub iface_name: Option<String>,
14    pub iface_index: Option<u32>,
15    /// Base64 encoded public key
16    pub public_key: Option<String>,
17    /// Base64 encoded private key, this property will be display as
18    /// `(hidden)` for `Debug` trait.
19    pub private_key: Option<String>,
20    pub listen_port: Option<u16>,
21    pub fwmark: Option<u32>,
22    pub peers: Option<Vec<WireguardPeerParsed>>,
23    // TODO: Flags
24}
25
26// For simplifying the code on hide `private_key` in Debug display of
27// [WireguardParsed]
28#[allow(dead_code)]
29#[derive(Debug)]
30struct _WireguardParsed<'a> {
31    iface_name: &'a Option<String>,
32    iface_index: &'a Option<u32>,
33    public_key: &'a Option<String>,
34    private_key: Option<String>,
35    listen_port: &'a Option<u16>,
36    fwmark: &'a Option<u32>,
37    peers: &'a Option<Vec<WireguardPeerParsed>>,
38}
39
40impl std::fmt::Debug for WireguardParsed {
41    fn fmt(
42        &self,
43        f: &mut std::fmt::Formatter<'_>,
44    ) -> Result<(), std::fmt::Error> {
45        let Self {
46            iface_name,
47            iface_index,
48            public_key,
49            private_key,
50            listen_port,
51            fwmark,
52            peers,
53        } = self;
54
55        std::fmt::Debug::fmt(
56            &_WireguardParsed {
57                iface_name,
58                iface_index,
59                public_key,
60                private_key: if private_key.is_some() {
61                    Some("(hidden)".to_string())
62                } else {
63                    None
64                },
65                listen_port,
66                fwmark,
67                peers,
68            },
69            f,
70        )
71    }
72}
73
74impl From<WireguardMessage> for WireguardParsed {
75    fn from(msg: WireguardMessage) -> Self {
76        let mut ret = Self::default();
77        for attr in msg.attributes {
78            match attr {
79                WireguardAttribute::IfName(v) => ret.iface_name = Some(v),
80                WireguardAttribute::IfIndex(v) => ret.iface_index = Some(v),
81                WireguardAttribute::PrivateKey(v) => {
82                    ret.private_key = Some(BASE64_STANDARD.encode(v))
83                }
84                WireguardAttribute::PublicKey(v) => {
85                    ret.public_key = Some(BASE64_STANDARD.encode(v))
86                }
87                WireguardAttribute::ListenPort(v) => ret.listen_port = Some(v),
88                WireguardAttribute::Fwmark(v) => ret.fwmark = Some(v),
89                WireguardAttribute::Peers(peers) => {
90                    ret.peers = Some(
91                        peers
92                            .into_iter()
93                            .map(WireguardPeerParsed::from)
94                            .collect(),
95                    );
96                }
97                _ => {
98                    log::debug!("Unsupported WireguardAttribute {attr:?}");
99                }
100            }
101        }
102        ret
103    }
104}
105
106impl WireguardParsed {
107    /// Build [WireguardMessage]
108    pub fn build(
109        &self,
110        cmd: WireguardCmd,
111    ) -> Result<WireguardMessage, WireguardError> {
112        let mut attributes: Vec<WireguardAttribute> = Vec::new();
113
114        if let Some(v) = self.iface_name.as_ref() {
115            attributes.push(WireguardAttribute::IfName(v.to_string()));
116        }
117
118        if let Some(v) = self.iface_index {
119            attributes.push(WireguardAttribute::IfIndex(v));
120        }
121
122        if let Some(v) = self.public_key.as_deref() {
123            attributes.push(WireguardAttribute::PublicKey(decode_key(
124                "public_key",
125                v,
126            )?));
127        }
128
129        if let Some(v) = self.private_key.as_deref() {
130            attributes.push(WireguardAttribute::PrivateKey(decode_key(
131                "private_key",
132                v,
133            )?));
134        }
135
136        if let Some(v) = self.listen_port {
137            attributes.push(WireguardAttribute::ListenPort(v));
138        }
139
140        if let Some(v) = self.fwmark {
141            attributes.push(WireguardAttribute::Fwmark(v));
142        }
143
144        if let Some(peers) = self.peers.as_ref() {
145            let mut peer_addrs = Vec::new();
146            for peer in peers {
147                peer_addrs.push(peer.build()?);
148            }
149            attributes.push(WireguardAttribute::Peers(peer_addrs));
150        }
151
152        Ok(WireguardMessage { cmd, attributes })
153    }
154}
155
156pub(crate) fn decode_key(
157    prop_name: &str,
158    key_str: &str,
159) -> Result<[u8; WireguardAttribute::WG_KEY_LEN], WireguardError> {
160    let key = BASE64_STANDARD.decode(key_str).map_err(|e| {
161        WireguardError::new(
162            ErrorKind::InvalidKey,
163            format!(
164                "Invalid {prop_name}: not valid base64 encoded string \
165                 {key_str}: {e}"
166            ),
167            None,
168        )
169    })?;
170    if key.len() != WireguardAttribute::WG_KEY_LEN {
171        return Err(WireguardError::new(
172            ErrorKind::InvalidKey,
173            format!(
174                "Invalid {prop_name}: current length {}, but expecting {} \
175                 length of u8 encoded base64 string, {key_str}",
176                key.len(),
177                WireguardAttribute::WG_KEY_LEN
178            ),
179            None,
180        ));
181    }
182    let mut key_data = [0u8; WireguardAttribute::WG_KEY_LEN];
183    key_data.copy_from_slice(&key);
184    Ok(key_data)
185}