1use futures_util::{Stream, StreamExt};
4use genetlink::GenetlinkHandle;
5use netlink_packet_core::{
6 DecodeError, NetlinkMessage, NetlinkPayload, NLM_F_ACK, NLM_F_DUMP,
7 NLM_F_REQUEST,
8};
9use netlink_packet_generic::GenlMessage;
10
11use crate::{
12 ErrorKind, WireguardCmd, WireguardError, WireguardMessage, WireguardParsed,
13};
14
15#[derive(Clone, Debug)]
16pub struct WireguardHandle {
17 handle: GenetlinkHandle,
18}
19
20impl WireguardHandle {
21 pub(crate) fn new(handle: GenetlinkHandle) -> Self {
22 WireguardHandle { handle }
23 }
24
25 pub async fn get_by_name(
26 &mut self,
27 iface_name: &str,
28 ) -> Result<WireguardParsed, WireguardError> {
29 let msg = WireguardParsed {
30 iface_name: Some(iface_name.to_string()),
31 ..Default::default()
32 }
33 .build(WireguardCmd::GetDevice)?;
34 match self
35 .request(NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP, msg.clone())
36 .await?
37 .next()
38 .await
39 {
40 None => Err(WireguardError::new(
41 ErrorKind::Bug,
42 "Got no reply from kernel for request".to_string(),
43 Some(NetlinkMessage::from(GenlMessage::from_payload(msg))),
44 )),
45 Some(reply) => reply.map(WireguardParsed::from),
46 }
47 }
48
49 pub async fn set(
50 &mut self,
51 parsed: WireguardParsed,
52 ) -> Result<(), WireguardError> {
53 let msg = parsed.build(WireguardCmd::SetDevice)?;
54 match self
56 .request(NLM_F_REQUEST | NLM_F_ACK, msg.clone())
57 .await?
58 .next()
59 .await
60 {
61 None | Some(Ok(_)) => Ok(()),
62 Some(Err(e)) => Err(e),
63 }
64 }
65
66 pub async fn request(
69 &mut self,
70 nl_header_flags: u16,
71 message: WireguardMessage,
72 ) -> Result<
73 impl Stream<Item = Result<WireguardMessage, WireguardError>>,
74 WireguardError,
75 > {
76 let mut nl_msg =
77 NetlinkMessage::from(GenlMessage::from_payload(message));
78 nl_msg.header.flags = nl_header_flags;
79
80 match self.handle.request(nl_msg.clone()).await {
81 Ok(stream) => Ok(parse_nl_msg_stream(nl_msg, stream)),
82 Err(e) => Err(WireguardError::new(
83 ErrorKind::NetlinkError,
84 format!("Netlink request failed: {e}"),
85 Some(nl_msg),
86 )),
87 }
88 }
89}
90
91fn parse_nl_msg_stream(
92 nl_msg: NetlinkMessage<GenlMessage<WireguardMessage>>,
93 stream: impl Stream<
94 Item = Result<
95 NetlinkMessage<GenlMessage<WireguardMessage>>,
96 DecodeError,
97 >,
98 >,
99) -> impl Stream<Item = Result<WireguardMessage, WireguardError>> {
100 stream.map(move |reply| match reply {
101 Ok(reply_msg) => {
102 let (header, payload) = reply_msg.into_parts();
103 match payload {
104 NetlinkPayload::InnerMessage(genl_msg) => {
105 let (_genl_hdr, wg_msg) = genl_msg.into_parts();
106 Ok(wg_msg)
107 }
108 NetlinkPayload::Error(ref err) => Err(WireguardError::new(
109 ErrorKind::NetlinkError,
110 format!("netlink error: {err:?}"),
111 Some(NetlinkMessage::new(header, payload)),
112 )),
113 _ => Err(WireguardError::new(
114 ErrorKind::Bug,
115 format!("Unexpected NetlinkPayload type: {payload:?}"),
116 Some(NetlinkMessage::new(header, payload)),
117 )),
118 }
119 }
120 Err(e) => Err(WireguardError::new(
121 ErrorKind::DecodeError,
122 format!("netlink decode error: {e}"),
123 Some(nl_msg.clone()),
124 )),
125 })
126}