tcp_ip/tcp/
sys.rs

1use std::io;
2use std::io::Error;
3use std::ops::Add;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use bytes::{Buf, BytesMut};
9use pnet_packet::ip::IpNextHeaderProtocols;
10use tokio::sync::mpsc::error::TrySendError;
11use tokio::sync::mpsc::{Receiver, Sender};
12use tokio::sync::Notify;
13use tokio::time::Instant;
14
15use crate::ip_stack::{BindAddr, IpStack, NetworkTuple, TransportPacket};
16use crate::tcp::tcb::Tcb;
17
18#[derive(Debug)]
19pub struct TcpStreamTask {
20    _bind_addr: Option<BindAddr>,
21    quick_end: bool,
22    tcb: Tcb,
23    ip_stack: IpStack,
24    application_layer_receiver: Receiver<BytesMut>,
25    last_buffer: Option<BytesMut>,
26    packet_receiver: Receiver<TransportPacket>,
27    application_layer_sender: Option<Sender<BytesMut>>,
28    write_half_closed: bool,
29    retransmission: bool,
30    read_notify: ReadNotify,
31}
32
33#[derive(Clone, Default, Debug)]
34pub struct ReadNotify {
35    readable: Arc<AtomicBool>,
36    notify: Arc<Notify>,
37}
38
39impl ReadNotify {
40    pub fn notify(&self) {
41        if self.readable.load(Ordering::Acquire) {
42            self.notify.notify_one();
43        }
44    }
45    pub fn close(&self) {
46        self.notify.notify_one();
47    }
48    async fn notified(&self) {
49        self.notify.notified().await
50    }
51    fn set_state(&self, readable: bool) {
52        self.readable.store(readable, Ordering::Release);
53    }
54}
55
56impl Drop for TcpStreamTask {
57    fn drop(&mut self) {
58        let peer_addr = self.tcb.peer_addr();
59        let local_addr = self.tcb.local_addr();
60        let network_tuple = NetworkTuple::new(peer_addr, local_addr, IpNextHeaderProtocols::Tcp);
61        self.ip_stack.remove_tcp_socket(&network_tuple);
62    }
63}
64
65impl TcpStreamTask {
66    pub fn new(
67        _bind_addr: Option<BindAddr>,
68        tcb: Tcb,
69        ip_stack: IpStack,
70        application_layer_sender: Sender<BytesMut>,
71        application_layer_receiver: Receiver<BytesMut>,
72        packet_receiver: Receiver<TransportPacket>,
73    ) -> Self {
74        Self {
75            _bind_addr,
76            quick_end: ip_stack.config.tcp_config.quick_end,
77            tcb,
78            ip_stack,
79            application_layer_receiver,
80            last_buffer: None,
81            packet_receiver,
82            application_layer_sender: Some(application_layer_sender),
83            write_half_closed: false,
84            retransmission: false,
85            read_notify: Default::default(),
86        }
87    }
88    pub fn read_notify(&self) -> ReadNotify {
89        self.read_notify.clone()
90    }
91}
92
93impl TcpStreamTask {
94    pub async fn run(&mut self) -> io::Result<()> {
95        let result = self.run0().await;
96        self.push_application_layer();
97        result
98    }
99    pub async fn run0(&mut self) -> io::Result<()> {
100        loop {
101            if self.tcb.is_close() {
102                return Ok(());
103            }
104            if self.quick_end && self.read_half_closed() && self.write_half_closed {
105                return Ok(());
106            }
107            if !self.write_half_closed && !self.retransmission {
108                self.flush().await?;
109            }
110            let data = self.recv_data().await;
111
112            match data {
113                TaskRecvData::In(mut buf) => {
114                    let mut count = 0;
115                    loop {
116                        if let Some(reply_packet) = self.tcb.push_packet(buf) {
117                            self.send_packet(reply_packet).await?;
118                        }
119
120                        if self.tcb.is_close() {
121                            return Ok(());
122                        }
123                        if !self.tcb.readable_state() {
124                            break;
125                        }
126                        count += 1;
127                        if count >= 10 {
128                            break;
129                        }
130                        if let Some(v) = self.try_recv_in() {
131                            buf = v
132                        } else {
133                            break;
134                        }
135                    }
136                    self.push_application_layer();
137                    // if self.tcb.readable_state() && self.application_layer_sender.is_some() && self.tcb.readable() && self.tcb.recv_busy() {
138                    //     // The window is too small and requires blocking to wait; otherwise, it will lead to severe packet loss
139                    //     self.read_notify.notified().await;
140                    //     self.push_application_layer();
141                    // }
142                }
143                TaskRecvData::Out(buf) => {
144                    self.write(buf).await?;
145                }
146                TaskRecvData::InClose => return Err(Error::new(io::ErrorKind::Other, "NetworkDown")),
147                TaskRecvData::OutClose => {
148                    assert!(self.last_buffer.is_none());
149                    self.write_half_closed = true;
150                    let packet = self.tcb.fin_packet();
151                    self.send_packet(packet).await?;
152                    self.tcb.sent_fin();
153                }
154                TaskRecvData::Timeout => {
155                    self.tcb.timeout();
156                    if self.tcb.is_close() {
157                        return Ok(());
158                    }
159                    if self.tcb.cannot_write() {
160                        let packet = self.tcb.fin_packet();
161                        self.send_packet(packet).await?;
162                    }
163                    if self.read_half_closed() && self.write_half_closed {
164                        return Ok(());
165                    }
166                }
167                TaskRecvData::ReadNotify => {
168                    self.push_application_layer();
169                    self.try_send_ack().await?;
170                }
171            }
172            self.retransmission = self.try_retransmission().await?;
173            self.try_send_ack().await?;
174            self.tcb.perform_post_ack_action();
175            if !self.read_half_closed() && self.tcb.cannot_read() {
176                self.close_read();
177            }
178        }
179    }
180    async fn send_packet(&mut self, transport_packet: TransportPacket) -> io::Result<()> {
181        self.ip_stack.send_packet(transport_packet).await?;
182        self.tcb.perform_post_ack_action();
183        Ok(())
184    }
185    fn read_half_closed(&self) -> bool {
186        if let Some(v) = self.application_layer_sender.as_ref() {
187            v.is_closed()
188        } else {
189            true
190        }
191    }
192    pub fn mss(&self) -> u16 {
193        self.tcb.mss()
194    }
195    fn only_recv_in(&self) -> bool {
196        self.retransmission || self.last_buffer.is_some() || self.write_half_closed || self.tcb.limit()
197    }
198    fn push_application_layer(&mut self) {
199        if let Some(sender) = self.application_layer_sender.as_ref() {
200            let mut read_half_closed = false;
201            while self.tcb.readable() {
202                match sender.try_reserve() {
203                    Ok(sender) => {
204                        if let Some(buffer) = self.tcb.read() {
205                            sender.send(buffer);
206                        }
207                    }
208                    Err(e) => match e {
209                        TrySendError::Full(_) => break,
210                        TrySendError::Closed(_) => {
211                            read_half_closed = true;
212                            break;
213                        }
214                    },
215                }
216                self.read_notify.set_state(self.tcb.readable());
217            }
218            if self.tcb.cannot_read() || read_half_closed {
219                self.close_read();
220            }
221        } else {
222            self.tcb.read_none();
223        }
224    }
225    fn close_read(&mut self) {
226        if let Some(sender) = self.application_layer_sender.take() {
227            _ = sender.try_send(BytesMut::new());
228        }
229    }
230    async fn write_slice0(tcb: &mut Tcb, ip_stack: &IpStack, mut buf: &[u8]) -> io::Result<usize> {
231        let len = buf.len();
232        while !buf.is_empty() {
233            if let Some((packet, len)) = tcb.write(buf) {
234                if len == 0 {
235                    break;
236                }
237                ip_stack.send_packet(packet).await?;
238                tcb.perform_post_ack_action();
239                buf = &buf[len..];
240            } else {
241                break;
242            }
243        }
244        Ok(len - buf.len())
245    }
246    async fn write_slice(&mut self, buf: &[u8]) -> io::Result<usize> {
247        Self::write_slice0(&mut self.tcb, &self.ip_stack, buf).await
248    }
249    async fn write(&mut self, mut buf: BytesMut) -> io::Result<usize> {
250        let len = self.write_slice(&buf).await?;
251        if len != buf.len() {
252            // Buffer is full
253            buf.advance(len);
254            self.last_buffer.replace(buf);
255        }
256        Ok(len)
257    }
258    async fn flush(&mut self) -> io::Result<()> {
259        if let Some(buf) = self.last_buffer.as_mut() {
260            let len = Self::write_slice0(&mut self.tcb, &self.ip_stack, buf).await?;
261            if buf.len() == len {
262                self.last_buffer.take();
263            } else {
264                buf.advance(len);
265            }
266        }
267        Ok(())
268    }
269
270    async fn try_retransmission(&mut self) -> io::Result<bool> {
271        if self.write_half_closed {
272            return Ok(false);
273        }
274        if let Some(v) = self.tcb.retransmission() {
275            self.send_packet(v).await?;
276            return Ok(true);
277        }
278        if self.tcb.no_inflight_packet() {
279            return Ok(false);
280        }
281        if self.tcb.need_retransmission() {
282            if let Some(v) = self.tcb.retransmission() {
283                self.send_packet(v).await?;
284                return Ok(true);
285            }
286        }
287        Ok(false)
288    }
289    async fn try_send_ack(&mut self) -> io::Result<()> {
290        if self.tcb.need_ack() {
291            let packet = self.tcb.ack_packet();
292            self.ip_stack.send_packet(packet).await?;
293        }
294        Ok(())
295    }
296
297    async fn recv_data(&mut self) -> TaskRecvData {
298        let deadline = if let Some(v) = self.tcb.time_wait() {
299            Some(v.into())
300        } else {
301            self.tcb.write_timeout().map(|v| v.into())
302        };
303
304        if let Some(deadline) = deadline {
305            if self.only_recv_in() {
306                self.recv_in_timeout_at(deadline).await
307            } else {
308                self.recv_timeout_at(deadline).await
309            }
310        } else if self.write_half_closed {
311            let timeout_at = Instant::now().add(self.ip_stack.config.tcp_config.time_wait_timeout);
312            self.recv_in_timeout_at(timeout_at).await
313        } else {
314            self.recv().await
315        }
316    }
317    async fn recv(&mut self) -> TaskRecvData {
318        tokio::select! {
319            rs=self.packet_receiver.recv()=>{
320                rs.map(|v| TaskRecvData::In(v.buf)).unwrap_or(TaskRecvData::InClose)
321            }
322            rs=self.application_layer_receiver.recv()=>{
323                rs.map(TaskRecvData::Out).unwrap_or(TaskRecvData::OutClose)
324            }
325            _=self.read_notify.notified()=>{
326                TaskRecvData::ReadNotify
327            }
328        }
329    }
330    async fn recv_timeout_at(&mut self, deadline: Instant) -> TaskRecvData {
331        tokio::select! {
332            rs=self.packet_receiver.recv()=>{
333                rs.map(|v| TaskRecvData::In(v.buf)).unwrap_or(TaskRecvData::InClose)
334            }
335            rs=self.application_layer_receiver.recv()=>{
336                rs.map(TaskRecvData::Out).unwrap_or(TaskRecvData::OutClose)
337            }
338            _=tokio::time::sleep_until(deadline)=>{
339                TaskRecvData::Timeout
340            }
341            _=self.read_notify.notified()=>{
342                TaskRecvData::ReadNotify
343            }
344        }
345    }
346
347    async fn recv_in_timeout_at(&mut self, deadline: Instant) -> TaskRecvData {
348        tokio::select! {
349            rs=self.packet_receiver.recv()=>{
350                rs.map(|v| TaskRecvData::In(v.buf)).unwrap_or(TaskRecvData::InClose)
351            }
352            _=tokio::time::sleep_until(deadline)=>{
353                TaskRecvData::Timeout
354            }
355            _=self.read_notify.notified()=>{
356                TaskRecvData::ReadNotify
357            }
358        }
359    }
360    async fn recv_in_timeout(&mut self, duration: Duration) -> TaskRecvData {
361        self.recv_in_timeout_at(Instant::now().add(duration)).await
362    }
363
364    fn try_recv_in(&mut self) -> Option<BytesMut> {
365        self.packet_receiver.try_recv().map(|v| v.buf).ok()
366    }
367}
368
369impl TcpStreamTask {
370    pub async fn connect(&mut self) -> io::Result<()> {
371        let mut count = 0;
372        let mut time = 50;
373        while let Some(packet) = self.tcb.try_syn_sent() {
374            count += 1;
375            if count > 50 {
376                break;
377            }
378            self.send_packet(packet).await?;
379            time *= 2;
380            return match self.recv_in_timeout(Duration::from_millis(time.min(3000))).await {
381                TaskRecvData::In(buf) => {
382                    if let Some(relay) = self.tcb.try_syn_sent_to_established(buf) {
383                        self.send_packet(relay).await?;
384                        Ok(())
385                    } else {
386                        Err(io::Error::from(io::ErrorKind::ConnectionRefused))
387                    }
388                }
389                TaskRecvData::InClose => Err(io::Error::from(io::ErrorKind::ConnectionRefused)),
390                TaskRecvData::Timeout => continue,
391                _ => {
392                    unreachable!()
393                }
394            };
395        }
396        Err(io::Error::from(io::ErrorKind::ConnectionRefused))
397    }
398}
399
400enum TaskRecvData {
401    In(BytesMut),
402    Out(BytesMut),
403    ReadNotify,
404    InClose,
405    OutClose,
406    Timeout,
407}