1use std::{
2 io::{Read, Write},
3 net::IpAddr,
4 os::fd::{AsRawFd, FromRawFd},
5};
6
7use futures::TryStreamExt;
8use ipnet::IpNet;
9use tokio::{
10 sync::{broadcast, mpsc},
11 task,
12};
13use tun_tap::Mode;
14
15use crate::Result;
16
17#[derive(Debug)]
18pub struct TunDevice {
19 device: tun_tap::Iface,
20 rt_handle: rtnetlink::Handle,
21 link_index: u32,
22 mtu: usize,
23}
24
25impl TunDevice {
26 pub async fn new(name: &str) -> Result<Self> {
33 let (rt_connection, rt_handle, _) = rtnetlink::new_connection().map_err(|err| {
35 log::error!("Failed to open rtnetlink connection");
36 log::error!("{}", err);
37 err
38 })?;
39 tokio::spawn(rt_connection);
40
41 let tun_device = tun_tap::Iface::without_packet_info(name, Mode::Tun)?;
43 log::debug!("Created new TUN device: {}", tun_device.name());
44
45 let tun_link = rt_handle
48 .link()
49 .get()
50 .match_name(tun_device.name().to_owned())
51 .execute()
52 .try_next()
53 .await?
54 .expect("Failed to access newly created TUN device");
55
56 rt_handle
58 .link()
59 .set(tun_link.header.index)
60 .up()
61 .execute()
62 .await
63 .map_err(|err| {
64 log::error!("Failed to bring up link");
65 log::error!("{}", err);
66 err
67 })?;
68 log::debug!("Brought {} up", tun_device.name());
69
70 let mtu: usize =
72 std::fs::read_to_string(format!("/sys/class/net/{}/mtu", tun_device.name()))
73 .expect("Failed to read link MTU")
74 .strip_suffix("\n")
75 .unwrap()
76 .parse()
77 .unwrap();
78
79 Ok(Self {
80 device: tun_device,
81 rt_handle,
82 link_index: tun_link.header.index,
83 mtu,
84 })
85 }
86
87 pub async fn add_address(&mut self, ip_address: IpAddr, prefix_len: u8) -> Result<()> {
89 self.rt_handle
90 .address()
91 .add(self.link_index, ip_address, prefix_len)
92 .execute()
93 .await
94 .map_err(|err| {
95 log::error!("Failed to add address {} to link", ip_address);
96 log::error!("{}", err);
97 err
98 })?;
99
100 Ok(())
101 }
102
103 pub async fn remove_address(&mut self, ip_address: IpAddr, prefix_len: u8) -> Result<()> {
105 if let Some(address_message) = self
107 .rt_handle
108 .address()
109 .get()
110 .set_link_index_filter(self.link_index)
111 .set_address_filter(ip_address)
112 .set_prefix_length_filter(prefix_len)
113 .execute()
114 .try_next()
115 .await
116 .map_err(|err| {
117 log::error!("Failed to find address {} on link", ip_address);
118 log::error!("{}", err);
119 err
120 })?
121 {
122 self.rt_handle
124 .address()
125 .del(address_message)
126 .execute()
127 .await
128 .map_err(|err| {
129 log::error!("Failed to remove address {} from link", ip_address);
130 log::error!("{}", err);
131 err
132 })?;
133 }
134
135 Ok(())
136 }
137
138 pub async fn add_route(&mut self, destination: IpNet) -> Result<()> {
140 match destination {
141 IpNet::V4(destination) => {
142 self.rt_handle
143 .route()
144 .add()
145 .v4()
146 .output_interface(self.link_index)
147 .destination_prefix(destination.addr(), destination.prefix_len())
148 .execute()
149 .await
150 .map_err(|err| {
151 log::error!("Failed to add route {} to link", destination);
152 log::error!("{}", err);
153 err
154 })?;
155 }
156 IpNet::V6(destination) => {
157 self.rt_handle
158 .route()
159 .add()
160 .v6()
161 .output_interface(self.link_index)
162 .destination_prefix(destination.addr(), destination.prefix_len())
163 .execute()
164 .await
165 .map_err(|err| {
166 log::error!("Failed to add route {} to link", destination);
167 log::error!("{}", err);
168 err
169 })?;
170 }
171 }
172
173 Ok(())
174 }
175
176 pub async fn spawn_worker(&self) -> (mpsc::Sender<Vec<u8>>, broadcast::Receiver<Vec<u8>>) {
178 let (tx_to_caller, rx_from_worker) = broadcast::channel(65535);
180
181 let (tx_to_worker, mut rx_from_caller) = mpsc::channel(65535);
183
184 let mtu = self.mtu;
186 let device_fd = self.device.as_raw_fd();
187
188 let _rx_task = task::spawn_blocking(move || {
190 let mut buffer = vec![0u8; mtu];
192
193 let mut device = unsafe { std::fs::File::from_raw_fd(device_fd) };
195
196 loop {
197 let packet_len = device.read(&mut buffer[..]).unwrap();
199 let packet = buffer[..packet_len].to_vec();
200
201 tx_to_caller.send(packet).unwrap();
203 }
204 });
205
206 let _tx_task = task::spawn(async move {
208 let mut device = unsafe { std::fs::File::from_raw_fd(device_fd) };
210
211 loop {
212 let packet: Vec<u8> = rx_from_caller.recv().await.unwrap();
214
215 device.write_all(&packet[..]).unwrap();
217 }
218 });
219
220 let _tx_task = task::spawn_blocking(|| {});
222
223 (tx_to_worker, rx_from_worker)
225 }
226}