tokio_wireguard/
interface.rs

1use std::{
2    any::type_name,
3    fmt,
4    future::{poll_fn, Future},
5    io,
6    net::{IpAddr, SocketAddr},
7    ops::Deref,
8    pin::Pin,
9    sync::{
10        atomic::{AtomicBool, Ordering},
11        Arc,
12    },
13    task::{Context, Poll},
14    time::Duration,
15};
16
17use allocations::Proto;
18use boringtun::x25519::PublicKey;
19use pin_project_lite::pin_project;
20use rand::{rngs::OsRng, Rng, SeedableRng};
21use random::AtomicXorShift32;
22use smoltcp::{
23    socket::{tcp, udp},
24    wire::HardwareAddress,
25};
26use tokio::{
27    runtime::Handle,
28    sync::{futures::Notified, mpsc, oneshot, Notify},
29    time,
30};
31
32mod allocations;
33mod device;
34mod random;
35mod sockets;
36mod tunnel;
37
38pub(crate) use allocations::Allocation;
39
40/// A handle to a WireGuard interface
41///
42/// Cloning returns a new handle to the same interface.
43#[derive(Clone)]
44pub struct Interface {
45    tx: mpsc::UnboundedSender<Message>,
46    shared: Arc<Shared>,
47    allocations: allocations::Allocations,
48    _drop: Arc<CloseOnDrop>,
49}
50
51/// Advanced options for configuring an [`Interface`]
52#[derive(Debug, Clone)]
53#[non_exhaustive]
54pub struct Options {
55    /// Handle to the Tokio runtime on which the interface will be run
56    pub runtime: Handle,
57    /// Default poll interval for the interface when idle
58    pub poll_interval: Duration,
59    /// Inteval at which to update the internal WireGuard timers
60    pub timer_interval: Duration,
61    /// TCP options
62    pub tcp: TcpOptions,
63    /// UDP options
64    pub udp: UdpOptions,
65}
66
67/// Advanced TCP options for configuring an [`Interface`]
68#[derive(Debug, Clone)]
69#[non_exhaustive]
70pub struct TcpOptions {
71    /// Timeout for connecting to a remote peer
72    pub connect_timeout: Duration,
73    /// Size of the TCP receive buffer
74    pub recv_buffer_size: usize,
75    /// Size of the TCP send buffer
76    pub send_buffer_size: usize,
77    /// Maximum number of pending connections on the listener
78    pub backlog: usize,
79}
80
81/// Advanced UDP options for configuring an [`Interface`]
82#[derive(Debug, Clone, Copy)]
83#[non_exhaustive]
84pub struct UdpOptions {
85    /// Size of the UDP receive buffer
86    pub recv_buffer_size: usize,
87    /// Size of the UDP send buffer
88    pub send_buffer_size: usize,
89}
90
91/// A trait for types that can be converted into an [`Interface`]
92pub trait ToInterface {
93    fn to_interface(self) -> impl Future<Output = Result<Interface, io::Error>>;
94}
95
96struct Shared {
97    is_closed: AtomicBool,
98    notify_closed: Notify,
99    options: Options,
100    rng: AtomicXorShift32,
101}
102
103#[derive(Debug)]
104enum Message {
105    Tcp(crate::Shared<tcp::Socket<'static>>),
106    Udp(crate::Shared<udp::Socket<'static>>),
107    TcpConnect {
108        socket: crate::Shared<tcp::Socket<'static>>,
109        allocation: Allocation,
110        target: SocketAddr,
111        result: oneshot::Sender<Result<(), tcp::ConnectError>>,
112    },
113    AddPeer {
114        config: crate::config::Peer,
115        result: oneshot::Sender<Result<(), io::Error>>,
116    },
117    RemovePeer {
118        key: PublicKey,
119        result: oneshot::Sender<bool>,
120    },
121    Close,
122}
123
124pin_project! {
125    /// Future that resolves once its associated interface is in the closed state
126    pub struct Closed<'a> {
127        #[pin]
128        notified: Notified<'a>,
129        shared: Arc<Shared>,
130    }
131}
132
133impl Interface {
134    /// Create a new interface
135    pub fn new(config: crate::config::Config) -> Result<Self, io::Error> {
136        Self::new_with(config, Options::default())
137    }
138
139    /// Local address of the interface within the WireGuard network
140    pub fn address(&self) -> Address {
141        self.allocations.address
142    }
143
144    /// Advanced options for the interface
145    pub fn options(&self) -> &Options {
146        &self.shared.options
147    }
148
149    /// Request that the interface be closed
150    ///
151    /// All sockets created by the interface will be closed, and any attempt to send or receive data
152    /// using them will result in an error. Once all remaining queued packets have been sent,
153    /// the interface will enter the closed state.
154    pub fn close(&self) {
155        self.tx.send(Message::Close).ok();
156    }
157
158    /// Returns a future that resolves once the interface is in the closed state
159    ///
160    /// See [`close`](Self::close) for more information.
161    pub fn closed(&self) -> Closed<'_> {
162        Closed {
163            notified: self.shared.notify_closed.notified(),
164            shared: self.shared.clone(),
165        }
166    }
167
168    /// Whether the interface is in the closed state
169    ///
170    /// See [`close`](Self::close) for more information.
171    pub fn is_closed(&self) -> bool {
172        self.shared.is_closed.load(Ordering::Acquire)
173    }
174
175    /// Create a new interface with advanced options
176    pub fn new_with(config: crate::config::Config, options: Options) -> Result<Self, io::Error> {
177        #[derive(Clone, Copy)]
178        enum Close {
179            No,
180            Requested,
181            Ready,
182        }
183        let runtime = options.runtime.clone();
184        let scope = runtime.enter();
185
186        let crate::config::Config {
187            interface: config,
188            peers: peer_configs,
189        } = config;
190        let (tx, mut rx) = mpsc::unbounded_channel();
191        let (mut interface, mut device) = Self::smol(&config);
192
193        let mut sockets = sockets::Sockets::new();
194        let mut tunnel = tunnel::Tunnel::new(&config, peer_configs)?;
195
196        let poll = time::sleep_until(time::Instant::now());
197        let mut timers = time::interval(options.timer_interval);
198        timers.set_missed_tick_behavior(time::MissedTickBehavior::Delay);
199
200        let mut close = Close::No;
201        let shared = Arc::new(Shared {
202            is_closed: AtomicBool::new(false),
203            notify_closed: Notify::new(),
204            options,
205            rng: AtomicXorShift32::from_entropy(),
206        });
207
208        let i = Self {
209            tx: tx.clone(),
210            shared: shared.clone(),
211            allocations: allocations::Allocations::new(&config),
212            _drop: Arc::new(CloseOnDrop { tx }),
213        };
214        let shared = NotifyOnDrop(shared);
215
216        tokio::spawn(async move {
217            tokio::pin!(poll);
218            loop {
219                #[derive(Debug)]
220                enum Select {
221                    Poll,
222                    Message(Message),
223                    Recv,
224                    Send,
225                    Timers,
226                    Close,
227                }
228
229                let selected = poll_fn(|cx| {
230                    let recv = tunnel.socket().poll_recv_ready(cx);
231                    let send = tunnel.socket().poll_send_ready(cx);
232                    let can_send = device.can_send();
233
234                    if let Poll::Ready(()) = poll.as_mut().poll(cx) {
235                        Poll::Ready(Select::Poll)
236                    } else if let (Poll::Ready(Some(message)), Close::No) =
237                        (rx.poll_recv(cx), close)
238                    {
239                        Poll::Ready(Select::Message(message))
240                    } else if let (Poll::Ready(..), Poll::Ready(..)) = (recv, &send) {
241                        Poll::Ready(Select::Recv)
242                    } else if let (true, Poll::Ready(..)) = (can_send, send) {
243                        Poll::Ready(Select::Send)
244                    } else if let Poll::Ready(..) = timers.poll_tick(cx) {
245                        Poll::Ready(Select::Timers)
246                    } else if let (false, Close::Ready) = (can_send, close) {
247                        Poll::Ready(Select::Close)
248                    } else {
249                        Poll::Pending
250                    }
251                })
252                .await;
253
254                match selected {
255                    Select::Poll => {
256                        let wait = sockets.with(|s| {
257                            let now = smoltcp::time::Instant::now();
258                            interface.poll(now, &mut device, s);
259                            interface.poll_delay(now, s).map(time::Duration::from)
260                        });
261                        match wait {
262                            Some(wait) => poll.as_mut().reset(time::Instant::now() + wait),
263                            None => {
264                                poll.as_mut()
265                                    .reset(time::Instant::now() + shared.options.poll_interval);
266                                if let Close::Requested = close {
267                                    close = Close::Ready;
268                                }
269                            }
270                        }
271                    }
272                    Select::Message(Message::Tcp(socket)) => {
273                        sockets.register_tcp(socket);
274                        poll.as_mut().reset(time::Instant::now());
275                    }
276                    Select::Message(Message::Udp(socket)) => {
277                        sockets.register_udp(socket);
278                        poll.as_mut().reset(time::Instant::now());
279                    }
280                    Select::Message(Message::TcpConnect {
281                        socket,
282                        allocation,
283                        target,
284                        result,
285                    }) => {
286                        if let Some(socket) = socket.lock().as_mut() {
287                            result
288                                .send(socket.connect(
289                                    interface.context(),
290                                    target,
291                                    allocation.address(),
292                                ))
293                                .ok();
294                        }
295                        sockets.register_tcp(socket);
296                        poll.as_mut().reset(time::Instant::now());
297                    }
298                    Select::Message(Message::AddPeer { config, result }) => {
299                        result.send(tunnel.add_peer(config)).ok();
300                    }
301                    Select::Message(Message::RemovePeer { key, result }) => {
302                        result.send(tunnel.remove_peer(&key)).ok();
303                    }
304                    Select::Message(Message::Close) => {
305                        close = Close::Requested;
306                        sockets.close();
307                        poll.as_mut().reset(time::Instant::now());
308                    }
309                    Select::Recv => {
310                        if let Some(packet) = tunnel.recv().await {
311                            device.enqueue_received(packet);
312                        }
313                    }
314                    Select::Send => {
315                        if let Some(packet) = device.dequeue_sent() {
316                            tunnel.send(packet, |_| true).await;
317                        }
318                    }
319                    Select::Timers => tunnel.update_timers().await,
320                    Select::Close => break,
321                }
322            }
323        });
324
325        drop(scope);
326        Ok(i)
327    }
328
329    /// Dynamically add a new peer to the interface. The peer will only become available
330    /// once the returned future resolves.
331    pub async fn add_peer(&self, config: crate::config::Peer) -> Result<(), io::Error> {
332        let (tx, rx) = oneshot::channel();
333        self.tx
334            .send(Message::AddPeer { config, result: tx })
335            .map_err(|_| Self::error())?;
336        rx.await.map_err(|_| Self::error())?
337    }
338    /// Removes a peer from the interface. Returns whether the peer existed.
339    pub async fn remove_peer(&self, key: &PublicKey) -> Result<bool, io::Error> {
340        let (tx, rx) = oneshot::channel();
341        self.tx
342            .send(Message::RemovePeer {
343                key: key.clone(),
344                result: tx,
345            })
346            .map_err(|_| Self::error())?;
347        rx.await.map_err(|_| Self::error())
348    }
349
350    pub(crate) fn error() -> io::Error {
351        io::Error::new(io::ErrorKind::BrokenPipe, "interface is closed")
352    }
353
354    pub(crate) fn register_tcp(
355        &self,
356        socket: crate::Shared<tcp::Socket<'static>>,
357    ) -> io::Result<()> {
358        self.tx
359            .send(Message::Tcp(socket))
360            .map_err(|_| Self::error())
361    }
362    pub(crate) fn register_udp(
363        &self,
364        socket: crate::Shared<udp::Socket<'static>>,
365    ) -> io::Result<()> {
366        self.tx
367            .send(Message::Udp(socket))
368            .map_err(|_| Self::error())
369    }
370
371    pub(crate) fn allocate_tcp(&self, address: impl Into<SocketAddr>) -> Option<Allocation> {
372        self.allocations
373            .acquire(address, Proto::Tcp, &mut &self.shared.rng)
374    }
375    pub(crate) fn allocate_udp(&self, address: impl Into<SocketAddr>) -> Option<Allocation> {
376        self.allocations
377            .acquire(address, Proto::Udp, &mut &self.shared.rng)
378    }
379
380    pub(crate) async fn connect_tcp(
381        &self,
382        socket: crate::Shared<tcp::Socket<'static>>,
383        allocation: Allocation,
384        target: SocketAddr,
385    ) -> io::Result<Result<(), tcp::ConnectError>> {
386        let (tx, result) = oneshot::channel();
387        self.tx
388            .send(Message::TcpConnect {
389                socket,
390                allocation,
391                target,
392                result: tx,
393            })
394            .map_err(|_| Self::error())?;
395        result.await.map_err(|_| Self::error())
396    }
397
398    pub(crate) fn deallocate(&self, allocation: Allocation) {
399        self.allocations.release(allocation);
400    }
401
402    fn smol(config: &crate::config::Interface) -> (smoltcp::iface::Interface, device::Device) {
403        let ips = config.address.addresses();
404        let mut device = device::Device::new(&config);
405        let mut config = smoltcp::iface::Config::new(HardwareAddress::Ip);
406        config.random_seed = OsRng.gen();
407
408        let mut interface =
409            smoltcp::iface::Interface::new(config, &mut device, smoltcp::time::Instant::now());
410        interface.update_ip_addrs(|a| {
411            for ip in ips {
412                match ip {
413                    IpAddr::V4(ip) => a.push(smoltcp::wire::Ipv4Cidr::new(ip.into(), 32).into()),
414                    IpAddr::V6(ip) => a.push(smoltcp::wire::Ipv6Cidr::new(ip.into(), 128).into()),
415                }
416                .ok();
417            }
418        });
419
420        (interface, device)
421    }
422}
423
424use crate::config::Address;
425
426impl fmt::Debug for Interface {
427    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
428        f.debug_tuple(type_name::<Self>())
429            .field(&self.address())
430            .finish()
431    }
432}
433
434impl ToInterface for Interface {
435    async fn to_interface(self) -> Result<Interface, io::Error> {
436        Ok(self)
437    }
438}
439
440impl<'a> ToInterface for &'a Interface {
441    async fn to_interface(self) -> Result<Interface, io::Error> {
442        Ok(self.clone())
443    }
444}
445
446impl<'a> ToInterface for &'a mut Interface {
447    async fn to_interface(self) -> Result<Interface, io::Error> {
448        Ok(self.clone())
449    }
450}
451
452impl ToInterface for crate::config::Config {
453    async fn to_interface(self) -> Result<Interface, io::Error> {
454        Interface::new(self)
455    }
456}
457
458impl Default for Options {
459    fn default() -> Self {
460        Self {
461            runtime: Handle::current(),
462            poll_interval: Duration::from_millis(100),
463            timer_interval: Duration::from_millis(100),
464            tcp: TcpOptions::default(),
465            udp: UdpOptions::default(),
466        }
467    }
468}
469
470impl Default for TcpOptions {
471    fn default() -> Self {
472        Self {
473            connect_timeout: Duration::from_secs(10),
474            recv_buffer_size: 64 * 1024,
475            send_buffer_size: 16 * 1024,
476            backlog: 128,
477        }
478    }
479}
480
481impl Default for UdpOptions {
482    fn default() -> Self {
483        Self {
484            recv_buffer_size: 64 * 1024,
485            send_buffer_size: 16 * 1024,
486        }
487    }
488}
489
490impl<'a> Future for Closed<'a> {
491    type Output = <Notified<'a> as Future>::Output;
492
493    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
494        if self.shared.is_closed.load(Ordering::Acquire) {
495            Poll::Ready(())
496        } else {
497            self.project().notified.poll(cx)
498        }
499    }
500}
501
502struct CloseOnDrop {
503    tx: mpsc::UnboundedSender<Message>,
504}
505
506struct NotifyOnDrop(Arc<Shared>);
507
508impl Drop for CloseOnDrop {
509    fn drop(&mut self) {
510        self.tx.send(Message::Close).ok();
511    }
512}
513
514impl Deref for NotifyOnDrop {
515    type Target = Shared;
516
517    fn deref(&self) -> &Self::Target {
518        &self.0
519    }
520}
521
522impl Drop for NotifyOnDrop {
523    fn drop(&mut self) {
524        self.is_closed.store(true, Ordering::Release);
525        self.notify_closed.notify_waiters();
526    }
527}