1use std::{future::Future, pin::Pin, sync::Arc, time::Duration};
2
3use web_transport_trait::Stats;
4
5use crate::{BandwidthConsumer, BandwidthProducer, Error, Version};
6
7#[derive(Clone)]
13pub struct Session {
14 session: Arc<dyn SessionInner>,
15 version: Version,
16 send_bandwidth: Option<BandwidthConsumer>,
17 recv_bandwidth: Option<BandwidthConsumer>,
18 closed: bool,
19}
20
21impl Session {
22 pub(super) fn new<S: web_transport_trait::Session>(
23 session: S,
24 version: Version,
25 recv_bandwidth: Option<BandwidthConsumer>,
26 ) -> Self {
27 let send_bandwidth = if session.stats().estimated_send_rate().is_some() {
29 let producer = BandwidthProducer::new();
30 let consumer = producer.consume();
31
32 let session = session.clone();
33 web_async::spawn(async move {
34 run_send_bandwidth(&session, producer).await;
35 });
36
37 Some(consumer)
38 } else {
39 None
40 };
41
42 Self {
43 session: Arc::new(session),
44 version,
45 send_bandwidth,
46 recv_bandwidth,
47 closed: false,
48 }
49 }
50
51 pub fn version(&self) -> Version {
53 self.version
54 }
55
56 pub fn send_bandwidth(&self) -> Option<BandwidthConsumer> {
60 self.send_bandwidth.clone()
61 }
62
63 pub fn recv_bandwidth(&self) -> Option<BandwidthConsumer> {
67 self.recv_bandwidth.clone()
68 }
69
70 pub fn close(&mut self, err: Error) {
72 if self.closed {
73 return;
74 }
75 self.closed = true;
76 self.session.close(err.to_code(), err.to_string().as_ref());
77 }
78
79 pub async fn closed(&self) -> Result<(), Error> {
81 let err = self.session.closed().await;
82 Err(Error::Transport(err))
83 }
84}
85
86impl Drop for Session {
87 fn drop(&mut self) {
88 if !self.closed {
89 self.session.close(Error::Cancel.to_code(), "dropped");
90 }
91 }
92}
93
94async fn run_send_bandwidth<S: web_transport_trait::Session>(session: &S, producer: BandwidthProducer) {
99 tokio::select! {
100 _ = session.closed() => {}
101 _ = producer.closed() => {}
102 _ = run_send_bandwidth_inner(session, &producer) => {}
103 }
104}
105
106async fn run_send_bandwidth_inner<S: web_transport_trait::Session>(session: &S, producer: &BandwidthProducer) {
109 const POLL_INTERVAL: Duration = Duration::from_millis(100);
110
111 loop {
112 if producer.used().await.is_err() {
113 return;
114 }
115
116 let mut interval = tokio::time::interval(POLL_INTERVAL);
117 loop {
118 tokio::select! {
119 biased;
120 res = producer.unused() => {
121 if res.is_err() {
122 return;
123 }
124 break;
126 }
127 _ = interval.tick() => {
128 let bitrate = session.stats().estimated_send_rate();
129 if producer.set(bitrate).is_err() {
130 return;
131 }
132 }
133 }
134 }
135 }
136}
137
138trait SessionInner: Send + Sync {
140 fn close(&self, code: u32, reason: &str);
141 fn closed(&self) -> Pin<Box<dyn Future<Output = String> + Send + '_>>;
142}
143
144impl<S: web_transport_trait::Session> SessionInner for S {
145 fn close(&self, code: u32, reason: &str) {
146 S::close(self, code, reason);
147 }
148
149 fn closed(&self) -> Pin<Box<dyn Future<Output = String> + Send + '_>> {
150 Box::pin(async move { S::closed(self).await.to_string() })
151 }
152}