sosistab2/multiplex/stream/
stream_state.rs1use std::{
2 io::{Read, Write},
3 sync::Arc,
4 time::{Duration, Instant},
5};
6
7use bytes::Bytes;
8
9use clone_macro::clone;
10use once_cell::sync::Lazy;
11use parking_lot::Mutex;
12use stdcode::StdcodeSerializeExt;
13
14use crate::{
15 multiplex::stream::{RelKind, StreamMessage},
16 Stream,
17};
18
19use super::{inflight::Inflight, reorderer::Reorderer, StreamQueues};
20const MSS: usize = 1150;
21
22pub struct StreamState {
35 phase: Phase,
36 stream_id: u16,
37 additional_data: String,
38 incoming_queue: Vec<StreamMessage>,
39 queues: Arc<Mutex<StreamQueues>>,
40 local_notify: Arc<async_event::Event>,
41 tick_notify: Arc<dyn Fn() + Send + Sync + 'static>,
42
43 next_unseen_seqno: u64,
45 reorderer: Reorderer<Bytes>,
46
47 inflight: Inflight,
49 next_write_seqno: u64,
50 cwnd: f64,
51 ssthresh: f64,
52
53 in_recovery: bool,
54 last_write_time: Instant,
55}
56
57impl Drop for StreamState {
58 fn drop(&mut self) {
59 self.queues.lock().closed = true;
60 self.local_notify.notify_all();
61 }
62}
63
64impl StreamState {
65 pub fn new_pending(
67 tick_notify: impl Fn() + Send + Sync + 'static,
68 stream_id: u16,
69 label: String,
70 ) -> (Self, Stream) {
71 Self::new_in_phase(tick_notify, stream_id, Phase::Pending, label)
72 }
73
74 pub fn new_established(
76 tick_notify: impl Fn() + Send + Sync + 'static,
77 stream_id: u16,
78 label: String,
79 ) -> (Self, Stream) {
80 Self::new_in_phase(tick_notify, stream_id, Phase::Established, label)
81 }
82
83 fn new_in_phase(
85 tick_notify: impl Fn() + Send + Sync + 'static,
86 stream_id: u16,
87 phase: Phase,
88 label: String,
89 ) -> (Self, Stream) {
90 let queues = Arc::new(Mutex::new(StreamQueues::default()));
91 let ready = Arc::new(async_event::Event::new());
92 let tick_notify: Arc<dyn Fn() + Send + Sync + 'static> = Arc::new(tick_notify);
93 let handle = Stream::new(
94 clone!([tick_notify], move || tick_notify()),
95 ready.clone(),
96 queues.clone(),
97 label.clone().into(),
98 );
99
100 static START: Lazy<Instant> = Lazy::new(Instant::now);
101 let state = Self {
102 phase,
103 stream_id,
104 incoming_queue: Default::default(),
105 queues,
106 local_notify: ready,
107
108 next_unseen_seqno: 0,
109 reorderer: Reorderer::default(),
110 inflight: Inflight::new(),
111 next_write_seqno: 0,
112 cwnd: 4.0,
113 ssthresh: 0.0,
114 tick_notify,
115
116 in_recovery: false,
117
118 additional_data: label,
119 last_write_time: *START,
120 };
121 (state, handle)
122 }
123
124 pub fn inject_incoming(&mut self, msg: StreamMessage) {
126 self.incoming_queue.push(msg);
127 (self.tick_notify)();
128 }
129
130 pub fn tick(&mut self, mut outgoing_callback: impl FnMut(StreamMessage)) -> Option<Instant> {
134 log::trace!("ticking {} at {:?}", self.stream_id, self.phase);
135
136 let now: Instant = Instant::now();
137
138 match self.phase {
139 Phase::Pending => {
140 outgoing_callback(StreamMessage::Reliable {
142 kind: RelKind::Syn,
143 stream_id: self.stream_id,
144 seqno: 0,
145 payload: Bytes::copy_from_slice(self.additional_data.as_bytes()),
146 });
147 let next_resend = now + Duration::from_secs(1);
148 self.phase = Phase::SynSent { next_resend };
149 Some(next_resend)
150 }
151 Phase::SynSent { next_resend } => {
152 if self.incoming_queue.drain(..).any(|msg| {
153 matches!(
154 msg,
155 StreamMessage::Reliable {
156 kind: RelKind::SynAck,
157 stream_id: _,
158 seqno: _,
159 payload: _
160 }
161 )
162 }) {
163 self.phase = Phase::Established;
164 self.queues.lock().connected = true;
165 self.local_notify.notify_all();
166 Some(now)
167 } else if now >= next_resend {
168 outgoing_callback(StreamMessage::Reliable {
169 kind: RelKind::Syn,
170 stream_id: self.stream_id,
171 seqno: 0,
172 payload: Bytes::copy_from_slice(self.additional_data.as_bytes()),
173 });
174 let next_resend = now + Duration::from_secs(1);
175 self.phase = Phase::SynSent { next_resend };
176 Some(next_resend)
177 } else {
178 Some(next_resend)
179 }
180 }
181 Phase::Established => {
182 self.tick_read(now, &mut outgoing_callback);
184 self.tick_write(now, &mut outgoing_callback);
186 if self.queues.lock().closed {
188 self.phase = Phase::Closed;
189 }
190 Some(self.retick_time(now))
192 }
193 Phase::Closed => {
194 self.queues.lock().closed = true;
195 self.local_notify.notify_all();
196 for _ in self.incoming_queue.drain(..) {
197 outgoing_callback(StreamMessage::Reliable {
198 kind: RelKind::Rst,
199 stream_id: self.stream_id,
200 seqno: 0,
201 payload: Default::default(),
202 });
203 }
204 None
205 }
206 }
207 }
208
209 fn tick_read(&mut self, _now: Instant, mut outgoing_callback: impl FnMut(StreamMessage)) {
210 let mut to_ack = vec![];
212 for packet in self.incoming_queue.drain(..) {
214 if self.queues.lock().read_stream.len() > 10_000_000 {
217 continue;
218 }
219
220 match packet {
221 StreamMessage::Reliable {
222 kind: RelKind::Data,
223 stream_id,
224 seqno,
225 payload,
226 } => {
227 log::trace!("incoming seqno {stream_id}/{seqno}");
228 if self.reorderer.insert(seqno, payload) {
229 to_ack.push(seqno);
230 }
231 }
232 StreamMessage::Reliable {
233 kind: RelKind::DataAck,
234 stream_id: _,
235 seqno: lowest_unseen_seqno, payload: selective_acks,
237 } => {
238 let mut ack_count = self.inflight.mark_acked_lt(lowest_unseen_seqno);
240 if let Ok(sacks) = stdcode::deserialize::<Vec<u64>>(&selective_acks) {
242 for sack in sacks {
243 if self.inflight.mark_acked(sack) {
244 ack_count += 1;
245 }
246 }
247 }
248
249 for _ in 0..ack_count {
251 let bic_inc = if self.cwnd < self.ssthresh {
252 (self.ssthresh - self.cwnd) / 2.0
253 } else {
254 self.cwnd - self.ssthresh
255 }
256 .max(1.0)
257 .min(50.0)
258 .min(self.cwnd);
259 self.cwnd += bic_inc / self.cwnd;
260 }
261
262 log::debug!(
263 "ack_count = {ack_count}; send window {}; cwnd {:.1}; bdp {}; write queue {}",
264 self.inflight.inflight(),
265 self.cwnd,
266 self.inflight.bdp(),
267 self.queues.lock().write_stream.len()
268 );
269 self.local_notify.notify_all();
270 }
271 StreamMessage::Reliable {
272 kind: RelKind::Syn,
273 stream_id,
274 seqno,
275 payload,
276 } => {
277 outgoing_callback(StreamMessage::Reliable {
279 kind: RelKind::SynAck,
280 stream_id,
281 seqno,
282 payload,
283 });
284 }
285 StreamMessage::Reliable {
286 kind: RelKind::Rst | RelKind::Fin,
287 stream_id: _,
288 seqno: _,
289 payload: _,
290 } => {
291 self.phase = Phase::Closed;
292 }
293 StreamMessage::Unreliable {
294 stream_id: _,
295 payload,
296 } => {
297 self.queues.lock().recv_urel.push_back(payload);
298 self.local_notify.notify_all();
299 }
300 _ => log::warn!("discarding out-of-turn packet {:?}", packet),
301 }
302 }
303 for (seqno, packet) in self.reorderer.take() {
305 self.next_unseen_seqno = seqno + 1;
306 self.queues.lock().read_stream.write_all(&packet).unwrap();
307 }
308
309 if !to_ack.is_empty() {
311 self.local_notify.notify_all();
312 to_ack.retain(|a| a >= &self.next_unseen_seqno);
313 outgoing_callback(StreamMessage::Reliable {
314 kind: RelKind::DataAck,
315 stream_id: self.stream_id,
316 seqno: self.next_unseen_seqno,
317 payload: to_ack.stdcode().into(),
318 });
319 }
320 }
321
322 fn start_recovery(&mut self) {
323 if !self.in_recovery {
324 log::debug!("*** START RECOVRY AT CWND = {}", self.cwnd);
325
326 let beta = 0.15;
328 if self.cwnd < self.ssthresh {
329 self.ssthresh = self.cwnd * (2.0 - beta) / 2.0;
330 } else {
331 self.ssthresh = self.cwnd;
332 }
333
334 self.cwnd *= 1.0 - beta;
335 self.cwnd = self.cwnd.max(1.0);
336
337 self.in_recovery = true;
338 }
339 }
340
341 fn stop_recovery(&mut self) {
342 self.in_recovery = false;
343 }
344
345 fn congested(&self, now: Instant) -> bool {
346 self.inflight.inflight() - self.inflight.lost_at(now) >= self.cwnd as usize
347 }
348
349 fn tick_write(&mut self, now: Instant, mut outgoing_callback: impl FnMut(StreamMessage)) {
350 log::trace!("tick_write for {}", self.stream_id);
351 {
353 let mut queues = self.queues.lock();
354 while let Some(payload) = queues.send_urel.pop_front() {
355 outgoing_callback(StreamMessage::Unreliable {
356 stream_id: self.stream_id,
357 payload,
358 });
359 }
360 }
361
362 if self.inflight.lost_at(now) > 0 {
363 self.start_recovery();
364 } else {
365 self.stop_recovery();
366 }
367
368 let speed = self.speed();
370 let mut writes_allowed = (now
371 .saturating_duration_since(self.last_write_time)
372 .as_secs_f64()
373 * speed) as usize;
374
375 while !self.congested(now) && writes_allowed > 0 {
376 if let Some((seqno, retrans_time)) = self.inflight.first_rto() {
378 if now >= retrans_time {
379 log::debug!(
380 "inflight = {}, lost = {}, cwnd = {}",
381 self.inflight.inflight(),
382 self.inflight.lost_at(now),
383 self.cwnd
384 );
385 log::debug!("*** retransmit {}", seqno);
386 let first = self.inflight.retransmit(seqno).expect("no first");
387 writes_allowed -= 1;
388 log::debug!("RETRANSMIT {seqno} at {:.2} pkts/s", speed);
389 outgoing_callback(first);
390 continue;
391 }
392 }
393
394 let mut queues = self.queues.lock();
396 if !queues.write_stream.is_empty() {
397 let mut buffer = vec![0; MSS];
398 let n = queues.write_stream.read(&mut buffer).unwrap();
399 buffer.truncate(n);
400 let seqno = self.next_write_seqno;
401 self.next_write_seqno += 1;
402 let msg = StreamMessage::Reliable {
403 kind: RelKind::Data,
404 stream_id: self.stream_id,
405 seqno,
406 payload: buffer.into(),
407 };
408 self.inflight.insert(msg.clone());
409 self.local_notify.notify_all();
410
411 outgoing_callback(msg);
412 self.last_write_time = now;
413 writes_allowed -= 1;
414 log::debug!("{seqno} at {:.2} pkts/s", speed);
415 continue;
416 } else {
417 queues.write_stream.shrink_to_fit();
418 }
419
420 break;
421 }
422 }
423
424 fn speed(&self) -> f64 {
425 (self.cwnd / self.inflight.min_rtt().as_secs_f64()).max(1.0)
426 }
427
428 fn retick_time(&self, now: Instant) -> Instant {
429 let idle = { self.inflight.inflight() == 0 && self.queues.lock().write_stream.is_empty() };
430
431 if idle {
432 now + Duration::from_secs(100000)
433 } else {
434 now + Duration::from_secs_f64((1.0 / self.speed()).max(0.02)) }
436 }
437}
438
439#[derive(Clone, Copy, Debug)]
440enum Phase {
441 Pending,
442 SynSent { next_resend: Instant },
443 Established,
444 Closed,
445}