1use std::io;
2use std::io::Error;
3use std::ops::Add;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use bytes::{Buf, BytesMut};
9use pnet_packet::ip::IpNextHeaderProtocols;
10use tokio::sync::mpsc::error::TrySendError;
11use tokio::sync::mpsc::{Receiver, Sender};
12use tokio::sync::Notify;
13use tokio::time::Instant;
14
15use crate::ip_stack::{BindAddr, IpStack, NetworkTuple, TransportPacket};
16use crate::tcp::tcb::Tcb;
17
18#[derive(Debug)]
19pub struct TcpStreamTask {
20 _bind_addr: Option<BindAddr>,
21 quick_end: bool,
22 tcb: Tcb,
23 ip_stack: IpStack,
24 application_layer_receiver: Receiver<BytesMut>,
25 last_buffer: Option<BytesMut>,
26 packet_receiver: Receiver<TransportPacket>,
27 application_layer_sender: Option<Sender<BytesMut>>,
28 write_half_closed: bool,
29 retransmission: bool,
30 read_notify: ReadNotify,
31}
32
33#[derive(Clone, Default, Debug)]
34pub struct ReadNotify {
35 readable: Arc<AtomicBool>,
36 notify: Arc<Notify>,
37}
38
39impl ReadNotify {
40 pub fn notify(&self) {
41 if self.readable.load(Ordering::Acquire) {
42 self.notify.notify_one();
43 }
44 }
45 pub fn close(&self) {
46 self.notify.notify_one();
47 }
48 async fn notified(&self) {
49 self.notify.notified().await
50 }
51 fn set_state(&self, readable: bool) {
52 self.readable.store(readable, Ordering::Release);
53 }
54}
55
56impl Drop for TcpStreamTask {
57 fn drop(&mut self) {
58 let peer_addr = self.tcb.peer_addr();
59 let local_addr = self.tcb.local_addr();
60 let network_tuple = NetworkTuple::new(peer_addr, local_addr, IpNextHeaderProtocols::Tcp);
61 self.ip_stack.remove_tcp_socket(&network_tuple);
62 }
63}
64
65impl TcpStreamTask {
66 pub fn new(
67 _bind_addr: Option<BindAddr>,
68 tcb: Tcb,
69 ip_stack: IpStack,
70 application_layer_sender: Sender<BytesMut>,
71 application_layer_receiver: Receiver<BytesMut>,
72 packet_receiver: Receiver<TransportPacket>,
73 ) -> Self {
74 Self {
75 _bind_addr,
76 quick_end: ip_stack.config.tcp_config.quick_end,
77 tcb,
78 ip_stack,
79 application_layer_receiver,
80 last_buffer: None,
81 packet_receiver,
82 application_layer_sender: Some(application_layer_sender),
83 write_half_closed: false,
84 retransmission: false,
85 read_notify: Default::default(),
86 }
87 }
88 pub fn read_notify(&self) -> ReadNotify {
89 self.read_notify.clone()
90 }
91}
92
93impl TcpStreamTask {
94 pub async fn run(&mut self) -> io::Result<()> {
95 let result = self.run0().await;
96 self.push_application_layer();
97 result
98 }
99 pub async fn run0(&mut self) -> io::Result<()> {
100 loop {
101 if self.tcb.is_close() {
102 return Ok(());
103 }
104 if self.quick_end && self.read_half_closed() && self.write_half_closed {
105 return Ok(());
106 }
107 if !self.write_half_closed && !self.retransmission {
108 self.flush().await?;
109 }
110 let data = self.recv_data().await;
111
112 match data {
113 TaskRecvData::In(mut buf) => {
114 let mut count = 0;
115 loop {
116 if let Some(reply_packet) = self.tcb.push_packet(buf) {
117 self.send_packet(reply_packet).await?;
118 }
119
120 if self.tcb.is_close() {
121 return Ok(());
122 }
123 if !self.tcb.readable_state() {
124 break;
125 }
126 count += 1;
127 if count >= 10 {
128 break;
129 }
130 if let Some(v) = self.try_recv_in() {
131 buf = v
132 } else {
133 break;
134 }
135 }
136 self.push_application_layer();
137 }
143 TaskRecvData::Out(buf) => {
144 self.write(buf).await?;
145 }
146 TaskRecvData::InClose => return Err(Error::new(io::ErrorKind::Other, "NetworkDown")),
147 TaskRecvData::OutClose => {
148 assert!(self.last_buffer.is_none());
149 self.write_half_closed = true;
150 let packet = self.tcb.fin_packet();
151 self.send_packet(packet).await?;
152 self.tcb.sent_fin();
153 }
154 TaskRecvData::Timeout => {
155 self.tcb.timeout();
156 if self.tcb.is_close() {
157 return Ok(());
158 }
159 if self.tcb.cannot_write() {
160 let packet = self.tcb.fin_packet();
161 self.send_packet(packet).await?;
162 }
163 if self.read_half_closed() && self.write_half_closed {
164 return Ok(());
165 }
166 }
167 TaskRecvData::ReadNotify => {
168 self.push_application_layer();
169 self.try_send_ack().await?;
170 }
171 }
172 self.retransmission = self.try_retransmission().await?;
173 self.try_send_ack().await?;
174 self.tcb.perform_post_ack_action();
175 if !self.read_half_closed() && self.tcb.cannot_read() {
176 self.close_read();
177 }
178 }
179 }
180 async fn send_packet(&mut self, transport_packet: TransportPacket) -> io::Result<()> {
181 self.ip_stack.send_packet(transport_packet).await?;
182 self.tcb.perform_post_ack_action();
183 Ok(())
184 }
185 fn read_half_closed(&self) -> bool {
186 if let Some(v) = self.application_layer_sender.as_ref() {
187 v.is_closed()
188 } else {
189 true
190 }
191 }
192 pub fn mss(&self) -> u16 {
193 self.tcb.mss()
194 }
195 fn only_recv_in(&self) -> bool {
196 self.retransmission || self.last_buffer.is_some() || self.write_half_closed || self.tcb.limit()
197 }
198 fn push_application_layer(&mut self) {
199 if let Some(sender) = self.application_layer_sender.as_ref() {
200 let mut read_half_closed = false;
201 while self.tcb.readable() {
202 match sender.try_reserve() {
203 Ok(sender) => {
204 if let Some(buffer) = self.tcb.read() {
205 sender.send(buffer);
206 }
207 }
208 Err(e) => match e {
209 TrySendError::Full(_) => break,
210 TrySendError::Closed(_) => {
211 read_half_closed = true;
212 break;
213 }
214 },
215 }
216 self.read_notify.set_state(self.tcb.readable());
217 }
218 if self.tcb.cannot_read() || read_half_closed {
219 self.close_read();
220 }
221 } else {
222 self.tcb.read_none();
223 }
224 }
225 fn close_read(&mut self) {
226 if let Some(sender) = self.application_layer_sender.take() {
227 _ = sender.try_send(BytesMut::new());
228 }
229 }
230 async fn write_slice0(tcb: &mut Tcb, ip_stack: &IpStack, mut buf: &[u8]) -> io::Result<usize> {
231 let len = buf.len();
232 while !buf.is_empty() {
233 if let Some((packet, len)) = tcb.write(buf) {
234 if len == 0 {
235 break;
236 }
237 ip_stack.send_packet(packet).await?;
238 tcb.perform_post_ack_action();
239 buf = &buf[len..];
240 } else {
241 break;
242 }
243 }
244 Ok(len - buf.len())
245 }
246 async fn write_slice(&mut self, buf: &[u8]) -> io::Result<usize> {
247 Self::write_slice0(&mut self.tcb, &self.ip_stack, buf).await
248 }
249 async fn write(&mut self, mut buf: BytesMut) -> io::Result<usize> {
250 let len = self.write_slice(&buf).await?;
251 if len != buf.len() {
252 buf.advance(len);
254 self.last_buffer.replace(buf);
255 }
256 Ok(len)
257 }
258 async fn flush(&mut self) -> io::Result<()> {
259 if let Some(buf) = self.last_buffer.as_mut() {
260 let len = Self::write_slice0(&mut self.tcb, &self.ip_stack, buf).await?;
261 if buf.len() == len {
262 self.last_buffer.take();
263 } else {
264 buf.advance(len);
265 }
266 }
267 Ok(())
268 }
269
270 async fn try_retransmission(&mut self) -> io::Result<bool> {
271 if self.write_half_closed {
272 return Ok(false);
273 }
274 if let Some(v) = self.tcb.retransmission() {
275 self.send_packet(v).await?;
276 return Ok(true);
277 }
278 if self.tcb.no_inflight_packet() {
279 return Ok(false);
280 }
281 if self.tcb.need_retransmission() {
282 if let Some(v) = self.tcb.retransmission() {
283 self.send_packet(v).await?;
284 return Ok(true);
285 }
286 }
287 Ok(false)
288 }
289 async fn try_send_ack(&mut self) -> io::Result<()> {
290 if self.tcb.need_ack() {
291 let packet = self.tcb.ack_packet();
292 self.ip_stack.send_packet(packet).await?;
293 }
294 Ok(())
295 }
296
297 async fn recv_data(&mut self) -> TaskRecvData {
298 let deadline = if let Some(v) = self.tcb.time_wait() {
299 Some(v.into())
300 } else {
301 self.tcb.write_timeout().map(|v| v.into())
302 };
303
304 if let Some(deadline) = deadline {
305 if self.only_recv_in() {
306 self.recv_in_timeout_at(deadline).await
307 } else {
308 self.recv_timeout_at(deadline).await
309 }
310 } else if self.write_half_closed {
311 let timeout_at = Instant::now().add(self.ip_stack.config.tcp_config.time_wait_timeout);
312 self.recv_in_timeout_at(timeout_at).await
313 } else {
314 self.recv().await
315 }
316 }
317 async fn recv(&mut self) -> TaskRecvData {
318 tokio::select! {
319 rs=self.packet_receiver.recv()=>{
320 rs.map(|v| TaskRecvData::In(v.buf)).unwrap_or(TaskRecvData::InClose)
321 }
322 rs=self.application_layer_receiver.recv()=>{
323 rs.map(TaskRecvData::Out).unwrap_or(TaskRecvData::OutClose)
324 }
325 _=self.read_notify.notified()=>{
326 TaskRecvData::ReadNotify
327 }
328 }
329 }
330 async fn recv_timeout_at(&mut self, deadline: Instant) -> TaskRecvData {
331 tokio::select! {
332 rs=self.packet_receiver.recv()=>{
333 rs.map(|v| TaskRecvData::In(v.buf)).unwrap_or(TaskRecvData::InClose)
334 }
335 rs=self.application_layer_receiver.recv()=>{
336 rs.map(TaskRecvData::Out).unwrap_or(TaskRecvData::OutClose)
337 }
338 _=tokio::time::sleep_until(deadline)=>{
339 TaskRecvData::Timeout
340 }
341 _=self.read_notify.notified()=>{
342 TaskRecvData::ReadNotify
343 }
344 }
345 }
346
347 async fn recv_in_timeout_at(&mut self, deadline: Instant) -> TaskRecvData {
348 tokio::select! {
349 rs=self.packet_receiver.recv()=>{
350 rs.map(|v| TaskRecvData::In(v.buf)).unwrap_or(TaskRecvData::InClose)
351 }
352 _=tokio::time::sleep_until(deadline)=>{
353 TaskRecvData::Timeout
354 }
355 _=self.read_notify.notified()=>{
356 TaskRecvData::ReadNotify
357 }
358 }
359 }
360 async fn recv_in_timeout(&mut self, duration: Duration) -> TaskRecvData {
361 self.recv_in_timeout_at(Instant::now().add(duration)).await
362 }
363
364 fn try_recv_in(&mut self) -> Option<BytesMut> {
365 self.packet_receiver.try_recv().map(|v| v.buf).ok()
366 }
367}
368
369impl TcpStreamTask {
370 pub async fn connect(&mut self) -> io::Result<()> {
371 let mut count = 0;
372 let mut time = 50;
373 while let Some(packet) = self.tcb.try_syn_sent() {
374 count += 1;
375 if count > 50 {
376 break;
377 }
378 self.send_packet(packet).await?;
379 time *= 2;
380 return match self.recv_in_timeout(Duration::from_millis(time.min(3000))).await {
381 TaskRecvData::In(buf) => {
382 if let Some(relay) = self.tcb.try_syn_sent_to_established(buf) {
383 self.send_packet(relay).await?;
384 Ok(())
385 } else {
386 Err(io::Error::from(io::ErrorKind::ConnectionRefused))
387 }
388 }
389 TaskRecvData::InClose => Err(io::Error::from(io::ErrorKind::ConnectionRefused)),
390 TaskRecvData::Timeout => continue,
391 _ => {
392 unreachable!()
393 }
394 };
395 }
396 Err(io::Error::from(io::ErrorKind::ConnectionRefused))
397 }
398}
399
400enum TaskRecvData {
401 In(BytesMut),
402 Out(BytesMut),
403 ReadNotify,
404 InClose,
405 OutClose,
406 Timeout,
407}