1use crate::{
2 ALPN_14, ALPN_15, ALPN_16, ALPN_17, ALPN_18, ALPN_LITE, ALPN_LITE_03, ALPN_LITE_04, ALPN_LITE_05_WIP, Error,
3 NEGOTIATED, OriginConsumer, OriginProducer, Session, StatsHandle, Version, Versions,
4 coding::{self, Decode, Encode, Stream},
5 ietf, lite, setup,
6};
7
8#[derive(Default, Clone)]
10pub struct Client {
11 publish: Option<OriginConsumer>,
12 consume: Option<OriginProducer>,
13 stats: StatsHandle,
14 versions: Versions,
15}
16
17impl Client {
18 pub fn new() -> Self {
19 Default::default()
20 }
21
22 pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
23 self.publish = publish.into();
24 self
25 }
26
27 pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
28 self.consume = consume.into();
29 self
30 }
31
32 pub fn with_stats(mut self, stats: StatsHandle) -> Self {
36 self.stats = stats;
37 self
38 }
39
40 pub fn with_origin(self, origin: OriginProducer) -> Self {
44 let consumer = origin.consume();
45 self.with_publish(consumer).with_consume(origin)
46 }
47
48 pub fn with_versions(mut self, versions: Versions) -> Self {
49 self.versions = versions;
50 self
51 }
52
53 pub async fn connect<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
55 if self.publish.is_none() && self.consume.is_none() {
56 tracing::warn!("not publishing or consuming anything");
57 }
58
59 let (encoding, supported) = match session.protocol() {
62 Some(ALPN_18) => {
63 let v = self
64 .versions
65 .select(Version::Ietf(ietf::Version::Draft18))
66 .ok_or(Error::Version)?;
67
68 ietf::start(
70 session.clone(),
71 None,
72 None,
73 true,
74 self.publish.clone(),
75 self.consume.clone(),
76 ietf::Version::Draft18,
77 )?;
78
79 tracing::debug!(version = ?v, "connected");
80 return Ok(Session::new(session, v, None));
81 }
82 Some(ALPN_17) => {
83 let v = self
84 .versions
85 .select(Version::Ietf(ietf::Version::Draft17))
86 .ok_or(Error::Version)?;
87
88 ietf::start(
90 session.clone(),
91 None,
92 None,
93 true,
94 self.publish.clone(),
95 self.consume.clone(),
96 ietf::Version::Draft17,
97 )?;
98
99 tracing::debug!(version = ?v, "connected");
100 return Ok(Session::new(session, v, None));
101 }
102 Some(ALPN_16) => {
103 let v = self
104 .versions
105 .select(Version::Ietf(ietf::Version::Draft16))
106 .ok_or(Error::Version)?;
107 (v, v.into())
108 }
109 Some(ALPN_15) => {
110 let v = self
111 .versions
112 .select(Version::Ietf(ietf::Version::Draft15))
113 .ok_or(Error::Version)?;
114 (v, v.into())
115 }
116 Some(ALPN_14) => {
117 let v = self
118 .versions
119 .select(Version::Ietf(ietf::Version::Draft14))
120 .ok_or(Error::Version)?;
121 (v, v.into())
122 }
123 Some(ALPN_LITE_05_WIP) => {
124 self.versions
125 .select(Version::Lite(lite::Version::Lite05Wip))
126 .ok_or(Error::Version)?;
127
128 let recv_bw = lite::start(
129 session.clone(),
130 None,
131 self.publish.clone(),
132 self.consume.clone(),
133 self.stats.clone(),
134 lite::Version::Lite05Wip,
135 )?;
136
137 return Ok(Session::new(session, lite::Version::Lite05Wip.into(), recv_bw));
138 }
139 Some(ALPN_LITE_04) => {
140 self.versions
141 .select(Version::Lite(lite::Version::Lite04))
142 .ok_or(Error::Version)?;
143
144 let recv_bw = lite::start(
145 session.clone(),
146 None,
147 self.publish.clone(),
148 self.consume.clone(),
149 self.stats.clone(),
150 lite::Version::Lite04,
151 )?;
152
153 return Ok(Session::new(session, lite::Version::Lite04.into(), recv_bw));
154 }
155 Some(ALPN_LITE_03) => {
156 self.versions
157 .select(Version::Lite(lite::Version::Lite03))
158 .ok_or(Error::Version)?;
159
160 let recv_bw = lite::start(
162 session.clone(),
163 None,
164 self.publish.clone(),
165 self.consume.clone(),
166 self.stats.clone(),
167 lite::Version::Lite03,
168 )?;
169
170 return Ok(Session::new(session, lite::Version::Lite03.into(), recv_bw));
171 }
172 Some(ALPN_LITE) | None => {
173 let supported = self.versions.filter(&NEGOTIATED.into()).ok_or(Error::Version)?;
174 (Version::Ietf(ietf::Version::Draft14), supported)
175 }
176 Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
177 };
178
179 let mut stream = Stream::open(&session, encoding).await?;
180
181 let ietf_encoding = ietf::Version::try_from(encoding).map_err(|_| Error::Version)?;
183
184 let mut parameters = ietf::Parameters::default();
185 parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
186 parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
187 let parameters = parameters.encode_bytes(ietf_encoding)?;
188
189 let client = setup::Client {
190 versions: supported.clone().into(),
191 parameters,
192 };
193
194 stream.writer.encode(&client).await?;
195
196 let mut server: setup::Server = stream.reader.decode().await?;
197
198 let version = supported
199 .iter()
200 .find(|v| coding::Version::from(**v) == server.version)
201 .copied()
202 .ok_or(Error::Version)?;
203
204 let recv_bw = match version {
205 Version::Lite(v) => {
206 let stream = stream.with_version(v);
207 lite::start(
208 session.clone(),
209 Some(stream),
210 self.publish.clone(),
211 self.consume.clone(),
212 self.stats.clone(),
213 v,
214 )?
215 }
216 Version::Ietf(v) => {
217 let parameters = ietf::Parameters::decode(&mut server.parameters, v)?;
219 let request_id_max = parameters
220 .get_varint(ietf::ParameterVarInt::MaxRequestId)
221 .map(ietf::RequestId);
222
223 let stream = stream.with_version(v);
224 ietf::start(
225 session.clone(),
226 Some(stream),
227 request_id_max,
228 true,
229 self.publish.clone(),
230 self.consume.clone(),
231 v,
232 )?;
233 None
234 }
235 };
236
237 Ok(Session::new(session, version, recv_bw))
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use std::{
245 collections::VecDeque,
246 sync::{Arc, Mutex},
247 };
248
249 use crate::coding::{Decode, Encode};
250 use bytes::{BufMut, Bytes};
251
252 #[derive(Debug, Clone, Default)]
253 struct FakeError;
254
255 impl std::fmt::Display for FakeError {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 write!(f, "fake transport error")
258 }
259 }
260
261 impl std::error::Error for FakeError {}
262
263 impl web_transport_trait::Error for FakeError {
264 fn session_error(&self) -> Option<(u32, String)> {
265 Some((0, "closed".to_string()))
266 }
267 }
268
269 #[derive(Clone, Default)]
270 struct FakeSession {
271 state: Arc<FakeSessionState>,
272 }
273
274 #[derive(Default)]
275 struct FakeSessionState {
276 protocol: Option<&'static str>,
277 control_stream: Mutex<Option<(FakeSendStream, FakeRecvStream)>>,
278 close_events: Mutex<Vec<(u32, String)>>,
279 close_notify: tokio::sync::Notify,
280 control_writes: Arc<Mutex<Vec<u8>>>,
281 }
282
283 impl FakeSession {
284 fn new(protocol: Option<&'static str>, server_control_bytes: Vec<u8>) -> Self {
285 let writes = Arc::new(Mutex::new(Vec::new()));
286 let send = FakeSendStream { writes: writes.clone() };
287 let recv = FakeRecvStream {
288 data: VecDeque::from(server_control_bytes),
289 };
290 let state = FakeSessionState {
291 protocol,
292 control_stream: Mutex::new(Some((send, recv))),
293 close_events: Mutex::new(Vec::new()),
294 close_notify: tokio::sync::Notify::new(),
295 control_writes: writes,
296 };
297 Self { state: Arc::new(state) }
298 }
299
300 fn control_writes(&self) -> Vec<u8> {
301 self.state.control_writes.lock().unwrap().clone()
302 }
303
304 async fn wait_for_first_close(&self) -> (u32, String) {
305 loop {
306 let notified = self.state.close_notify.notified();
307 if let Some(close) = self.state.close_events.lock().unwrap().first().cloned() {
308 return close;
309 }
310 notified.await;
311 }
312 }
313 }
314
315 impl web_transport_trait::Session for FakeSession {
316 type SendStream = FakeSendStream;
317 type RecvStream = FakeRecvStream;
318 type Error = FakeError;
319
320 async fn accept_uni(&self) -> Result<Self::RecvStream, Self::Error> {
321 std::future::pending().await
322 }
323
324 async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
325 std::future::pending().await
326 }
327
328 async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
329 self.state.control_stream.lock().unwrap().take().ok_or(FakeError)
330 }
331
332 async fn open_uni(&self) -> Result<Self::SendStream, Self::Error> {
333 std::future::pending().await
334 }
335
336 fn send_datagram(&self, _payload: Bytes) -> Result<(), Self::Error> {
337 Ok(())
338 }
339
340 async fn recv_datagram(&self) -> Result<Bytes, Self::Error> {
341 std::future::pending().await
342 }
343
344 fn max_datagram_size(&self) -> usize {
345 1200
346 }
347
348 fn protocol(&self) -> Option<&str> {
349 self.state.protocol
350 }
351
352 fn close(&self, code: u32, reason: &str) {
353 self.state.close_events.lock().unwrap().push((code, reason.to_string()));
354 self.state.close_notify.notify_waiters();
355 }
356
357 async fn closed(&self) -> Self::Error {
358 self.state.close_notify.notified().await;
359 FakeError
360 }
361 }
362
363 #[derive(Clone, Default)]
364 struct FakeSendStream {
365 writes: Arc<Mutex<Vec<u8>>>,
366 }
367
368 impl web_transport_trait::SendStream for FakeSendStream {
369 type Error = FakeError;
370
371 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
372 self.writes.lock().unwrap().put_slice(buf);
373 Ok(buf.len())
374 }
375
376 fn set_priority(&mut self, _order: u8) {}
377
378 fn finish(&mut self) -> Result<(), Self::Error> {
379 Ok(())
380 }
381
382 fn reset(&mut self, _code: u32) {}
383
384 async fn closed(&mut self) -> Result<(), Self::Error> {
385 Ok(())
386 }
387 }
388
389 struct FakeRecvStream {
390 data: VecDeque<u8>,
391 }
392
393 impl web_transport_trait::RecvStream for FakeRecvStream {
394 type Error = FakeError;
395
396 async fn read(&mut self, dst: &mut [u8]) -> Result<Option<usize>, Self::Error> {
397 if self.data.is_empty() {
398 return Ok(None);
399 }
400
401 let size = dst.len().min(self.data.len());
402 for slot in dst.iter_mut().take(size) {
403 *slot = self.data.pop_front().unwrap();
404 }
405 Ok(Some(size))
406 }
407
408 fn stop(&mut self, _code: u32) {}
409
410 async fn closed(&mut self) -> Result<(), Self::Error> {
411 Ok(())
412 }
413 }
414
415 fn mock_server_setup(negotiated: Version) -> Vec<u8> {
416 let mut encoded = Vec::new();
417 let server = setup::Server {
418 version: negotiated.into(),
419 parameters: Bytes::new(),
420 };
421 server
422 .encode(&mut encoded, Version::Ietf(ietf::Version::Draft14))
423 .unwrap();
424
425 let info = lite::SessionInfo { bitrate: Some(1) };
427 let lite_v = lite::Version::try_from(negotiated).unwrap();
428 info.encode(&mut encoded, lite_v).unwrap();
429
430 encoded
431 }
432
433 async fn run_alpn_lite_fallback_case(protocol: Option<&'static str>) {
434 let fake = FakeSession::new(protocol, mock_server_setup(Version::Lite(lite::Version::Lite01)));
435 let client = Client::new().with_versions(
436 [
437 Version::Lite(lite::Version::Lite03),
438 Version::Lite(lite::Version::Lite02),
439 Version::Lite(lite::Version::Lite01),
440 Version::Ietf(ietf::Version::Draft14),
441 ]
442 .into(),
443 );
444
445 let _session = client.connect(fake.clone()).await.unwrap();
446
447 let mut setup_bytes = Bytes::from(fake.control_writes());
449 let setup = setup::Client::decode(&mut setup_bytes, Version::Ietf(ietf::Version::Draft14)).unwrap();
450 let advertised: Vec<Version> = setup.versions.iter().map(|v| Version::try_from(*v).unwrap()).collect();
451 assert_eq!(
452 advertised,
453 vec![
454 Version::Lite(lite::Version::Lite02),
455 Version::Lite(lite::Version::Lite01),
456 Version::Ietf(ietf::Version::Draft14),
457 ]
458 );
459
460 let (code, _) = fake.wait_for_first_close().await;
463 assert_eq!(code, Error::Cancel.to_code());
464 }
465
466 #[tokio::test(start_paused = true)]
467 async fn alpn_lite_falls_back_to_draft14_and_switches_version_post_setup() {
468 run_alpn_lite_fallback_case(Some(ALPN_LITE)).await;
469 }
470
471 #[tokio::test(start_paused = true)]
472 async fn no_alpn_falls_back_to_draft14_and_switches_version_post_setup() {
473 run_alpn_lite_fallback_case(None).await;
474 }
475}