protomask_tun/
tun.rs

1use std::{
2    io::{Read, Write},
3    net::IpAddr,
4    os::fd::{AsRawFd, FromRawFd},
5};
6
7use futures::TryStreamExt;
8use ipnet::IpNet;
9use tokio::{
10    sync::{broadcast, mpsc},
11    task,
12};
13use tun_tap::Mode;
14
15use crate::Result;
16
17#[derive(Debug)]
18pub struct TunDevice {
19    device: tun_tap::Iface,
20    rt_handle: rtnetlink::Handle,
21    link_index: u32,
22    mtu: usize,
23}
24
25impl TunDevice {
26    /// Create and bring up a new TUN device
27    ///
28    /// ## Name format
29    ///
30    /// The name field can be any string. If `%d` is present in the string,
31    /// it will be replaced with a unique number.
32    pub async fn new(name: &str) -> Result<Self> {
33        // Bring up an rtnetlink connection
34        let (rt_connection, rt_handle, _) = rtnetlink::new_connection().map_err(|err| {
35            log::error!("Failed to open rtnetlink connection");
36            log::error!("{}", err);
37            err
38        })?;
39        tokio::spawn(rt_connection);
40
41        // Create the TUN device
42        let tun_device = tun_tap::Iface::without_packet_info(name, Mode::Tun)?;
43        log::debug!("Created new TUN device: {}", tun_device.name());
44
45        // Get access to the link through rtnetlink
46        // NOTE: I don't think there is any way this can fail, so `except` is probably OK
47        let tun_link = rt_handle
48            .link()
49            .get()
50            .match_name(tun_device.name().to_owned())
51            .execute()
52            .try_next()
53            .await?
54            .expect("Failed to access newly created TUN device");
55
56        // Bring the link up
57        rt_handle
58            .link()
59            .set(tun_link.header.index)
60            .up()
61            .execute()
62            .await
63            .map_err(|err| {
64                log::error!("Failed to bring up link");
65                log::error!("{}", err);
66                err
67            })?;
68        log::debug!("Brought {} up", tun_device.name());
69
70        // Read the link MTU
71        let mtu: usize =
72            std::fs::read_to_string(format!("/sys/class/net/{}/mtu", tun_device.name()))
73                .expect("Failed to read link MTU")
74                .strip_suffix("\n")
75                .unwrap()
76                .parse()
77                .unwrap();
78
79        Ok(Self {
80            device: tun_device,
81            rt_handle,
82            link_index: tun_link.header.index,
83            mtu,
84        })
85    }
86
87    /// Add an IP address to this device
88    pub async fn add_address(&mut self, ip_address: IpAddr, prefix_len: u8) -> Result<()> {
89        self.rt_handle
90            .address()
91            .add(self.link_index, ip_address, prefix_len)
92            .execute()
93            .await
94            .map_err(|err| {
95                log::error!("Failed to add address {} to link", ip_address);
96                log::error!("{}", err);
97                err
98            })?;
99
100        Ok(())
101    }
102
103    /// Remove an IP address from this device
104    pub async fn remove_address(&mut self, ip_address: IpAddr, prefix_len: u8) -> Result<()> {
105        // Find the address message that matches the given address
106        if let Some(address_message) = self
107            .rt_handle
108            .address()
109            .get()
110            .set_link_index_filter(self.link_index)
111            .set_address_filter(ip_address)
112            .set_prefix_length_filter(prefix_len)
113            .execute()
114            .try_next()
115            .await
116            .map_err(|err| {
117                log::error!("Failed to find address {} on link", ip_address);
118                log::error!("{}", err);
119                err
120            })?
121        {
122            // Delete the address
123            self.rt_handle
124                .address()
125                .del(address_message)
126                .execute()
127                .await
128                .map_err(|err| {
129                    log::error!("Failed to remove address {} from link", ip_address);
130                    log::error!("{}", err);
131                    err
132                })?;
133        }
134
135        Ok(())
136    }
137
138    /// Add a route to this device
139    pub async fn add_route(&mut self, destination: IpNet) -> Result<()> {
140        match destination {
141            IpNet::V4(destination) => {
142                self.rt_handle
143                    .route()
144                    .add()
145                    .v4()
146                    .output_interface(self.link_index)
147                    .destination_prefix(destination.addr(), destination.prefix_len())
148                    .execute()
149                    .await
150                    .map_err(|err| {
151                        log::error!("Failed to add route {} to link", destination);
152                        log::error!("{}", err);
153                        err
154                    })?;
155            }
156            IpNet::V6(destination) => {
157                self.rt_handle
158                    .route()
159                    .add()
160                    .v6()
161                    .output_interface(self.link_index)
162                    .destination_prefix(destination.addr(), destination.prefix_len())
163                    .execute()
164                    .await
165                    .map_err(|err| {
166                        log::error!("Failed to add route {} to link", destination);
167                        log::error!("{}", err);
168                        err
169                    })?;
170            }
171        }
172
173        Ok(())
174    }
175
176    /// Spawns worker threads, and returns a tx/rx pair for the caller to interact with them
177    pub async fn spawn_worker(&self) -> (mpsc::Sender<Vec<u8>>, broadcast::Receiver<Vec<u8>>) {
178        // Create a channel for packets to be sent to the caller
179        let (tx_to_caller, rx_from_worker) = broadcast::channel(65535);
180
181        // Create a channel for packets being received from the caller
182        let (tx_to_worker, mut rx_from_caller) = mpsc::channel(65535);
183
184        // Clone some values for use in worker threads
185        let mtu = self.mtu;
186        let device_fd = self.device.as_raw_fd();
187
188        // Create a task that broadcasts all incoming packets
189        let _rx_task = task::spawn_blocking(move || {
190            // Build a buffer to read packets into
191            let mut buffer = vec![0u8; mtu];
192
193            // Create a file to access the TUN device
194            let mut device = unsafe { std::fs::File::from_raw_fd(device_fd) };
195
196            loop {
197                // Read a packet from the TUN device
198                let packet_len = device.read(&mut buffer[..]).unwrap();
199                let packet = buffer[..packet_len].to_vec();
200
201                // Broadcast the packet to all listeners
202                tx_to_caller.send(packet).unwrap();
203            }
204        });
205
206        // Create a task that sends all outgoing packets
207        let _tx_task = task::spawn(async move {
208            // Create a file to access the TUN device
209            let mut device = unsafe { std::fs::File::from_raw_fd(device_fd) };
210
211            loop {
212                // Wait for a packet to be sent
213                let packet: Vec<u8> = rx_from_caller.recv().await.unwrap();
214
215                // Write the packet to the TUN device
216                device.write_all(&packet[..]).unwrap();
217            }
218        });
219
220        // Create a task that sends all outgoing packets
221        let _tx_task = task::spawn_blocking(|| {});
222
223        // Return an rx/tx pair for the caller to interact with the workers
224        (tx_to_worker, rx_from_worker)
225    }
226}