ya_relay_stack/
connection.rs

1use derive_more::Display;
2use smoltcp::iface::SocketHandle;
3use smoltcp::socket::*;
4use smoltcp::wire::IpEndpoint;
5use std::cell::RefCell;
6use std::convert::TryFrom;
7use std::future::Future;
8use std::net::SocketAddr;
9use std::pin::Pin;
10use std::rc::Rc;
11use std::task::{Context, Poll};
12
13use crate::interface::CaptureInterface;
14use crate::patch_smoltcp::GetSocketSafe;
15use crate::socket::{SocketDesc, SocketEndpoint};
16use crate::{Error, Protocol, Result};
17
18use ya_relay_util::Payload;
19
20/// Virtual connection teardown reason
21#[derive(Copy, Clone, Debug)]
22pub enum DisconnectReason {
23    SinkClosed,
24    SocketClosed,
25    ConnectionFinished,
26    ConnectionFailed,
27    ConnectionTimeout,
28}
29
30/// Virtual connection representing 2 endpoints with an existing record
31/// of exchanging packets via a known protocol; not necessarily a TCP connection
32#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
33pub struct Connection {
34    pub handle: SocketHandle,
35    pub meta: ConnectionMeta,
36}
37
38impl Connection {
39    pub fn try_new<T, E>(handle: SocketHandle, t: T) -> Result<Self>
40    where
41        ConnectionMeta: TryFrom<T, Error = E>,
42        Error: From<E>,
43    {
44        Ok(Self {
45            handle,
46            meta: ConnectionMeta::try_from(t)?,
47        })
48    }
49}
50
51impl From<Connection> for SocketDesc {
52    fn from(c: Connection) -> Self {
53        SocketDesc {
54            protocol: c.meta.protocol,
55            local: c.meta.local.into(),
56            remote: c.meta.remote.into(),
57        }
58    }
59}
60
61impl From<Connection> for SocketHandle {
62    fn from(c: Connection) -> Self {
63        c.handle
64    }
65}
66
67impl From<Connection> for ConnectionMeta {
68    fn from(c: Connection) -> Self {
69        c.meta
70    }
71}
72
73#[derive(Debug, Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
74#[display(
75    fmt = "ConnectionMeta {{ protocol: {}, local: {}, remote: {} }}",
76    protocol,
77    local,
78    remote
79)]
80pub struct ConnectionMeta {
81    pub protocol: Protocol,
82    pub local: IpEndpoint,
83    pub remote: IpEndpoint,
84}
85
86impl ConnectionMeta {
87    pub fn new(protocol: Protocol, local: IpEndpoint, remote: IpEndpoint) -> Self {
88        Self {
89            protocol,
90            local,
91            remote,
92        }
93    }
94
95    pub fn unspecified(protocol: Protocol) -> Self {
96        Self {
97            protocol,
98            local: IpEndpoint {
99                addr: smoltcp::wire::Ipv4Address::UNSPECIFIED.into_address(),
100                port: 0,
101            },
102            remote: IpEndpoint {
103                addr: smoltcp::wire::Ipv4Address::UNSPECIFIED.into_address(),
104                port: 0,
105            },
106        }
107    }
108
109    #[inline]
110    pub fn to_socket_addr(&self) -> SocketAddr {
111        SocketAddr::from((self.local.addr, self.local.port))
112    }
113}
114
115impl From<ConnectionMeta> for SocketDesc {
116    fn from(c: ConnectionMeta) -> Self {
117        SocketDesc {
118            protocol: c.protocol,
119            local: c.local.into(),
120            remote: c.remote.into(),
121        }
122    }
123}
124
125impl<'a> From<&'a ConnectionMeta> for SocketEndpoint {
126    fn from(c: &'a ConnectionMeta) -> Self {
127        SocketEndpoint::Ip(c.local)
128    }
129}
130
131impl TryFrom<SocketDesc> for ConnectionMeta {
132    type Error = Error;
133
134    fn try_from(desc: SocketDesc) -> std::result::Result<Self, Self::Error> {
135        let local = match desc.local {
136            SocketEndpoint::Ip(endpoint) => endpoint,
137            endpoint => return Err(Error::EndpointInvalid(endpoint)),
138        };
139        let remote = match desc.remote {
140            SocketEndpoint::Ip(endpoint) => endpoint,
141            endpoint => return Err(Error::EndpointInvalid(endpoint)),
142        };
143        Ok(Self {
144            protocol: desc.protocol,
145            local,
146            remote,
147        })
148    }
149}
150
151/// TCP connection future
152pub struct Connect<'a> {
153    pub connection: Connection,
154    iface: Rc<RefCell<CaptureInterface<'a>>>,
155}
156
157impl<'a> Connect<'a> {
158    pub fn new(connection: Connection, iface: Rc<RefCell<CaptureInterface<'a>>>) -> Self {
159        log::trace!("[Connect::new]: {:?}", connection);
160        Self { connection, iface }
161    }
162}
163
164impl<'a> Future for Connect<'a> {
165    type Output = Result<Connection>;
166
167    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
168        let iface_rfc = self.iface.clone();
169        let mut iface = iface_rfc.borrow_mut();
170
171        log::trace!("[Connect::poll]: {:?}", self.connection);
172
173        let socket = match iface.get_socket_safe::<tcp::Socket>(self.connection.handle) {
174            Ok(s) => s,
175            Err(_) => return Poll::Ready(Err(Error::SocketClosed)),
176        };
177
178        if !socket.is_open() {
179            Poll::Ready(Err(Error::SocketClosed))
180        } else if socket.can_send() {
181            Poll::Ready(Ok(self.connection))
182        } else {
183            socket.register_send_waker(cx.waker());
184            Poll::Pending
185        }
186    }
187}
188
189/// TCP disconnection future
190pub struct Disconnect<'a> {
191    handle: SocketHandle,
192    iface: Rc<RefCell<CaptureInterface<'a>>>,
193}
194
195impl<'a> Disconnect<'a> {
196    pub fn new(handle: SocketHandle, iface: Rc<RefCell<CaptureInterface<'a>>>) -> Self {
197        Self { handle, iface }
198    }
199}
200
201impl<'a> Future for Disconnect<'a> {
202    type Output = Result<()>;
203
204    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
205        let iface_rfc = self.iface.clone();
206        let mut iface = iface_rfc.borrow_mut();
207
208        let socket = match iface.get_socket_safe::<tcp::Socket>(self.handle) {
209            Ok(s) => s,
210            Err(_) => return Poll::Ready(Ok(())),
211        };
212
213        if !socket.is_open() {
214            Poll::Ready(Ok(()))
215        } else {
216            socket.register_recv_waker(cx.waker());
217            Poll::Pending
218        }
219    }
220}
221
222/// Packet send future
223pub struct Send<'a> {
224    data: Payload,
225    offset: usize,
226    connection: Connection,
227    iface: Rc<RefCell<CaptureInterface<'a>>>,
228    /// Send completion callback; there may as well have been no data sent
229    sent: Box<dyn Fn()>,
230}
231
232impl<'a> Send<'a> {
233    pub fn new<F: Fn() + 'static>(
234        data: Payload,
235        connection: Connection,
236        iface: Rc<RefCell<CaptureInterface<'a>>>,
237        sent: F,
238    ) -> Self {
239        log::trace!("[Send::new]: {:?}", connection);
240        Self {
241            data,
242            offset: 0,
243            connection,
244            iface,
245            sent: Box::new(sent),
246        }
247    }
248}
249
250impl<'a> Future for Send<'a> {
251    type Output = Result<()>;
252
253    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
254        let result = {
255            let mut iface = self.iface.borrow_mut();
256            let conn = &self.connection;
257
258            match conn.meta.protocol {
259                Protocol::Tcp => {
260                    log::trace!(
261                        "[Future(Send)::poll]: Sending TCP packet: {:?}",
262                        self.data.as_ref()
263                    );
264                    let result = {
265                        let socket = match iface.get_socket_safe::<tcp::Socket>(conn.handle) {
266                            Ok(socket) => socket,
267                            Err(e) => return Poll::Ready(Err(Error::Other(e.to_string()))),
268                        };
269                        socket.register_send_waker(cx.waker());
270                        socket.send_slice(&self.data.as_ref()[self.offset..])
271                    };
272
273                    drop(iface);
274                    (*self.sent)();
275
276                    return match result {
277                        Ok(count) => {
278                            self.offset += count;
279                            if self.offset >= self.data.len() {
280                                Poll::Ready(Ok(()))
281                            } else {
282                                Poll::Pending
283                            }
284                        }
285                        Err(smoltcp::socket::tcp::SendError::InvalidState) => Poll::Pending,
286                    };
287                }
288                Protocol::Udp => {
289                    log::trace!(
290                        "[Future(Send)::poll]: Sending UDP packet: {:?}",
291                        self.data.as_ref()
292                    );
293                    let socket = match iface.get_socket_safe::<udp::Socket>(conn.handle) {
294                        Ok(socket) => socket,
295                        Err(e) => return Poll::Ready(Err(Error::Other(e.to_string()))),
296                    };
297                    socket.register_send_waker(cx.waker());
298                    socket
299                        .send_slice(self.data.as_ref(), conn.meta.remote)
300                        .map_err(|err| err.to_string())
301                }
302                Protocol::Icmp | Protocol::Ipv6Icmp => {
303                    log::trace!(
304                        "[Future(Send)::poll]: Sending ICMP packet: {:?}",
305                        self.data.as_ref()
306                    );
307                    let socket = match iface.get_socket_safe::<icmp::Socket>(conn.handle) {
308                        Ok(socket) => socket,
309                        Err(e) => return Poll::Ready(Err(Error::Other(e.to_string()))),
310                    };
311                    socket.register_send_waker(cx.waker());
312                    socket
313                        .send_slice(self.data.as_ref(), conn.meta.remote.addr)
314                        .map_err(|err| err.to_string())
315                }
316                _ => {
317                    let socket = match iface.get_socket_safe::<raw::Socket>(conn.handle) {
318                        Ok(socket) => socket,
319                        Err(e) => return Poll::Ready(Err(Error::Other(e.to_string()))),
320                    };
321                    socket.register_send_waker(cx.waker());
322                    socket
323                        .send_slice(self.data.as_ref())
324                        .map_err(|err| err.to_string())
325                }
326            }
327        };
328
329        (*self.sent)();
330
331        match result {
332            Ok(_) => Poll::Ready(Ok(())),
333            Err(err) => Poll::Ready(Err(Error::Other(err))),
334        }
335    }
336}