1use derive_more::Display;
2use smoltcp::iface::SocketHandle;
3use smoltcp::socket::*;
4use smoltcp::wire::IpEndpoint;
5use std::cell::RefCell;
6use std::convert::TryFrom;
7use std::future::Future;
8use std::net::SocketAddr;
9use std::pin::Pin;
10use std::rc::Rc;
11use std::task::{Context, Poll};
12
13use crate::interface::CaptureInterface;
14use crate::patch_smoltcp::GetSocketSafe;
15use crate::socket::{SocketDesc, SocketEndpoint};
16use crate::{Error, Protocol, Result};
17
18use ya_relay_util::Payload;
19
20#[derive(Copy, Clone, Debug)]
22pub enum DisconnectReason {
23 SinkClosed,
24 SocketClosed,
25 ConnectionFinished,
26 ConnectionFailed,
27 ConnectionTimeout,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
33pub struct Connection {
34 pub handle: SocketHandle,
35 pub meta: ConnectionMeta,
36}
37
38impl Connection {
39 pub fn try_new<T, E>(handle: SocketHandle, t: T) -> Result<Self>
40 where
41 ConnectionMeta: TryFrom<T, Error = E>,
42 Error: From<E>,
43 {
44 Ok(Self {
45 handle,
46 meta: ConnectionMeta::try_from(t)?,
47 })
48 }
49}
50
51impl From<Connection> for SocketDesc {
52 fn from(c: Connection) -> Self {
53 SocketDesc {
54 protocol: c.meta.protocol,
55 local: c.meta.local.into(),
56 remote: c.meta.remote.into(),
57 }
58 }
59}
60
61impl From<Connection> for SocketHandle {
62 fn from(c: Connection) -> Self {
63 c.handle
64 }
65}
66
67impl From<Connection> for ConnectionMeta {
68 fn from(c: Connection) -> Self {
69 c.meta
70 }
71}
72
73#[derive(Debug, Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
74#[display(
75 fmt = "ConnectionMeta {{ protocol: {}, local: {}, remote: {} }}",
76 protocol,
77 local,
78 remote
79)]
80pub struct ConnectionMeta {
81 pub protocol: Protocol,
82 pub local: IpEndpoint,
83 pub remote: IpEndpoint,
84}
85
86impl ConnectionMeta {
87 pub fn new(protocol: Protocol, local: IpEndpoint, remote: IpEndpoint) -> Self {
88 Self {
89 protocol,
90 local,
91 remote,
92 }
93 }
94
95 pub fn unspecified(protocol: Protocol) -> Self {
96 Self {
97 protocol,
98 local: IpEndpoint {
99 addr: smoltcp::wire::Ipv4Address::UNSPECIFIED.into_address(),
100 port: 0,
101 },
102 remote: IpEndpoint {
103 addr: smoltcp::wire::Ipv4Address::UNSPECIFIED.into_address(),
104 port: 0,
105 },
106 }
107 }
108
109 #[inline]
110 pub fn to_socket_addr(&self) -> SocketAddr {
111 SocketAddr::from((self.local.addr, self.local.port))
112 }
113}
114
115impl From<ConnectionMeta> for SocketDesc {
116 fn from(c: ConnectionMeta) -> Self {
117 SocketDesc {
118 protocol: c.protocol,
119 local: c.local.into(),
120 remote: c.remote.into(),
121 }
122 }
123}
124
125impl<'a> From<&'a ConnectionMeta> for SocketEndpoint {
126 fn from(c: &'a ConnectionMeta) -> Self {
127 SocketEndpoint::Ip(c.local)
128 }
129}
130
131impl TryFrom<SocketDesc> for ConnectionMeta {
132 type Error = Error;
133
134 fn try_from(desc: SocketDesc) -> std::result::Result<Self, Self::Error> {
135 let local = match desc.local {
136 SocketEndpoint::Ip(endpoint) => endpoint,
137 endpoint => return Err(Error::EndpointInvalid(endpoint)),
138 };
139 let remote = match desc.remote {
140 SocketEndpoint::Ip(endpoint) => endpoint,
141 endpoint => return Err(Error::EndpointInvalid(endpoint)),
142 };
143 Ok(Self {
144 protocol: desc.protocol,
145 local,
146 remote,
147 })
148 }
149}
150
151pub struct Connect<'a> {
153 pub connection: Connection,
154 iface: Rc<RefCell<CaptureInterface<'a>>>,
155}
156
157impl<'a> Connect<'a> {
158 pub fn new(connection: Connection, iface: Rc<RefCell<CaptureInterface<'a>>>) -> Self {
159 log::trace!("[Connect::new]: {:?}", connection);
160 Self { connection, iface }
161 }
162}
163
164impl<'a> Future for Connect<'a> {
165 type Output = Result<Connection>;
166
167 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
168 let iface_rfc = self.iface.clone();
169 let mut iface = iface_rfc.borrow_mut();
170
171 log::trace!("[Connect::poll]: {:?}", self.connection);
172
173 let socket = match iface.get_socket_safe::<tcp::Socket>(self.connection.handle) {
174 Ok(s) => s,
175 Err(_) => return Poll::Ready(Err(Error::SocketClosed)),
176 };
177
178 if !socket.is_open() {
179 Poll::Ready(Err(Error::SocketClosed))
180 } else if socket.can_send() {
181 Poll::Ready(Ok(self.connection))
182 } else {
183 socket.register_send_waker(cx.waker());
184 Poll::Pending
185 }
186 }
187}
188
189pub struct Disconnect<'a> {
191 handle: SocketHandle,
192 iface: Rc<RefCell<CaptureInterface<'a>>>,
193}
194
195impl<'a> Disconnect<'a> {
196 pub fn new(handle: SocketHandle, iface: Rc<RefCell<CaptureInterface<'a>>>) -> Self {
197 Self { handle, iface }
198 }
199}
200
201impl<'a> Future for Disconnect<'a> {
202 type Output = Result<()>;
203
204 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
205 let iface_rfc = self.iface.clone();
206 let mut iface = iface_rfc.borrow_mut();
207
208 let socket = match iface.get_socket_safe::<tcp::Socket>(self.handle) {
209 Ok(s) => s,
210 Err(_) => return Poll::Ready(Ok(())),
211 };
212
213 if !socket.is_open() {
214 Poll::Ready(Ok(()))
215 } else {
216 socket.register_recv_waker(cx.waker());
217 Poll::Pending
218 }
219 }
220}
221
222pub struct Send<'a> {
224 data: Payload,
225 offset: usize,
226 connection: Connection,
227 iface: Rc<RefCell<CaptureInterface<'a>>>,
228 sent: Box<dyn Fn()>,
230}
231
232impl<'a> Send<'a> {
233 pub fn new<F: Fn() + 'static>(
234 data: Payload,
235 connection: Connection,
236 iface: Rc<RefCell<CaptureInterface<'a>>>,
237 sent: F,
238 ) -> Self {
239 log::trace!("[Send::new]: {:?}", connection);
240 Self {
241 data,
242 offset: 0,
243 connection,
244 iface,
245 sent: Box::new(sent),
246 }
247 }
248}
249
250impl<'a> Future for Send<'a> {
251 type Output = Result<()>;
252
253 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
254 let result = {
255 let mut iface = self.iface.borrow_mut();
256 let conn = &self.connection;
257
258 match conn.meta.protocol {
259 Protocol::Tcp => {
260 log::trace!(
261 "[Future(Send)::poll]: Sending TCP packet: {:?}",
262 self.data.as_ref()
263 );
264 let result = {
265 let socket = match iface.get_socket_safe::<tcp::Socket>(conn.handle) {
266 Ok(socket) => socket,
267 Err(e) => return Poll::Ready(Err(Error::Other(e.to_string()))),
268 };
269 socket.register_send_waker(cx.waker());
270 socket.send_slice(&self.data.as_ref()[self.offset..])
271 };
272
273 drop(iface);
274 (*self.sent)();
275
276 return match result {
277 Ok(count) => {
278 self.offset += count;
279 if self.offset >= self.data.len() {
280 Poll::Ready(Ok(()))
281 } else {
282 Poll::Pending
283 }
284 }
285 Err(smoltcp::socket::tcp::SendError::InvalidState) => Poll::Pending,
286 };
287 }
288 Protocol::Udp => {
289 log::trace!(
290 "[Future(Send)::poll]: Sending UDP packet: {:?}",
291 self.data.as_ref()
292 );
293 let socket = match iface.get_socket_safe::<udp::Socket>(conn.handle) {
294 Ok(socket) => socket,
295 Err(e) => return Poll::Ready(Err(Error::Other(e.to_string()))),
296 };
297 socket.register_send_waker(cx.waker());
298 socket
299 .send_slice(self.data.as_ref(), conn.meta.remote)
300 .map_err(|err| err.to_string())
301 }
302 Protocol::Icmp | Protocol::Ipv6Icmp => {
303 log::trace!(
304 "[Future(Send)::poll]: Sending ICMP packet: {:?}",
305 self.data.as_ref()
306 );
307 let socket = match iface.get_socket_safe::<icmp::Socket>(conn.handle) {
308 Ok(socket) => socket,
309 Err(e) => return Poll::Ready(Err(Error::Other(e.to_string()))),
310 };
311 socket.register_send_waker(cx.waker());
312 socket
313 .send_slice(self.data.as_ref(), conn.meta.remote.addr)
314 .map_err(|err| err.to_string())
315 }
316 _ => {
317 let socket = match iface.get_socket_safe::<raw::Socket>(conn.handle) {
318 Ok(socket) => socket,
319 Err(e) => return Poll::Ready(Err(Error::Other(e.to_string()))),
320 };
321 socket.register_send_waker(cx.waker());
322 socket
323 .send_slice(self.data.as_ref())
324 .map_err(|err| err.to_string())
325 }
326 }
327 };
328
329 (*self.sent)();
330
331 match result {
332 Ok(_) => Poll::Ready(Ok(())),
333 Err(err) => Poll::Ready(Err(Error::Other(err))),
334 }
335 }
336}