1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
use futures_lite::{AsyncRead, AsyncWrite, AsyncWriteExt};

use super::tcp::IpStackTcpStream as IpStackTcpStreamInner;
use crate::{packet::TcpHeaderWrapper, PacketSender};
use std::{net::SocketAddr, pin::Pin, time::Duration};

pub struct IpStackTcpStream {
    inner: Option<Box<IpStackTcpStreamInner>>,
    peer_addr: SocketAddr,
    local_addr: SocketAddr,
    stream_sender: PacketSender,
}

impl IpStackTcpStream {
    pub(crate) fn new(
        local_addr: SocketAddr,
        peer_addr: SocketAddr,
        tcp: TcpHeaderWrapper,
        pkt_sender: PacketSender,
        mtu: u16,
        tcp_timeout: Duration,
    ) -> anyhow::Result<IpStackTcpStream> {
        let (stream_sender, stream_receiver) = async_channel::unbounded();
        IpStackTcpStreamInner::new(
            local_addr,
            peer_addr,
            tcp,
            pkt_sender,
            stream_receiver,
            mtu,
            tcp_timeout,
        )
        .map(|inner| IpStackTcpStream {
            inner: Some(Box::new(inner)),
            peer_addr,
            local_addr,
            stream_sender,
        })
    }
    pub fn local_addr(&self) -> SocketAddr {
        self.local_addr
    }
    pub fn peer_addr(&self) -> SocketAddr {
        self.peer_addr
    }
    pub fn stream_sender(&self) -> PacketSender {
        self.stream_sender.clone()
    }
}

impl AsyncRead for IpStackTcpStream {
    fn poll_read(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &mut [u8],
    ) -> std::task::Poll<std::io::Result<usize>> {
        match self.inner.as_mut() {
            Some(mut inner) => Pin::new(&mut inner).poll_read(cx, buf),
            None => {
                std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected)))
            }
        }
    }
}

impl AsyncWrite for IpStackTcpStream {
    fn poll_write(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &[u8],
    ) -> std::task::Poll<Result<usize, std::io::Error>> {
        match self.inner.as_mut() {
            Some(mut inner) => Pin::new(&mut inner).poll_write(cx, buf),
            None => {
                std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected)))
            }
        }
    }
    fn poll_flush(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), std::io::Error>> {
        match self.inner.as_mut() {
            Some(mut inner) => Pin::new(&mut inner).poll_flush(cx),
            None => {
                std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected)))
            }
        }
    }
    fn poll_close(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), std::io::Error>> {
        match self.inner.as_mut() {
            Some(mut inner) => Pin::new(&mut inner).poll_close(cx),
            None => {
                std::task::Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::NotConnected)))
            }
        }
    }
}

impl Drop for IpStackTcpStream {
    fn drop(&mut self) {
        if let Some(mut inner) = self.inner.take() {
            std::thread::spawn(move || async move {
                let _ = Box::pin(inner.close()).await;
            });
        }
    }
}