sosistab2/multiplex/stream/
stream_state.rs

1use std::{
2    io::{Read, Write},
3    sync::Arc,
4    time::{Duration, Instant},
5};
6
7use bytes::Bytes;
8
9use clone_macro::clone;
10use once_cell::sync::Lazy;
11use parking_lot::Mutex;
12use stdcode::StdcodeSerializeExt;
13
14use crate::{
15    multiplex::stream::{RelKind, StreamMessage},
16    Stream,
17};
18
19use super::{inflight::Inflight, reorderer::Reorderer, StreamQueues};
20const MSS: usize = 1150;
21
22/// The raw internal state of a stream.
23///
24/// This is exposed so that crates other than `sosistab2` itself can use the reliable-stream logic of `sosistab2`, outside the context of multiplexing streams over a `sosistab2::Multiplex`.
25///
26/// A StreamState is constructed and used in a rather particular way:
27/// - On construction, a `tick_notify` closure is passed in.
28/// - The caller must arrange so that `StreamState::tick` is called
29///     - every time `tick_notify` is called
30///     - `tick_retval` after the last tick, where `tick_retval` is the return value of the last time the state was ticked
31/// - inject_incoming is called on every incoming message
32///
33/// As long as the above holds, the `Stream` corresponding to the `StreamState`, which is returned from the `StreamState` constructor as well, will work properly.
34pub struct StreamState {
35    phase: Phase,
36    stream_id: u16,
37    additional_data: String,
38    incoming_queue: Vec<StreamMessage>,
39    queues: Arc<Mutex<StreamQueues>>,
40    local_notify: Arc<async_event::Event>,
41    tick_notify: Arc<dyn Fn() + Send + Sync + 'static>,
42
43    // read variables
44    next_unseen_seqno: u64,
45    reorderer: Reorderer<Bytes>,
46
47    // write variables
48    inflight: Inflight,
49    next_write_seqno: u64,
50    cwnd: f64,
51    ssthresh: f64,
52
53    in_recovery: bool,
54    last_write_time: Instant,
55}
56
57impl Drop for StreamState {
58    fn drop(&mut self) {
59        self.queues.lock().closed = true;
60        self.local_notify.notify_all();
61    }
62}
63
64impl StreamState {
65    /// Creates a new StreamState, in the pre-SYN-sent state. Also returns the "user-facing" handle.
66    pub fn new_pending(
67        tick_notify: impl Fn() + Send + Sync + 'static,
68        stream_id: u16,
69        label: String,
70    ) -> (Self, Stream) {
71        Self::new_in_phase(tick_notify, stream_id, Phase::Pending, label)
72    }
73
74    /// Creates a new StreamState, in the established state. Also returns the "user-facing" handle.
75    pub fn new_established(
76        tick_notify: impl Fn() + Send + Sync + 'static,
77        stream_id: u16,
78        label: String,
79    ) -> (Self, Stream) {
80        Self::new_in_phase(tick_notify, stream_id, Phase::Established, label)
81    }
82
83    /// Creates a new StreamState, in the specified state. Also returns the "user-facing" handle.
84    fn new_in_phase(
85        tick_notify: impl Fn() + Send + Sync + 'static,
86        stream_id: u16,
87        phase: Phase,
88        label: String,
89    ) -> (Self, Stream) {
90        let queues = Arc::new(Mutex::new(StreamQueues::default()));
91        let ready = Arc::new(async_event::Event::new());
92        let tick_notify: Arc<dyn Fn() + Send + Sync + 'static> = Arc::new(tick_notify);
93        let handle = Stream::new(
94            clone!([tick_notify], move || tick_notify()),
95            ready.clone(),
96            queues.clone(),
97            label.clone().into(),
98        );
99
100        static START: Lazy<Instant> = Lazy::new(Instant::now);
101        let state = Self {
102            phase,
103            stream_id,
104            incoming_queue: Default::default(),
105            queues,
106            local_notify: ready,
107
108            next_unseen_seqno: 0,
109            reorderer: Reorderer::default(),
110            inflight: Inflight::new(),
111            next_write_seqno: 0,
112            cwnd: 4.0,
113            ssthresh: 0.0,
114            tick_notify,
115
116            in_recovery: false,
117
118            additional_data: label,
119            last_write_time: *START,
120        };
121        (state, handle)
122    }
123
124    /// Injects an incoming message.
125    pub fn inject_incoming(&mut self, msg: StreamMessage) {
126        self.incoming_queue.push(msg);
127        (self.tick_notify)();
128    }
129
130    /// "Ticks" this StreamState, which advances its state. Any outgoing messages generated are passed to the callback given. Returns the correct time to call tick again at --- but if tick_notify, passed in during construction, fires, the stream must be ticked again.
131    ///
132    /// Returns None if the correct option is to delete the whole thing.
133    pub fn tick(&mut self, mut outgoing_callback: impl FnMut(StreamMessage)) -> Option<Instant> {
134        log::trace!("ticking {} at {:?}", self.stream_id, self.phase);
135
136        let now: Instant = Instant::now();
137
138        match self.phase {
139            Phase::Pending => {
140                // send a SYN, and transition into SynSent
141                outgoing_callback(StreamMessage::Reliable {
142                    kind: RelKind::Syn,
143                    stream_id: self.stream_id,
144                    seqno: 0,
145                    payload: Bytes::copy_from_slice(self.additional_data.as_bytes()),
146                });
147                let next_resend = now + Duration::from_secs(1);
148                self.phase = Phase::SynSent { next_resend };
149                Some(next_resend)
150            }
151            Phase::SynSent { next_resend } => {
152                if self.incoming_queue.drain(..).any(|msg| {
153                    matches!(
154                        msg,
155                        StreamMessage::Reliable {
156                            kind: RelKind::SynAck,
157                            stream_id: _,
158                            seqno: _,
159                            payload: _
160                        }
161                    )
162                }) {
163                    self.phase = Phase::Established;
164                    self.queues.lock().connected = true;
165                    self.local_notify.notify_all();
166                    Some(now)
167                } else if now >= next_resend {
168                    outgoing_callback(StreamMessage::Reliable {
169                        kind: RelKind::Syn,
170                        stream_id: self.stream_id,
171                        seqno: 0,
172                        payload: Bytes::copy_from_slice(self.additional_data.as_bytes()),
173                    });
174                    let next_resend = now + Duration::from_secs(1);
175                    self.phase = Phase::SynSent { next_resend };
176                    Some(next_resend)
177                } else {
178                    Some(next_resend)
179                }
180            }
181            Phase::Established => {
182                // First, handle receiving packets. This is the easier part.
183                self.tick_read(now, &mut outgoing_callback);
184                // Then, handle sending packets. This involves congestion control, so it's the harder part.
185                self.tick_write(now, &mut outgoing_callback);
186                // If closed, then die
187                if self.queues.lock().closed {
188                    self.phase = Phase::Closed;
189                }
190                // Finally, calculate the next interval.
191                Some(self.retick_time(now))
192            }
193            Phase::Closed => {
194                self.queues.lock().closed = true;
195                self.local_notify.notify_all();
196                for _ in self.incoming_queue.drain(..) {
197                    outgoing_callback(StreamMessage::Reliable {
198                        kind: RelKind::Rst,
199                        stream_id: self.stream_id,
200                        seqno: 0,
201                        payload: Default::default(),
202                    });
203                }
204                None
205            }
206        }
207    }
208
209    fn tick_read(&mut self, _now: Instant, mut outgoing_callback: impl FnMut(StreamMessage)) {
210        // Put all incoming packets into the reorderer.
211        let mut to_ack = vec![];
212        // log::debug!("processing incoming queue of {}", self.incoming_queue.len());
213        for packet in self.incoming_queue.drain(..) {
214            // If the receive queue is too large, then we pretend like we don't see anything. The sender will eventually retransmit.
215            // This unifies flow control with congestion control at the cost of a bit of efficiency.
216            if self.queues.lock().read_stream.len() > 10_000_000 {
217                continue;
218            }
219
220            match packet {
221                StreamMessage::Reliable {
222                    kind: RelKind::Data,
223                    stream_id,
224                    seqno,
225                    payload,
226                } => {
227                    log::trace!("incoming seqno {stream_id}/{seqno}");
228                    if self.reorderer.insert(seqno, payload) {
229                        to_ack.push(seqno);
230                    }
231                }
232                StreamMessage::Reliable {
233                    kind: RelKind::DataAck,
234                    stream_id: _,
235                    seqno: lowest_unseen_seqno, // *one greater* than the last packet that got to the other side
236                    payload: selective_acks,
237                } => {
238                    // mark every packet whose seqno is less than the given seqno as acked.
239                    let mut ack_count = self.inflight.mark_acked_lt(lowest_unseen_seqno);
240                    // then, we interpret the payload as a vector of acks that should additionally be taken care of.
241                    if let Ok(sacks) = stdcode::deserialize::<Vec<u64>>(&selective_acks) {
242                        for sack in sacks {
243                            if self.inflight.mark_acked(sack) {
244                                ack_count += 1;
245                            }
246                        }
247                    }
248
249                    // use BIC
250                    for _ in 0..ack_count {
251                        let bic_inc = if self.cwnd < self.ssthresh {
252                            (self.ssthresh - self.cwnd) / 2.0
253                        } else {
254                            self.cwnd - self.ssthresh
255                        }
256                        .max(1.0)
257                        .min(50.0)
258                        .min(self.cwnd);
259                        self.cwnd += bic_inc / self.cwnd;
260                    }
261
262                    log::debug!(
263                        "ack_count = {ack_count}; send window {}; cwnd {:.1}; bdp {}; write queue {}",
264                        self.inflight.inflight(),
265                        self.cwnd,
266                        self.inflight.bdp(),
267                        self.queues.lock().write_stream.len()
268                    );
269                    self.local_notify.notify_all();
270                }
271                StreamMessage::Reliable {
272                    kind: RelKind::Syn,
273                    stream_id,
274                    seqno,
275                    payload,
276                } => {
277                    // retransmit our syn-ack
278                    outgoing_callback(StreamMessage::Reliable {
279                        kind: RelKind::SynAck,
280                        stream_id,
281                        seqno,
282                        payload,
283                    });
284                }
285                StreamMessage::Reliable {
286                    kind: RelKind::Rst | RelKind::Fin,
287                    stream_id: _,
288                    seqno: _,
289                    payload: _,
290                } => {
291                    self.phase = Phase::Closed;
292                }
293                StreamMessage::Unreliable {
294                    stream_id: _,
295                    payload,
296                } => {
297                    self.queues.lock().recv_urel.push_back(payload);
298                    self.local_notify.notify_all();
299                }
300                _ => log::warn!("discarding out-of-turn packet {:?}", packet),
301            }
302        }
303        // Then, drain the reorderer
304        for (seqno, packet) in self.reorderer.take() {
305            self.next_unseen_seqno = seqno + 1;
306            self.queues.lock().read_stream.write_all(&packet).unwrap();
307        }
308
309        // Then, generate an ack.
310        if !to_ack.is_empty() {
311            self.local_notify.notify_all();
312            to_ack.retain(|a| a >= &self.next_unseen_seqno);
313            outgoing_callback(StreamMessage::Reliable {
314                kind: RelKind::DataAck,
315                stream_id: self.stream_id,
316                seqno: self.next_unseen_seqno,
317                payload: to_ack.stdcode().into(),
318            });
319        }
320    }
321
322    fn start_recovery(&mut self) {
323        if !self.in_recovery {
324            log::debug!("*** START RECOVRY AT CWND = {}", self.cwnd);
325
326            // BIC
327            let beta = 0.15;
328            if self.cwnd < self.ssthresh {
329                self.ssthresh = self.cwnd * (2.0 - beta) / 2.0;
330            } else {
331                self.ssthresh = self.cwnd;
332            }
333
334            self.cwnd *= 1.0 - beta;
335            self.cwnd = self.cwnd.max(1.0);
336
337            self.in_recovery = true;
338        }
339    }
340
341    fn stop_recovery(&mut self) {
342        self.in_recovery = false;
343    }
344
345    fn congested(&self, now: Instant) -> bool {
346        self.inflight.inflight() - self.inflight.lost_at(now) >= self.cwnd as usize
347    }
348
349    fn tick_write(&mut self, now: Instant, mut outgoing_callback: impl FnMut(StreamMessage)) {
350        log::trace!("tick_write for {}", self.stream_id);
351        // we first handle unreliable datagrams
352        {
353            let mut queues = self.queues.lock();
354            while let Some(payload) = queues.send_urel.pop_front() {
355                outgoing_callback(StreamMessage::Unreliable {
356                    stream_id: self.stream_id,
357                    payload,
358                });
359            }
360        }
361
362        if self.inflight.lost_at(now) > 0 {
363            self.start_recovery();
364        } else {
365            self.stop_recovery();
366        }
367
368        // speed here is calculated based on the idea that we should be able to transmit a whole cwnd of things in an rtt.
369        let speed = self.speed();
370        let mut writes_allowed = (now
371            .saturating_duration_since(self.last_write_time)
372            .as_secs_f64()
373            * speed) as usize;
374
375        while !self.congested(now) && writes_allowed > 0 {
376            // we do any retransmissions if necessary
377            if let Some((seqno, retrans_time)) = self.inflight.first_rto() {
378                if now >= retrans_time {
379                    log::debug!(
380                        "inflight = {}, lost = {}, cwnd = {}",
381                        self.inflight.inflight(),
382                        self.inflight.lost_at(now),
383                        self.cwnd
384                    );
385                    log::debug!("*** retransmit {}", seqno);
386                    let first = self.inflight.retransmit(seqno).expect("no first");
387                    writes_allowed -= 1;
388                    log::debug!("RETRANSMIT {seqno} at {:.2} pkts/s", speed);
389                    outgoing_callback(first);
390                    continue;
391                }
392            }
393
394            // okay, we don't have retransmissions. this means we get to send a "normal" packet.
395            let mut queues = self.queues.lock();
396            if !queues.write_stream.is_empty() {
397                let mut buffer = vec![0; MSS];
398                let n = queues.write_stream.read(&mut buffer).unwrap();
399                buffer.truncate(n);
400                let seqno = self.next_write_seqno;
401                self.next_write_seqno += 1;
402                let msg = StreamMessage::Reliable {
403                    kind: RelKind::Data,
404                    stream_id: self.stream_id,
405                    seqno,
406                    payload: buffer.into(),
407                };
408                self.inflight.insert(msg.clone());
409                self.local_notify.notify_all();
410
411                outgoing_callback(msg);
412                self.last_write_time = now;
413                writes_allowed -= 1;
414                log::debug!("{seqno} at {:.2} pkts/s", speed);
415                continue;
416            } else {
417                queues.write_stream.shrink_to_fit();
418            }
419
420            break;
421        }
422    }
423
424    fn speed(&self) -> f64 {
425        (self.cwnd / self.inflight.min_rtt().as_secs_f64()).max(1.0)
426    }
427
428    fn retick_time(&self, now: Instant) -> Instant {
429        let idle = { self.inflight.inflight() == 0 && self.queues.lock().write_stream.is_empty() };
430
431        if idle {
432            now + Duration::from_secs(100000)
433        } else {
434            now + Duration::from_secs_f64((1.0 / self.speed()).max(0.02)) // max 50Hz
435        }
436    }
437}
438
439#[derive(Clone, Copy, Debug)]
440enum Phase {
441    Pending,
442    SynSent { next_resend: Instant },
443    Established,
444    Closed,
445}