1use crate::Result;
2use crate::TunBuilder;
3use crate::linux::interface::Interface;
4use crate::linux::io::TunIo;
5use crate::linux::params::Params;
6use std::io::{self, ErrorKind, IoSlice, Read, Write};
7use std::mem;
8use std::net::Ipv4Addr;
9use std::os::raw::c_char;
10use std::os::unix::io::{AsRawFd, RawFd};
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{self, Context, Poll};
14use tokio::io::unix::AsyncFd;
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16
17static TUN: &[u8] = b"/dev/net/tun\0";
18
19macro_rules! ready {
21 ($e:expr $(,)?) => {
22 match $e {
23 std::task::Poll::Ready(t) => t,
24 std::task::Poll::Pending => return std::task::Poll::Pending,
25 }
26 };
27}
28
29pub struct Tun {
31 iface: Arc<Interface>,
32 io: AsyncFd<TunIo>,
33}
34
35impl AsRawFd for Tun {
36 fn as_raw_fd(&self) -> RawFd {
37 self.io.as_raw_fd()
38 }
39}
40
41impl AsyncRead for Tun {
42 fn poll_read(
43 self: Pin<&mut Self>,
44 cx: &mut Context<'_>,
45 buf: &mut ReadBuf<'_>,
46 ) -> task::Poll<io::Result<()>> {
47 let self_mut = self.get_mut();
48 loop {
49 let mut guard = ready!(self_mut.io.poll_read_ready_mut(cx))?;
50
51 match guard.try_io(|inner| inner.get_mut().read(buf.initialize_unfilled())) {
52 Ok(Ok(n)) => {
53 buf.set_filled(buf.filled().len() + n);
54 return Poll::Ready(Ok(()));
55 }
56 Ok(Err(err)) => return Poll::Ready(Err(err)),
57 Err(_) => continue,
58 }
59 }
60 }
61}
62
63impl AsyncWrite for Tun {
64 fn poll_write(
65 self: Pin<&mut Self>,
66 cx: &mut Context<'_>,
67 buf: &[u8],
68 ) -> task::Poll<io::Result<usize>> {
69 let self_mut = self.get_mut();
70 loop {
71 let mut guard = ready!(self_mut.io.poll_write_ready_mut(cx))?;
72
73 match guard.try_io(|inner| inner.get_mut().write(buf)) {
74 Ok(result) => return Poll::Ready(result),
75 Err(_would_block) => continue,
76 }
77 }
78 }
79
80 fn poll_write_vectored(
81 self: Pin<&mut Self>,
82 cx: &mut Context<'_>,
83 bufs: &[IoSlice<'_>],
84 ) -> Poll<std::result::Result<usize, io::Error>> {
85 let self_mut = self.get_mut();
86 loop {
87 let mut guard = ready!(self_mut.io.poll_write_ready_mut(cx))?;
88
89 match guard.try_io(|inner| inner.get_mut().write_vectored(bufs)) {
90 Ok(result) => return Poll::Ready(result),
91 Err(_would_block) => continue,
92 }
93 }
94 }
95
96 fn is_write_vectored(&self) -> bool {
97 true
98 }
99
100 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> task::Poll<io::Result<()>> {
101 let self_mut = self.get_mut();
102 loop {
103 let mut guard = ready!(self_mut.io.poll_write_ready_mut(cx))?;
104
105 match guard.try_io(|inner| inner.get_mut().flush()) {
106 Ok(result) => return Poll::Ready(result),
107 Err(_) => continue,
108 }
109 }
110 }
111
112 fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> task::Poll<io::Result<()>> {
113 Poll::Ready(Ok(()))
114 }
115}
116
117impl Tun {
118 pub fn builder() -> TunBuilder {
119 TunBuilder::new()
120 }
121
122 pub(crate) fn new(params: Params) -> Result<Self> {
124 let iface = Self::allocate(params, 1)?;
125 let fd = iface.files()[0];
126 Ok(Self {
127 iface: Arc::new(iface),
128 io: AsyncFd::new(TunIo::from(fd))?,
129 })
130 }
131
132 pub(crate) fn new_mq(params: Params, queues: usize) -> Result<Vec<Self>> {
134 let iface = Self::allocate(params, queues)?;
135 let mut tuns = Vec::with_capacity(queues);
136 let iface = Arc::new(iface);
137 for &fd in iface.files() {
138 tuns.push(Self {
139 iface: iface.clone(),
140 io: AsyncFd::new(TunIo::from(fd))?,
141 })
142 }
143 Ok(tuns)
144 }
145
146 fn allocate(params: Params, queues: usize) -> Result<Interface> {
147 let extra_flags = if params.cloexec { libc::O_CLOEXEC } else { 0 };
148
149 let fds = (0..queues)
150 .map(|_| unsafe {
151 match libc::open(
152 TUN.as_ptr().cast::<c_char>(),
153 libc::O_RDWR | libc::O_NONBLOCK | extra_flags,
154 ) {
155 fd if fd >= 0 => Ok(fd),
156 _ => Err(io::Error::last_os_error().into()),
157 }
158 })
159 .collect::<Result<Vec<_>>>()?;
160
161 let iface = Interface::new(
162 fds,
163 params.name.as_deref().unwrap_or_default(),
164 params.flags,
165 params.cloexec,
166 )?;
167 iface.init(params)?;
168 Ok(iface)
169 }
170
171 pub async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
175 loop {
176 let mut guard = self.io.readable().await?;
177 match guard.try_io(|inner| inner.get_ref().recv(buf)) {
178 Ok(res) => return res,
179 Err(_) => continue,
180 }
181 }
182 }
183
184 pub async fn send(&self, buf: &[u8]) -> io::Result<usize> {
188 loop {
189 let mut guard = self.io.writable().await?;
190 match guard.try_io(|inner| inner.get_ref().send(buf)) {
191 Ok(res) => return res,
192 Err(_) => continue,
193 }
194 }
195 }
196
197 pub async fn send_all(&self, buf: &[u8]) -> io::Result<()> {
201 let mut remaining = buf;
202 while !remaining.is_empty() {
203 match self.send(remaining).await? {
204 0 => return Err(ErrorKind::WriteZero.into()),
205 n => {
206 let (_, rest) = mem::take(&mut remaining).split_at(n);
207 remaining = rest;
208 }
209 }
210 }
211 Ok(())
212 }
213
214 pub async fn sendv(&self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
218 loop {
219 let mut guard = self.io.writable().await?;
220 match guard.try_io(|inner| inner.get_ref().sendv(bufs)) {
221 Ok(res) => return res,
222 Err(_) => continue,
223 }
224 }
225 }
226
227 pub async fn sendv_all(&self, bufs: &mut [IoSlice<'_>]) -> io::Result<()> {
240 let mut bufs = bufs;
241 while !bufs.is_empty() {
242 match self.sendv(bufs).await? {
243 0 => {
244 return Err(std::io::Error::new(
245 std::io::ErrorKind::WriteZero,
246 "failed to write whole buffer",
247 ));
248 }
249 n => {
250 IoSlice::advance_slices(&mut bufs, n);
251 }
252 }
253 }
254 Ok(())
255 }
256
257 pub fn try_recv(&self, buf: &mut [u8]) -> io::Result<usize> {
263 self.io.get_ref().recv(buf)
264 }
265
266 pub fn try_send(&self, buf: &[u8]) -> io::Result<usize> {
272 self.io.get_ref().send(buf)
273 }
274
275 pub fn try_sendv(&self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
281 self.io.get_ref().sendv(bufs)
282 }
283
284 pub fn name(&self) -> &str {
286 self.iface.name()
287 }
288
289 pub fn mtu(&self) -> Result<i32> {
291 self.iface.mtu(None)
292 }
293
294 pub fn address(&self) -> Result<Ipv4Addr> {
296 self.iface.address(None)
297 }
298
299 pub fn destination(&self) -> Result<Ipv4Addr> {
301 self.iface.destination(None)
302 }
303
304 pub fn broadcast(&self) -> Result<Ipv4Addr> {
306 self.iface.broadcast(None)
307 }
308
309 pub fn netmask(&self) -> Result<Ipv4Addr> {
311 self.iface.netmask(None)
312 }
313
314 pub fn flags(&self) -> Result<i16> {
316 self.iface.flags(None)
317 }
318}