iroh_roq/
session.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use anyhow::{bail, Result};
5use iroh::endpoint::{Connection, VarInt};
6use iroh_quinn_proto::coding::Codec;
7use n0_future::task::{self, AbortOnDropHandle, JoinSet};
8use tokio::io::{AsyncRead, AsyncReadExt};
9use tokio::sync::{mpsc, Mutex};
10use tokio_util::bytes::{Bytes, BytesMut};
11use tokio_util::sync::CancellationToken;
12use tracing::{debug, error, trace, warn};
13
14use crate::receive_flow::ReceiveFlow;
15use crate::send_flow::SendFlow;
16
17/// A RoQ session.
18#[derive(Debug, Clone)]
19pub struct Session {
20    conn: Connection,
21    cancel_token: CancellationToken,
22    send_flows: Arc<Mutex<HashMap<VarInt, SendFlow>>>,
23    receive_flows: Arc<Mutex<HashMap<VarInt, ReceiveFlowSender>>>,
24    _task: Arc<AbortOnDropHandle<()>>,
25}
26
27#[derive(Debug)]
28struct ReceiveFlowSender {
29    sender: mpsc::Sender<Bytes>,
30    /// Set to `Some` if this is a discovered flow.
31    incoming_flow: Option<ReceiveFlow>,
32    cancel_token: CancellationToken,
33}
34
35/// Buffer size of the receive flow channel.
36const RECV_FLOW_BUFFER: usize = 64;
37
38impl Session {
39    /// Creates a new session, based on an existing `Connection`.
40    pub fn new(conn: Connection) -> Self {
41        let receive_flows = Arc::new(Mutex::new(HashMap::new()));
42        let cancel_token = CancellationToken::new();
43        let fut = run(conn.clone(), cancel_token.clone(), receive_flows.clone());
44        let task = AbortOnDropHandle::new(task::spawn(fut));
45
46        Self {
47            conn,
48            cancel_token,
49            send_flows: Default::default(),
50            receive_flows,
51            _task: Arc::new(task),
52        }
53    }
54
55    /// Create a new `SendFlow`.
56    pub async fn new_send_flow(&self, id: VarInt) -> Result<SendFlow> {
57        let mut send_flows = self.send_flows.lock().await;
58        if send_flows.contains_key(&id) {
59            bail!("duplicated flow ID: {}", id);
60        }
61
62        let flow = SendFlow::new(self.conn.clone(), id, self.cancel_token.child_token());
63        send_flows.insert(id, flow.clone());
64
65        Ok(flow)
66    }
67
68    /// Creates a new receive flow.
69    ///
70    /// If a message has already been received on a flow will return that flow.
71    pub async fn new_receive_flow(&self, id: VarInt) -> Result<ReceiveFlow> {
72        let mut receive_flows = self.receive_flows.lock().await;
73        if let Some(flow) = receive_flows.get_mut(&id) {
74            if let Some(receiver) = flow.incoming_flow.take() {
75                debug!(flow_id = %id, "found incoming flow");
76                return Ok(receiver);
77            } else {
78                bail!("duplicated flow ID: {}", id);
79            }
80        }
81
82        let (s, r) = mpsc::channel(RECV_FLOW_BUFFER);
83        let cancel_token = self.cancel_token.child_token();
84        let flow = ReceiveFlow::new(id, r, cancel_token.clone());
85        receive_flows.insert(
86            id,
87            ReceiveFlowSender {
88                sender: s,
89                incoming_flow: None,
90                cancel_token,
91            },
92        );
93
94        Ok(flow)
95    }
96}
97
98async fn run(
99    conn: Connection,
100    cancel_token: CancellationToken,
101    receive_flows: Arc<Mutex<HashMap<VarInt, ReceiveFlowSender>>>,
102) {
103    let mut tasks = JoinSet::new();
104
105    loop {
106        tokio::select! {
107            biased;
108
109            _ = cancel_token.cancelled() => {
110                debug!("shutting down");
111                break;
112            }
113            Some(res) = tasks.join_next() => {
114                match res {
115                    Err(outer) => {
116                        if outer.is_panic() {
117                            error!("Task panicked: {outer:?}");
118                            break;
119                        } else if outer.is_cancelled() {
120                            trace!("Task cancelled: {outer:?}");
121                        } else {
122                            error!("Task failed: {outer:?}");
123                            break;
124                        }
125                    }
126                    Ok(()) => {
127                        trace!("Task finished");
128                    }
129                }
130            },
131
132            uni_stream = conn.accept_uni() => {
133                match uni_stream {
134                    Ok(mut recv) =>  {
135                        let token = cancel_token.child_token();
136                        let rf = receive_flows.clone();
137                        tasks.spawn(async move {
138                            let sub_token = token.child_token();
139                            token.run_until_cancelled(async move {
140                                // Read flow id
141                                let Ok(flow_id) = read_varint(&mut recv).await else {
142                                    warn!("failed to read from stream");
143                                    return;
144                                };
145                                debug!(%flow_id, "incoming send flow");
146
147                                let mut flows = rf.lock().await;
148                                let sender = if let Some(flow) = flows.get(&flow_id) {
149                                    debug!(%flow_id, "found existing recv flow");
150                                    if flow.cancel_token.is_cancelled() {
151                                        flows.remove(&flow_id);
152                                        debug!(%flow_id, "cleaning up closed recv flow");
153                                        return;
154                                    } else {
155                                        flow.sender.clone()
156                                    }
157                                } else {
158                                    // Store incoming flow to be retrieved by the user
159                                    debug!(%flow_id, "creating new recv flow");
160                                    let (s, r) = mpsc::channel(RECV_FLOW_BUFFER);
161                                    let cancel_token = sub_token.child_token();
162                                    let flow = ReceiveFlow::new(flow_id, r, cancel_token.clone());
163                                    flows.insert(flow_id, ReceiveFlowSender {
164                                        sender: s.clone(),
165                                        incoming_flow: Some(flow),
166                                        cancel_token,
167                                    });
168                                    s
169                                };
170                                drop(flows);
171
172                                const MAX_PACKET_SIZE: u64 = 1024 * 1024 * 64; // TODO: what should this be?
173                                loop {
174                                    let len = match read_varint(&mut recv).await {
175                                        Ok(len) => len.into_inner(),
176                                        Err(err) => {
177                                            warn!("failed to read: {:?}", err);
178                                            break;
179                                        }
180                                    };
181                                    if len > MAX_PACKET_SIZE {
182                                        warn!("packet too large {}", len);
183                                        break;
184                                    }
185                                    let mut buffer = BytesMut::zeroed(len as usize);
186                                    match recv.read_exact(&mut buffer).await {
187                                        Ok(()) => {
188                                            sender.send(buffer.freeze()).await.ok();
189                                        }
190                                        Err(err) => {
191                                            warn!("failed to read: {:?}", err);
192                                            break;
193                                        }
194                                    }
195                                }
196                            }).await;
197                        });
198                    }
199                    Err(err) => {
200                        warn!("connection terminated: {:?}", err);
201                        break;
202                    }
203                }
204            }
205            datagram = conn.read_datagram() => {
206                // handle datagram
207                match datagram {
208                    Ok(mut bytes) => {
209                        debug!("received datagram: {} bytes", bytes.len());
210                        let Ok(flow_id) = VarInt::decode(&mut bytes) else {
211                            warn!("invalid flow id");
212                            continue;
213                        };
214                        let mut flows = receive_flows.lock().await;
215                        if let Some(flow) = flows.get(&flow_id) {
216                            debug!(%flow_id, "found existing recv flow");
217                            if flow.cancel_token.is_cancelled() {
218
219                                flows.remove(&flow_id);
220                                debug!(%flow_id, "cleaning up closed recv flow");
221                            } else if let Err(err) = flow.sender.send(bytes).await {
222                                warn!(%flow_id, "failed to send to receiver: {:?}", err);
223                            }
224                        } else {
225                            // Store incoming flow to be retrieved by the user
226                            debug!(%flow_id, "creating new recv flow");
227                            let (s, r) = mpsc::channel(RECV_FLOW_BUFFER);
228                            let cancel_token = cancel_token.child_token();
229                            let flow = ReceiveFlow::new(flow_id, r, cancel_token.clone());
230                            // store the newly received datagram
231                            s.send(bytes).await.expect("just created");
232                            flows.insert(flow_id, ReceiveFlowSender {
233                                sender: s,
234                                incoming_flow: Some(flow),
235                                cancel_token,
236                            });
237                        }
238                    }
239                    Err(err) => {
240                        warn!("connection terminated: {:?}", err);
241                        break;
242                    }
243                }
244            }
245        }
246    }
247}
248
249/// Async read based reading of a `VarInt`.
250async fn read_varint<R: AsyncRead + Unpin>(conn: &mut R) -> Result<VarInt> {
251    let mut buf = [0u8; VarInt::MAX_SIZE];
252
253    conn.read_exact(&mut buf[..1]).await?;
254    let tag = buf[0] >> 6;
255    buf[0] &= 0b0011_1111;
256
257    let x = match tag {
258        0b00 => u64::from(buf[0]),
259        0b01 => {
260            conn.read_exact(&mut buf[1..2]).await?;
261            u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap()))
262        }
263        0b10 => {
264            conn.read_exact(&mut buf[1..4]).await?;
265            u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap()))
266        }
267        0b11 => {
268            conn.read_exact(&mut buf[1..8]).await?;
269            u64::from_be_bytes(buf)
270        }
271        _ => unreachable!(),
272    };
273
274    let x = VarInt::from_u64(x)?;
275    Ok(x)
276}
277
278#[cfg(test)]
279mod tests {
280    use iroh::Endpoint;
281    use rtp::packet::Packet as RtpPacket;
282
283    use crate::ALPN;
284
285    use super::*;
286
287    #[tokio::test]
288    async fn test_datagram_flow() -> Result<()> {
289        let ep1 = Endpoint::builder()
290            .bind_addr_v4("127.0.0.1:0".parse().unwrap())
291            .alpns(vec![ALPN.to_vec()])
292            .bind()
293            .await?;
294        let ep2 = Endpoint::builder()
295            .bind_addr_v4("127.0.0.1:0".parse().unwrap())
296            .alpns(vec![ALPN.to_vec()])
297            .bind()
298            .await?;
299
300        let flow_id = VarInt::from_u32(0);
301
302        let ep2_addr = ep2.node_addr().await?;
303
304        let _handle = task::spawn(async move {
305            while let Some(incoming) = ep2.accept().await {
306                if let Ok(connection) = incoming.await {
307                    assert_eq!(connection.alpn().unwrap(), ALPN, "invalid ALPN");
308
309                    let session = Session::new(connection);
310                    let send_flow = session.new_send_flow(flow_id).await.unwrap();
311                    let mut recv_flow = session.new_receive_flow(flow_id).await.unwrap();
312
313                    // echo
314                    while let Ok(packet) = recv_flow.read_rtp().await {
315                        send_flow.send_rtp(&packet).unwrap();
316                    }
317                }
318            }
319        });
320
321        let conn = ep1.connect(ep2_addr, ALPN).await?;
322
323        let session = Session::new(conn);
324        let send_flow = session.new_send_flow(flow_id).await.unwrap();
325        let mut recv_flow = session.new_receive_flow(flow_id).await.unwrap();
326
327        for i in 0u8..10 {
328            let packet = RtpPacket {
329                header: rtp::header::Header::default(),
330                payload: vec![i; 10].into(),
331            };
332
333            send_flow.send_rtp(&packet)?;
334            let incoming = recv_flow.read_rtp().await?;
335            assert_eq!(packet, incoming);
336        }
337
338        Ok(())
339    }
340
341    #[tokio::test]
342    async fn test_session_flow() -> Result<()> {
343        let ep1 = Endpoint::builder()
344            .bind_addr_v4("127.0.0.1:0".parse().unwrap())
345            .alpns(vec![ALPN.to_vec()])
346            .bind()
347            .await?;
348        let ep2 = Endpoint::builder()
349            .bind_addr_v4("127.0.0.1:0".parse().unwrap())
350            .alpns(vec![ALPN.to_vec()])
351            .bind()
352            .await?;
353
354        let flow_id = VarInt::from_u32(0);
355
356        let ep2_addr = ep2.node_addr().await?;
357
358        let _handle = task::spawn(async move {
359            while let Some(incoming) = ep2.accept().await {
360                if let Ok(connection) = incoming.await {
361                    assert_eq!(connection.alpn().unwrap(), ALPN, "invalid ALPN");
362
363                    let session = Session::new(connection);
364                    let send_flow = session.new_send_flow(flow_id).await.unwrap();
365                    let mut recv_flow = session.new_receive_flow(flow_id).await.unwrap();
366
367                    // echo
368                    while let Ok(packet) = recv_flow.read_rtp().await {
369                        send_flow.send_rtp(&packet).unwrap();
370                    }
371                }
372            }
373        });
374
375        let conn = ep1.connect(ep2_addr, ALPN).await?;
376
377        let session = Session::new(conn);
378        let send_flow = session.new_send_flow(flow_id).await.unwrap();
379        let mut send_stream = send_flow.new_send_stream().await?;
380        let mut recv_flow = session.new_receive_flow(flow_id).await.unwrap();
381
382        for i in 0u8..10 {
383            let packet = RtpPacket {
384                header: rtp::header::Header::default(),
385                payload: vec![i; 10].into(),
386            };
387
388            send_stream.send_rtp(&packet).await?;
389            let incoming = recv_flow.read_rtp().await?;
390            assert_eq!(packet, incoming);
391        }
392
393        Ok(())
394    }
395}