Skip to main content

nl_wireguard/
handle.rs

1// SPDX-License-Identifier: MIT
2
3use futures_util::{Stream, StreamExt};
4use genetlink::GenetlinkHandle;
5use netlink_packet_core::{
6    DecodeError, NetlinkMessage, NetlinkPayload, NLM_F_ACK, NLM_F_DUMP,
7    NLM_F_REQUEST,
8};
9use netlink_packet_generic::GenlMessage;
10
11use crate::{
12    ErrorKind, WireguardCmd, WireguardError, WireguardMessage, WireguardParsed,
13};
14
15#[derive(Clone, Debug)]
16pub struct WireguardHandle {
17    handle: GenetlinkHandle,
18}
19
20impl WireguardHandle {
21    pub(crate) fn new(handle: GenetlinkHandle) -> Self {
22        WireguardHandle { handle }
23    }
24
25    pub async fn get_by_name(
26        &mut self,
27        iface_name: &str,
28    ) -> Result<WireguardParsed, WireguardError> {
29        let msg = WireguardParsed {
30            iface_name: Some(iface_name.to_string()),
31            ..Default::default()
32        }
33        .build(WireguardCmd::GetDevice)?;
34        match self
35            .request(NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP, msg.clone())
36            .await?
37            .next()
38            .await
39        {
40            None => Err(WireguardError::new(
41                ErrorKind::Bug,
42                "Got no reply from kernel for request".to_string(),
43                Some(NetlinkMessage::from(GenlMessage::from_payload(msg))),
44            )),
45            Some(reply) => reply.map(WireguardParsed::from),
46        }
47    }
48
49    pub async fn set(
50        &mut self,
51        parsed: WireguardParsed,
52    ) -> Result<(), WireguardError> {
53        let msg = parsed.build(WireguardCmd::SetDevice)?;
54        //TODO: Polished this
55        match self
56            .request(NLM_F_REQUEST | NLM_F_ACK, msg.clone())
57            .await?
58            .next()
59            .await
60        {
61            None | Some(Ok(_)) => Ok(()),
62            Some(Err(e)) => Err(e),
63        }
64    }
65
66    /// Sending arbitrary [WireguardMessage] message and manually handle
67    /// [WireguardMessage] reply from kernel.
68    pub async fn request(
69        &mut self,
70        nl_header_flags: u16,
71        message: WireguardMessage,
72    ) -> Result<
73        impl Stream<Item = Result<WireguardMessage, WireguardError>>,
74        WireguardError,
75    > {
76        let mut nl_msg =
77            NetlinkMessage::from(GenlMessage::from_payload(message));
78        nl_msg.header.flags = nl_header_flags;
79
80        match self.handle.request(nl_msg.clone()).await {
81            Ok(stream) => Ok(parse_nl_msg_stream(nl_msg, stream)),
82            Err(e) => Err(WireguardError::new(
83                ErrorKind::NetlinkError,
84                format!("Netlink request failed: {e}"),
85                Some(nl_msg),
86            )),
87        }
88    }
89}
90
91fn parse_nl_msg_stream(
92    nl_msg: NetlinkMessage<GenlMessage<WireguardMessage>>,
93    stream: impl Stream<
94        Item = Result<
95            NetlinkMessage<GenlMessage<WireguardMessage>>,
96            DecodeError,
97        >,
98    >,
99) -> impl Stream<Item = Result<WireguardMessage, WireguardError>> {
100    stream.map(move |reply| match reply {
101        Ok(reply_msg) => {
102            let (header, payload) = reply_msg.into_parts();
103            match payload {
104                NetlinkPayload::InnerMessage(genl_msg) => {
105                    let (_genl_hdr, wg_msg) = genl_msg.into_parts();
106                    Ok(wg_msg)
107                }
108                NetlinkPayload::Error(ref err) => Err(WireguardError::new(
109                    ErrorKind::NetlinkError,
110                    format!("netlink error: {err:?}"),
111                    Some(NetlinkMessage::new(header, payload)),
112                )),
113                _ => Err(WireguardError::new(
114                    ErrorKind::Bug,
115                    format!("Unexpected NetlinkPayload type: {payload:?}"),
116                    Some(NetlinkMessage::new(header, payload)),
117                )),
118            }
119        }
120        Err(e) => Err(WireguardError::new(
121            ErrorKind::DecodeError,
122            format!("netlink decode error: {e}"),
123            Some(nl_msg.clone()),
124        )),
125    })
126}