ipstack_geph/stream/
tcp_wrapper.rs

1use futures_lite::{AsyncRead, AsyncWrite, AsyncWriteExt};
2
3use super::tcp::IpStackTcpStream as IpStackTcpStreamInner;
4use crate::{packet::TcpHeaderWrapper, PacketSender};
5use std::{net::SocketAddr, pin::Pin, time::Duration};
6
7pub struct IpStackTcpStream {
8    inner: Option<Box<IpStackTcpStreamInner>>,
9    peer_addr: SocketAddr,
10    local_addr: SocketAddr,
11    stream_sender: PacketSender,
12}
13
14impl IpStackTcpStream {
15    pub(crate) fn new(
16        local_addr: SocketAddr,
17        peer_addr: SocketAddr,
18        tcp: TcpHeaderWrapper,
19        pkt_sender: PacketSender,
20        mtu: u16,
21        tcp_timeout: Duration,
22    ) -> anyhow::Result<IpStackTcpStream> {
23        let (stream_sender, stream_receiver) = async_channel::unbounded();
24        IpStackTcpStreamInner::new(
25            local_addr,
26            peer_addr,
27            tcp,
28            pkt_sender,
29            stream_receiver,
30            mtu,
31            tcp_timeout,
32        )
33        .map(|inner| IpStackTcpStream {
34            inner: Some(Box::new(inner)),
35            peer_addr,
36            local_addr,
37            stream_sender,
38        })
39    }
40    pub fn local_addr(&self) -> SocketAddr {
41        self.local_addr
42    }
43    pub fn peer_addr(&self) -> SocketAddr {
44        self.peer_addr
45    }
46    pub fn stream_sender(&self) -> PacketSender {
47        self.stream_sender.clone()
48    }
49}
50
51impl AsyncRead for IpStackTcpStream {
52    fn poll_read(
53        mut self: std::pin::Pin<&mut Self>,
54        cx: &mut std::task::Context<'_>,
55        buf: &mut [u8],
56    ) -> std::task::Poll<std::io::Result<usize>> {
57        match self.inner.as_mut() {
58            Some(mut inner) => Pin::new(&mut inner).poll_read(cx, buf),
59            None => {
60                std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected)))
61            }
62        }
63    }
64}
65
66impl AsyncWrite for IpStackTcpStream {
67    fn poll_write(
68        mut self: std::pin::Pin<&mut Self>,
69        cx: &mut std::task::Context<'_>,
70        buf: &[u8],
71    ) -> std::task::Poll<Result<usize, std::io::Error>> {
72        match self.inner.as_mut() {
73            Some(mut inner) => Pin::new(&mut inner).poll_write(cx, buf),
74            None => {
75                std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected)))
76            }
77        }
78    }
79    fn poll_flush(
80        mut self: std::pin::Pin<&mut Self>,
81        cx: &mut std::task::Context<'_>,
82    ) -> std::task::Poll<Result<(), std::io::Error>> {
83        match self.inner.as_mut() {
84            Some(mut inner) => Pin::new(&mut inner).poll_flush(cx),
85            None => {
86                std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected)))
87            }
88        }
89    }
90    fn poll_close(
91        mut self: std::pin::Pin<&mut Self>,
92        cx: &mut std::task::Context<'_>,
93    ) -> std::task::Poll<Result<(), std::io::Error>> {
94        match self.inner.as_mut() {
95            Some(mut inner) => Pin::new(&mut inner).poll_close(cx),
96            None => {
97                std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected)))
98            }
99        }
100    }
101}
102
103impl Drop for IpStackTcpStream {
104    fn drop(&mut self) {
105        if let Some(mut inner) = self.inner.take() {
106            smolscale::spawn(async move {
107                let _ = Box::pin(inner.close()).await;
108            })
109            .detach()
110        }
111    }
112}