1use crate::{
2 ALPN_14, ALPN_15, ALPN_16, ALPN_17, ALPN_18, ALPN_LITE, ALPN_LITE_03, ALPN_LITE_04, Error, NEGOTIATED,
3 OriginConsumer, OriginProducer, Session, StatsHandle, Version, Versions,
4 coding::{self, Decode, Encode, Stream},
5 ietf, lite, setup,
6};
7
8#[derive(Default, Clone)]
10pub struct Client {
11 publish: Option<OriginConsumer>,
12 consume: Option<OriginProducer>,
13 stats: StatsHandle,
14 versions: Versions,
15}
16
17impl Client {
18 pub fn new() -> Self {
19 Default::default()
20 }
21
22 pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
23 self.publish = publish.into();
24 self
25 }
26
27 pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
28 self.consume = consume.into();
29 self
30 }
31
32 pub fn with_stats(mut self, stats: StatsHandle) -> Self {
36 self.stats = stats;
37 self
38 }
39
40 pub fn with_origin(self, origin: OriginProducer) -> Self {
44 let consumer = origin.consume();
45 self.with_publish(consumer).with_consume(origin)
46 }
47
48 pub fn with_versions(mut self, versions: Versions) -> Self {
49 self.versions = versions;
50 self
51 }
52
53 pub async fn connect<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
55 if self.publish.is_none() && self.consume.is_none() {
56 tracing::warn!("not publishing or consuming anything");
57 }
58
59 let (encoding, supported) = match session.protocol() {
62 Some(ALPN_18) => {
63 let v = self
64 .versions
65 .select(Version::Ietf(ietf::Version::Draft18))
66 .ok_or(Error::Version)?;
67
68 ietf::start(
70 session.clone(),
71 None,
72 None,
73 true,
74 self.publish.clone(),
75 self.consume.clone(),
76 ietf::Version::Draft18,
77 )?;
78
79 tracing::debug!(version = ?v, "connected");
80 return Ok(Session::new(session, v, None));
81 }
82 Some(ALPN_17) => {
83 let v = self
84 .versions
85 .select(Version::Ietf(ietf::Version::Draft17))
86 .ok_or(Error::Version)?;
87
88 ietf::start(
90 session.clone(),
91 None,
92 None,
93 true,
94 self.publish.clone(),
95 self.consume.clone(),
96 ietf::Version::Draft17,
97 )?;
98
99 tracing::debug!(version = ?v, "connected");
100 return Ok(Session::new(session, v, None));
101 }
102 Some(ALPN_16) => {
103 let v = self
104 .versions
105 .select(Version::Ietf(ietf::Version::Draft16))
106 .ok_or(Error::Version)?;
107 (v, v.into())
108 }
109 Some(ALPN_15) => {
110 let v = self
111 .versions
112 .select(Version::Ietf(ietf::Version::Draft15))
113 .ok_or(Error::Version)?;
114 (v, v.into())
115 }
116 Some(ALPN_14) => {
117 let v = self
118 .versions
119 .select(Version::Ietf(ietf::Version::Draft14))
120 .ok_or(Error::Version)?;
121 (v, v.into())
122 }
123 Some(ALPN_LITE_04) => {
124 self.versions
125 .select(Version::Lite(lite::Version::Lite04))
126 .ok_or(Error::Version)?;
127
128 let recv_bw = lite::start(
129 session.clone(),
130 None,
131 self.publish.clone(),
132 self.consume.clone(),
133 self.stats.clone(),
134 lite::Version::Lite04,
135 )?;
136
137 return Ok(Session::new(session, lite::Version::Lite04.into(), recv_bw));
138 }
139 Some(ALPN_LITE_03) => {
140 self.versions
141 .select(Version::Lite(lite::Version::Lite03))
142 .ok_or(Error::Version)?;
143
144 let recv_bw = lite::start(
146 session.clone(),
147 None,
148 self.publish.clone(),
149 self.consume.clone(),
150 self.stats.clone(),
151 lite::Version::Lite03,
152 )?;
153
154 return Ok(Session::new(session, lite::Version::Lite03.into(), recv_bw));
155 }
156 Some(ALPN_LITE) | None => {
157 let supported = self.versions.filter(&NEGOTIATED.into()).ok_or(Error::Version)?;
158 (Version::Ietf(ietf::Version::Draft14), supported)
159 }
160 Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
161 };
162
163 let mut stream = Stream::open(&session, encoding).await?;
164
165 let ietf_encoding = ietf::Version::try_from(encoding).map_err(|_| Error::Version)?;
167
168 let mut parameters = ietf::Parameters::default();
169 parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
170 parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
171 let parameters = parameters.encode_bytes(ietf_encoding)?;
172
173 let client = setup::Client {
174 versions: supported.clone().into(),
175 parameters,
176 };
177
178 stream.writer.encode(&client).await?;
179
180 let mut server: setup::Server = stream.reader.decode().await?;
181
182 let version = supported
183 .iter()
184 .find(|v| coding::Version::from(**v) == server.version)
185 .copied()
186 .ok_or(Error::Version)?;
187
188 let recv_bw = match version {
189 Version::Lite(v) => {
190 let stream = stream.with_version(v);
191 lite::start(
192 session.clone(),
193 Some(stream),
194 self.publish.clone(),
195 self.consume.clone(),
196 self.stats.clone(),
197 v,
198 )?
199 }
200 Version::Ietf(v) => {
201 let parameters = ietf::Parameters::decode(&mut server.parameters, v)?;
203 let request_id_max = parameters
204 .get_varint(ietf::ParameterVarInt::MaxRequestId)
205 .map(ietf::RequestId);
206
207 let stream = stream.with_version(v);
208 ietf::start(
209 session.clone(),
210 Some(stream),
211 request_id_max,
212 true,
213 self.publish.clone(),
214 self.consume.clone(),
215 v,
216 )?;
217 None
218 }
219 };
220
221 Ok(Session::new(session, version, recv_bw))
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use std::{
229 collections::VecDeque,
230 sync::{Arc, Mutex},
231 };
232
233 use crate::coding::{Decode, Encode};
234 use bytes::{BufMut, Bytes};
235
236 #[derive(Debug, Clone, Default)]
237 struct FakeError;
238
239 impl std::fmt::Display for FakeError {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 write!(f, "fake transport error")
242 }
243 }
244
245 impl std::error::Error for FakeError {}
246
247 impl web_transport_trait::Error for FakeError {
248 fn session_error(&self) -> Option<(u32, String)> {
249 Some((0, "closed".to_string()))
250 }
251 }
252
253 #[derive(Clone, Default)]
254 struct FakeSession {
255 state: Arc<FakeSessionState>,
256 }
257
258 #[derive(Default)]
259 struct FakeSessionState {
260 protocol: Option<&'static str>,
261 control_stream: Mutex<Option<(FakeSendStream, FakeRecvStream)>>,
262 close_events: Mutex<Vec<(u32, String)>>,
263 close_notify: tokio::sync::Notify,
264 control_writes: Arc<Mutex<Vec<u8>>>,
265 }
266
267 impl FakeSession {
268 fn new(protocol: Option<&'static str>, server_control_bytes: Vec<u8>) -> Self {
269 let writes = Arc::new(Mutex::new(Vec::new()));
270 let send = FakeSendStream { writes: writes.clone() };
271 let recv = FakeRecvStream {
272 data: VecDeque::from(server_control_bytes),
273 };
274 let state = FakeSessionState {
275 protocol,
276 control_stream: Mutex::new(Some((send, recv))),
277 close_events: Mutex::new(Vec::new()),
278 close_notify: tokio::sync::Notify::new(),
279 control_writes: writes,
280 };
281 Self { state: Arc::new(state) }
282 }
283
284 fn control_writes(&self) -> Vec<u8> {
285 self.state.control_writes.lock().unwrap().clone()
286 }
287
288 async fn wait_for_first_close(&self) -> (u32, String) {
289 loop {
290 let notified = self.state.close_notify.notified();
291 if let Some(close) = self.state.close_events.lock().unwrap().first().cloned() {
292 return close;
293 }
294 notified.await;
295 }
296 }
297 }
298
299 impl web_transport_trait::Session for FakeSession {
300 type SendStream = FakeSendStream;
301 type RecvStream = FakeRecvStream;
302 type Error = FakeError;
303
304 async fn accept_uni(&self) -> Result<Self::RecvStream, Self::Error> {
305 std::future::pending().await
306 }
307
308 async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
309 std::future::pending().await
310 }
311
312 async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
313 self.state.control_stream.lock().unwrap().take().ok_or(FakeError)
314 }
315
316 async fn open_uni(&self) -> Result<Self::SendStream, Self::Error> {
317 std::future::pending().await
318 }
319
320 fn send_datagram(&self, _payload: Bytes) -> Result<(), Self::Error> {
321 Ok(())
322 }
323
324 async fn recv_datagram(&self) -> Result<Bytes, Self::Error> {
325 std::future::pending().await
326 }
327
328 fn max_datagram_size(&self) -> usize {
329 1200
330 }
331
332 fn protocol(&self) -> Option<&str> {
333 self.state.protocol
334 }
335
336 fn close(&self, code: u32, reason: &str) {
337 self.state.close_events.lock().unwrap().push((code, reason.to_string()));
338 self.state.close_notify.notify_waiters();
339 }
340
341 async fn closed(&self) -> Self::Error {
342 self.state.close_notify.notified().await;
343 FakeError
344 }
345 }
346
347 #[derive(Clone, Default)]
348 struct FakeSendStream {
349 writes: Arc<Mutex<Vec<u8>>>,
350 }
351
352 impl web_transport_trait::SendStream for FakeSendStream {
353 type Error = FakeError;
354
355 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
356 self.writes.lock().unwrap().put_slice(buf);
357 Ok(buf.len())
358 }
359
360 fn set_priority(&mut self, _order: u8) {}
361
362 fn finish(&mut self) -> Result<(), Self::Error> {
363 Ok(())
364 }
365
366 fn reset(&mut self, _code: u32) {}
367
368 async fn closed(&mut self) -> Result<(), Self::Error> {
369 Ok(())
370 }
371 }
372
373 struct FakeRecvStream {
374 data: VecDeque<u8>,
375 }
376
377 impl web_transport_trait::RecvStream for FakeRecvStream {
378 type Error = FakeError;
379
380 async fn read(&mut self, dst: &mut [u8]) -> Result<Option<usize>, Self::Error> {
381 if self.data.is_empty() {
382 return Ok(None);
383 }
384
385 let size = dst.len().min(self.data.len());
386 for slot in dst.iter_mut().take(size) {
387 *slot = self.data.pop_front().unwrap();
388 }
389 Ok(Some(size))
390 }
391
392 fn stop(&mut self, _code: u32) {}
393
394 async fn closed(&mut self) -> Result<(), Self::Error> {
395 Ok(())
396 }
397 }
398
399 fn mock_server_setup(negotiated: Version) -> Vec<u8> {
400 let mut encoded = Vec::new();
401 let server = setup::Server {
402 version: negotiated.into(),
403 parameters: Bytes::new(),
404 };
405 server
406 .encode(&mut encoded, Version::Ietf(ietf::Version::Draft14))
407 .unwrap();
408
409 let info = lite::SessionInfo { bitrate: Some(1) };
411 let lite_v = lite::Version::try_from(negotiated).unwrap();
412 info.encode(&mut encoded, lite_v).unwrap();
413
414 encoded
415 }
416
417 async fn run_alpn_lite_fallback_case(protocol: Option<&'static str>) {
418 let fake = FakeSession::new(protocol, mock_server_setup(Version::Lite(lite::Version::Lite01)));
419 let client = Client::new().with_versions(
420 [
421 Version::Lite(lite::Version::Lite03),
422 Version::Lite(lite::Version::Lite02),
423 Version::Lite(lite::Version::Lite01),
424 Version::Ietf(ietf::Version::Draft14),
425 ]
426 .into(),
427 );
428
429 let _session = client.connect(fake.clone()).await.unwrap();
430
431 let mut setup_bytes = Bytes::from(fake.control_writes());
433 let setup = setup::Client::decode(&mut setup_bytes, Version::Ietf(ietf::Version::Draft14)).unwrap();
434 let advertised: Vec<Version> = setup.versions.iter().map(|v| Version::try_from(*v).unwrap()).collect();
435 assert_eq!(
436 advertised,
437 vec![
438 Version::Lite(lite::Version::Lite02),
439 Version::Lite(lite::Version::Lite01),
440 Version::Ietf(ietf::Version::Draft14),
441 ]
442 );
443
444 let (code, _) = fake.wait_for_first_close().await;
447 assert_eq!(code, Error::Cancel.to_code());
448 }
449
450 #[tokio::test(start_paused = true)]
451 async fn alpn_lite_falls_back_to_draft14_and_switches_version_post_setup() {
452 run_alpn_lite_fallback_case(Some(ALPN_LITE)).await;
453 }
454
455 #[tokio::test(start_paused = true)]
456 async fn no_alpn_falls_back_to_draft14_and_switches_version_post_setup() {
457 run_alpn_lite_fallback_case(None).await;
458 }
459}