1use std::collections::HashMap;
2use std::sync::Arc;
3
4use anyhow::{bail, Result};
5use iroh::endpoint::{Connection, VarInt};
6use iroh_quinn_proto::coding::Codec;
7use n0_future::task::{self, AbortOnDropHandle, JoinSet};
8use tokio::io::{AsyncRead, AsyncReadExt};
9use tokio::sync::{mpsc, Mutex};
10use tokio_util::bytes::{Bytes, BytesMut};
11use tokio_util::sync::CancellationToken;
12use tracing::{debug, error, trace, warn};
13
14use crate::receive_flow::ReceiveFlow;
15use crate::send_flow::SendFlow;
16
17#[derive(Debug, Clone)]
19pub struct Session {
20 conn: Connection,
21 cancel_token: CancellationToken,
22 send_flows: Arc<Mutex<HashMap<VarInt, SendFlow>>>,
23 receive_flows: Arc<Mutex<HashMap<VarInt, ReceiveFlowSender>>>,
24 _task: Arc<AbortOnDropHandle<()>>,
25}
26
27#[derive(Debug)]
28struct ReceiveFlowSender {
29 sender: mpsc::Sender<Bytes>,
30 incoming_flow: Option<ReceiveFlow>,
32 cancel_token: CancellationToken,
33}
34
35const RECV_FLOW_BUFFER: usize = 64;
37
38impl Session {
39 pub fn new(conn: Connection) -> Self {
41 let receive_flows = Arc::new(Mutex::new(HashMap::new()));
42 let cancel_token = CancellationToken::new();
43 let fut = run(conn.clone(), cancel_token.clone(), receive_flows.clone());
44 let task = AbortOnDropHandle::new(task::spawn(fut));
45
46 Self {
47 conn,
48 cancel_token,
49 send_flows: Default::default(),
50 receive_flows,
51 _task: Arc::new(task),
52 }
53 }
54
55 pub async fn new_send_flow(&self, id: VarInt) -> Result<SendFlow> {
57 let mut send_flows = self.send_flows.lock().await;
58 if send_flows.contains_key(&id) {
59 bail!("duplicated flow ID: {}", id);
60 }
61
62 let flow = SendFlow::new(self.conn.clone(), id, self.cancel_token.child_token());
63 send_flows.insert(id, flow.clone());
64
65 Ok(flow)
66 }
67
68 pub async fn new_receive_flow(&self, id: VarInt) -> Result<ReceiveFlow> {
72 let mut receive_flows = self.receive_flows.lock().await;
73 if let Some(flow) = receive_flows.get_mut(&id) {
74 if let Some(receiver) = flow.incoming_flow.take() {
75 debug!(flow_id = %id, "found incoming flow");
76 return Ok(receiver);
77 } else {
78 bail!("duplicated flow ID: {}", id);
79 }
80 }
81
82 let (s, r) = mpsc::channel(RECV_FLOW_BUFFER);
83 let cancel_token = self.cancel_token.child_token();
84 let flow = ReceiveFlow::new(id, r, cancel_token.clone());
85 receive_flows.insert(
86 id,
87 ReceiveFlowSender {
88 sender: s,
89 incoming_flow: None,
90 cancel_token,
91 },
92 );
93
94 Ok(flow)
95 }
96}
97
98async fn run(
99 conn: Connection,
100 cancel_token: CancellationToken,
101 receive_flows: Arc<Mutex<HashMap<VarInt, ReceiveFlowSender>>>,
102) {
103 let mut tasks = JoinSet::new();
104
105 loop {
106 tokio::select! {
107 biased;
108
109 _ = cancel_token.cancelled() => {
110 debug!("shutting down");
111 break;
112 }
113 Some(res) = tasks.join_next() => {
114 match res {
115 Err(outer) => {
116 if outer.is_panic() {
117 error!("Task panicked: {outer:?}");
118 break;
119 } else if outer.is_cancelled() {
120 trace!("Task cancelled: {outer:?}");
121 } else {
122 error!("Task failed: {outer:?}");
123 break;
124 }
125 }
126 Ok(()) => {
127 trace!("Task finished");
128 }
129 }
130 },
131
132 uni_stream = conn.accept_uni() => {
133 match uni_stream {
134 Ok(mut recv) => {
135 let token = cancel_token.child_token();
136 let rf = receive_flows.clone();
137 tasks.spawn(async move {
138 let sub_token = token.child_token();
139 token.run_until_cancelled(async move {
140 let Ok(flow_id) = read_varint(&mut recv).await else {
142 warn!("failed to read from stream");
143 return;
144 };
145 debug!(%flow_id, "incoming send flow");
146
147 let mut flows = rf.lock().await;
148 let sender = if let Some(flow) = flows.get(&flow_id) {
149 debug!(%flow_id, "found existing recv flow");
150 if flow.cancel_token.is_cancelled() {
151 flows.remove(&flow_id);
152 debug!(%flow_id, "cleaning up closed recv flow");
153 return;
154 } else {
155 flow.sender.clone()
156 }
157 } else {
158 debug!(%flow_id, "creating new recv flow");
160 let (s, r) = mpsc::channel(RECV_FLOW_BUFFER);
161 let cancel_token = sub_token.child_token();
162 let flow = ReceiveFlow::new(flow_id, r, cancel_token.clone());
163 flows.insert(flow_id, ReceiveFlowSender {
164 sender: s.clone(),
165 incoming_flow: Some(flow),
166 cancel_token,
167 });
168 s
169 };
170 drop(flows);
171
172 const MAX_PACKET_SIZE: u64 = 1024 * 1024 * 64; loop {
174 let len = match read_varint(&mut recv).await {
175 Ok(len) => len.into_inner(),
176 Err(err) => {
177 warn!("failed to read: {:?}", err);
178 break;
179 }
180 };
181 if len > MAX_PACKET_SIZE {
182 warn!("packet too large {}", len);
183 break;
184 }
185 let mut buffer = BytesMut::zeroed(len as usize);
186 match recv.read_exact(&mut buffer).await {
187 Ok(()) => {
188 sender.send(buffer.freeze()).await.ok();
189 }
190 Err(err) => {
191 warn!("failed to read: {:?}", err);
192 break;
193 }
194 }
195 }
196 }).await;
197 });
198 }
199 Err(err) => {
200 warn!("connection terminated: {:?}", err);
201 break;
202 }
203 }
204 }
205 datagram = conn.read_datagram() => {
206 match datagram {
208 Ok(mut bytes) => {
209 debug!("received datagram: {} bytes", bytes.len());
210 let Ok(flow_id) = VarInt::decode(&mut bytes) else {
211 warn!("invalid flow id");
212 continue;
213 };
214 let mut flows = receive_flows.lock().await;
215 if let Some(flow) = flows.get(&flow_id) {
216 debug!(%flow_id, "found existing recv flow");
217 if flow.cancel_token.is_cancelled() {
218
219 flows.remove(&flow_id);
220 debug!(%flow_id, "cleaning up closed recv flow");
221 } else if let Err(err) = flow.sender.send(bytes).await {
222 warn!(%flow_id, "failed to send to receiver: {:?}", err);
223 }
224 } else {
225 debug!(%flow_id, "creating new recv flow");
227 let (s, r) = mpsc::channel(RECV_FLOW_BUFFER);
228 let cancel_token = cancel_token.child_token();
229 let flow = ReceiveFlow::new(flow_id, r, cancel_token.clone());
230 s.send(bytes).await.expect("just created");
232 flows.insert(flow_id, ReceiveFlowSender {
233 sender: s,
234 incoming_flow: Some(flow),
235 cancel_token,
236 });
237 }
238 }
239 Err(err) => {
240 warn!("connection terminated: {:?}", err);
241 break;
242 }
243 }
244 }
245 }
246 }
247}
248
249async fn read_varint<R: AsyncRead + Unpin>(conn: &mut R) -> Result<VarInt> {
251 let mut buf = [0u8; VarInt::MAX_SIZE];
252
253 conn.read_exact(&mut buf[..1]).await?;
254 let tag = buf[0] >> 6;
255 buf[0] &= 0b0011_1111;
256
257 let x = match tag {
258 0b00 => u64::from(buf[0]),
259 0b01 => {
260 conn.read_exact(&mut buf[1..2]).await?;
261 u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap()))
262 }
263 0b10 => {
264 conn.read_exact(&mut buf[1..4]).await?;
265 u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap()))
266 }
267 0b11 => {
268 conn.read_exact(&mut buf[1..8]).await?;
269 u64::from_be_bytes(buf)
270 }
271 _ => unreachable!(),
272 };
273
274 let x = VarInt::from_u64(x)?;
275 Ok(x)
276}
277
278#[cfg(test)]
279mod tests {
280 use iroh::Endpoint;
281 use rtp::packet::Packet as RtpPacket;
282
283 use crate::ALPN;
284
285 use super::*;
286
287 #[tokio::test]
288 async fn test_datagram_flow() -> Result<()> {
289 let ep1 = Endpoint::builder()
290 .bind_addr_v4("127.0.0.1:0".parse().unwrap())
291 .alpns(vec![ALPN.to_vec()])
292 .bind()
293 .await?;
294 let ep2 = Endpoint::builder()
295 .bind_addr_v4("127.0.0.1:0".parse().unwrap())
296 .alpns(vec![ALPN.to_vec()])
297 .bind()
298 .await?;
299
300 let flow_id = VarInt::from_u32(0);
301
302 let ep2_addr = ep2.node_addr().await?;
303
304 let _handle = task::spawn(async move {
305 while let Some(incoming) = ep2.accept().await {
306 if let Ok(connection) = incoming.await {
307 assert_eq!(connection.alpn().unwrap(), ALPN, "invalid ALPN");
308
309 let session = Session::new(connection);
310 let send_flow = session.new_send_flow(flow_id).await.unwrap();
311 let mut recv_flow = session.new_receive_flow(flow_id).await.unwrap();
312
313 while let Ok(packet) = recv_flow.read_rtp().await {
315 send_flow.send_rtp(&packet).unwrap();
316 }
317 }
318 }
319 });
320
321 let conn = ep1.connect(ep2_addr, ALPN).await?;
322
323 let session = Session::new(conn);
324 let send_flow = session.new_send_flow(flow_id).await.unwrap();
325 let mut recv_flow = session.new_receive_flow(flow_id).await.unwrap();
326
327 for i in 0u8..10 {
328 let packet = RtpPacket {
329 header: rtp::header::Header::default(),
330 payload: vec![i; 10].into(),
331 };
332
333 send_flow.send_rtp(&packet)?;
334 let incoming = recv_flow.read_rtp().await?;
335 assert_eq!(packet, incoming);
336 }
337
338 Ok(())
339 }
340
341 #[tokio::test]
342 async fn test_session_flow() -> Result<()> {
343 let ep1 = Endpoint::builder()
344 .bind_addr_v4("127.0.0.1:0".parse().unwrap())
345 .alpns(vec![ALPN.to_vec()])
346 .bind()
347 .await?;
348 let ep2 = Endpoint::builder()
349 .bind_addr_v4("127.0.0.1:0".parse().unwrap())
350 .alpns(vec![ALPN.to_vec()])
351 .bind()
352 .await?;
353
354 let flow_id = VarInt::from_u32(0);
355
356 let ep2_addr = ep2.node_addr().await?;
357
358 let _handle = task::spawn(async move {
359 while let Some(incoming) = ep2.accept().await {
360 if let Ok(connection) = incoming.await {
361 assert_eq!(connection.alpn().unwrap(), ALPN, "invalid ALPN");
362
363 let session = Session::new(connection);
364 let send_flow = session.new_send_flow(flow_id).await.unwrap();
365 let mut recv_flow = session.new_receive_flow(flow_id).await.unwrap();
366
367 while let Ok(packet) = recv_flow.read_rtp().await {
369 send_flow.send_rtp(&packet).unwrap();
370 }
371 }
372 }
373 });
374
375 let conn = ep1.connect(ep2_addr, ALPN).await?;
376
377 let session = Session::new(conn);
378 let send_flow = session.new_send_flow(flow_id).await.unwrap();
379 let mut send_stream = send_flow.new_send_stream().await?;
380 let mut recv_flow = session.new_receive_flow(flow_id).await.unwrap();
381
382 for i in 0u8..10 {
383 let packet = RtpPacket {
384 header: rtp::header::Header::default(),
385 payload: vec![i; 10].into(),
386 };
387
388 send_stream.send_rtp(&packet).await?;
389 let incoming = recv_flow.read_rtp().await?;
390 assert_eq!(packet, incoming);
391 }
392
393 Ok(())
394 }
395}