eznet/
reader.rs

1use crate::{
2    packet::{Packet, PacketHeader},
3    unwrap_or,
4};
5use bytes::{Bytes, BytesMut};
6use futures::{stream::SelectAll, StreamExt};
7use quinn::{ConnectionError, Datagrams, IncomingUniStreams, RecvStream};
8use std::{collections::HashMap, io::Error};
9use tokio::sync::{broadcast, mpsc};
10use tokio_util::codec::{FramedRead, LengthDelimitedCodec};
11
12//
13
14pub async fn reader_worker_job(
15    mut uni_streams: IncomingUniStreams,
16    mut datagrams: Datagrams,
17    mut send: mpsc::Sender<Packet>,
18    mut should_stop: broadcast::Receiver<()>,
19) {
20    let mut recv_streams = SelectAll::new();
21
22    let mut reliable_seq: HashMap<Option<u8>, u16> = Default::default();
23    let mut unreliable_seq: HashMap<Option<u8>, u16> = Default::default();
24
25    loop {
26        let new_stream = async {
27            uni_streams
28                .next()
29                .await
30                .map(|s| s.map(|s| FramedRead::new(s, LengthDelimitedCodec::default())))
31        };
32
33        let old_stream = recv_streams.next();
34
35        let datagram_stream = datagrams.next();
36
37        if tokio::select! {
38            stream = new_stream => handle_new_stream(stream, &mut recv_streams),
39            Some(bytes) = old_stream => handle_old_stream(bytes, &mut send, &mut reliable_seq, &mut unreliable_seq).await,
40            bytes = datagram_stream => handle_datagram(bytes, &mut send, &mut reliable_seq, &mut unreliable_seq).await,
41            _ = should_stop.recv() => true,
42        } {
43            break;
44        };
45    }
46
47    log::debug!("Reader worker stopped");
48}
49
50// returns true if reader should stop
51fn handle_new_stream(
52    stream: Option<Result<FRead, ConnectionError>>,
53    recv_streams: &mut SelectAll<FRead>,
54) -> bool {
55    let stream = stream.ok_or("Empty new stream");
56
57    let stream = unwrap_or!(stream, {
58        return true;
59    });
60
61    let stream = unwrap_or!(stream, {
62        return true;
63    });
64
65    recv_streams.push(stream);
66    false
67}
68
69// returns true if reader should stop
70async fn handle_old_stream(
71    bytes: Result<BytesMut, Error>,
72    send: &mut mpsc::Sender<Packet>,
73    reliable_seq: &mut HashMap<Option<u8>, u16>,
74    unreliable_seq: &mut HashMap<Option<u8>, u16>,
75) -> bool {
76    let packet = bytes.map(|b| bincode::deserialize(&b[..]));
77
78    let packet = unwrap_or!(packet, {
79        return true;
80    });
81
82    let packet = unwrap_or!(packet, {
83        return true;
84    });
85
86    if let Some(packet) = drop_sequenced(packet, reliable_seq, unreliable_seq) {
87        send.send(packet).await.is_err()
88    } else {
89        false
90    }
91}
92
93// returns true if reader should stop
94async fn handle_datagram(
95    bytes: Option<Result<Bytes, ConnectionError>>,
96    send: &mut mpsc::Sender<Packet>,
97    reliable_seq: &mut HashMap<Option<u8>, u16>,
98    unreliable_seq: &mut HashMap<Option<u8>, u16>,
99) -> bool {
100    let packet = bytes
101        .ok_or("Empty datagram")
102        .map(|b| b.map(|b| bincode::deserialize(&b[..])));
103
104    let packet = unwrap_or!(packet, {
105        return true;
106    });
107
108    let packet = unwrap_or!(packet, {
109        return true;
110    });
111
112    let packet = unwrap_or!(packet, {
113        return true;
114    });
115
116    if let Some(packet) = drop_sequenced(packet, reliable_seq, unreliable_seq) {
117        send.send(packet).await.is_err()
118    } else {
119        false
120    }
121}
122
123fn drop_sequenced(
124    packet: Packet,
125    reliable_seq: &mut HashMap<Option<u8>, u16>,
126    unreliable_seq: &mut HashMap<Option<u8>, u16>,
127) -> Option<Packet> {
128    match packet.header {
129        PacketHeader::ReliableSequenced { stream_id, seq_id } => {
130            drop_sequenced_common(stream_id, seq_id, reliable_seq)
131        }
132        PacketHeader::UnreliableSequenced { stream_id, seq_id } => {
133            drop_sequenced_common(stream_id, seq_id, unreliable_seq)
134        }
135        _ => true,
136    }
137    .then_some(packet)
138}
139
140fn drop_sequenced_common(
141    stream_id: Option<u8>,
142    seq_id: u16,
143    seq: &mut HashMap<Option<u8>, u16>,
144) -> bool {
145    let recv_seq_id = seq.entry(stream_id).or_insert(0);
146    let send_seq_id = seq_id;
147
148    // convert them to a form where it is comparable
149    let rsi = u16::MAX / 2 - 1;
150    let ssi = ((send_seq_id as i32 - *recv_seq_id as i32).rem_euclid(u16::MAX as i32) as u16)
151        .wrapping_add(u16::MAX / 2);
152
153    if cfg!(test) {
154        dbg!(&recv_seq_id);
155        dbg!(&send_seq_id);
156        dbg!(&rsi);
157        dbg!(&ssi);
158    }
159
160    if ssi > rsi {
161        // got packet that is 'newer'
162        *recv_seq_id = send_seq_id;
163        true
164    } else {
165        // got packet that is 'older'
166        log::debug!("Dropping out of sequence packet");
167        false
168    }
169}
170
171//
172
173type FRead = FramedRead<RecvStream, LengthDelimitedCodec>;
174
175//
176
177#[cfg(test)]
178mod tests {
179    use crate::reader::drop_sequenced_common;
180    use std::collections::hash_map::HashMap;
181
182    #[test]
183    fn drop_sequenced_common_test_0() {
184        let mut seq = HashMap::new();
185        seq.insert(None, 0);
186
187        assert!(drop_sequenced_common(None, 1, &mut seq) == true);
188        assert!(drop_sequenced_common(None, 1, &mut seq) == false);
189        assert!(drop_sequenced_common(None, 1, &mut seq) == false);
190        assert!(drop_sequenced_common(None, 2, &mut seq) == true);
191        assert!(drop_sequenced_common(None, 2, &mut seq) == false);
192        assert!(drop_sequenced_common(None, 2, &mut seq) == false);
193        assert!(drop_sequenced_common(None, 200, &mut seq) == true);
194        assert!(drop_sequenced_common(None, 2, &mut seq) == false);
195        assert!(drop_sequenced_common(None, u16::MAX / 4, &mut seq) == true);
196        assert!(drop_sequenced_common(None, u16::MAX / 2, &mut seq) == true);
197        assert!(drop_sequenced_common(None, u16::MAX / 4 * 3, &mut seq) == true);
198        assert!(drop_sequenced_common(None, u16::MAX - 100, &mut seq) == true);
199        assert!(drop_sequenced_common(None, u16::MAX - 100, &mut seq) == false);
200        assert!(drop_sequenced_common(None, u16::MAX - 99, &mut seq) == true);
201        assert!(drop_sequenced_common(None, u16::MAX - 99, &mut seq) == false);
202        assert!(drop_sequenced_common(None, 0, &mut seq) == true);
203        assert!(drop_sequenced_common(None, 0, &mut seq) == false);
204    }
205}