1use crate::{
2 ALPN_14, ALPN_15, ALPN_16, ALPN_17, ALPN_18, ALPN_LITE, ALPN_LITE_03, ALPN_LITE_04, ALPN_LITE_05_WIP, Error,
3 NEGOTIATED, OriginConsumer, OriginProducer, Session, StatsHandle, Version, Versions,
4 coding::{self, Decode, Encode, Stream},
5 ietf, lite, setup,
6};
7
8#[derive(Default, Clone)]
10pub struct Client {
11 publish: Option<OriginConsumer>,
12 consume: Option<OriginProducer>,
13 stats: StatsHandle,
14 versions: Versions,
15}
16
17impl Client {
18 pub fn new() -> Self {
19 Default::default()
20 }
21
22 pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
23 self.publish = publish.into();
24 self
25 }
26
27 pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
28 self.consume = consume.into();
29 self
30 }
31
32 pub fn with_stats(mut self, stats: StatsHandle) -> Self {
36 self.stats = stats;
37 self
38 }
39
40 pub fn with_origin(self, origin: OriginProducer) -> Self {
44 let consumer = origin.consume();
45 self.with_publish(consumer).with_consume(origin)
46 }
47
48 pub fn with_versions(mut self, versions: Versions) -> Self {
49 self.versions = versions;
50 self
51 }
52
53 pub async fn connect<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
55 if self.publish.is_none() && self.consume.is_none() {
56 tracing::warn!("not publishing or consuming anything");
57 }
58
59 let (encoding, supported) = match session.protocol() {
62 Some(ALPN_18) => {
63 let v = self
64 .versions
65 .select(Version::Ietf(ietf::Version::Draft18))
66 .ok_or(Error::Version)?;
67
68 ietf::start(
70 session.clone(),
71 None,
72 None,
73 true,
74 self.publish.clone(),
75 self.consume.clone(),
76 self.stats.clone(),
77 ietf::Version::Draft18,
78 )?;
79
80 tracing::debug!(version = ?v, "connected");
81 return Ok(Session::new(session, v, None));
82 }
83 Some(ALPN_17) => {
84 let v = self
85 .versions
86 .select(Version::Ietf(ietf::Version::Draft17))
87 .ok_or(Error::Version)?;
88
89 ietf::start(
91 session.clone(),
92 None,
93 None,
94 true,
95 self.publish.clone(),
96 self.consume.clone(),
97 self.stats.clone(),
98 ietf::Version::Draft17,
99 )?;
100
101 tracing::debug!(version = ?v, "connected");
102 return Ok(Session::new(session, v, None));
103 }
104 Some(ALPN_16) => {
105 let v = self
106 .versions
107 .select(Version::Ietf(ietf::Version::Draft16))
108 .ok_or(Error::Version)?;
109 (v, v.into())
110 }
111 Some(ALPN_15) => {
112 let v = self
113 .versions
114 .select(Version::Ietf(ietf::Version::Draft15))
115 .ok_or(Error::Version)?;
116 (v, v.into())
117 }
118 Some(ALPN_14) => {
119 let v = self
120 .versions
121 .select(Version::Ietf(ietf::Version::Draft14))
122 .ok_or(Error::Version)?;
123 (v, v.into())
124 }
125 Some(ALPN_LITE_05_WIP) => {
126 self.versions
127 .select(Version::Lite(lite::Version::Lite05Wip))
128 .ok_or(Error::Version)?;
129
130 let recv_bw = lite::start(
131 session.clone(),
132 None,
133 self.publish.clone(),
134 self.consume.clone(),
135 self.stats.clone(),
136 lite::Version::Lite05Wip,
137 )?;
138
139 return Ok(Session::new(session, lite::Version::Lite05Wip.into(), recv_bw));
140 }
141 Some(ALPN_LITE_04) => {
142 self.versions
143 .select(Version::Lite(lite::Version::Lite04))
144 .ok_or(Error::Version)?;
145
146 let recv_bw = lite::start(
147 session.clone(),
148 None,
149 self.publish.clone(),
150 self.consume.clone(),
151 self.stats.clone(),
152 lite::Version::Lite04,
153 )?;
154
155 return Ok(Session::new(session, lite::Version::Lite04.into(), recv_bw));
156 }
157 Some(ALPN_LITE_03) => {
158 self.versions
159 .select(Version::Lite(lite::Version::Lite03))
160 .ok_or(Error::Version)?;
161
162 let recv_bw = lite::start(
164 session.clone(),
165 None,
166 self.publish.clone(),
167 self.consume.clone(),
168 self.stats.clone(),
169 lite::Version::Lite03,
170 )?;
171
172 return Ok(Session::new(session, lite::Version::Lite03.into(), recv_bw));
173 }
174 Some(ALPN_LITE) | None => {
175 let supported = self.versions.filter(&NEGOTIATED.into()).ok_or(Error::Version)?;
176 (Version::Ietf(ietf::Version::Draft14), supported)
177 }
178 Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
179 };
180
181 let mut stream = Stream::open(&session, encoding).await?;
182
183 let ietf_encoding = ietf::Version::try_from(encoding).map_err(|_| Error::Version)?;
185
186 let mut parameters = ietf::Parameters::default();
187 parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
188 parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
189 let parameters = parameters.encode_bytes(ietf_encoding)?;
190
191 let client = setup::Client {
192 versions: supported.clone().into(),
193 parameters,
194 };
195
196 stream.writer.encode(&client).await?;
197
198 let mut server: setup::Server = stream.reader.decode().await?;
199
200 let version = supported
201 .iter()
202 .find(|v| coding::Version::from(**v) == server.version)
203 .copied()
204 .ok_or(Error::Version)?;
205
206 let recv_bw = match version {
207 Version::Lite(v) => {
208 let stream = stream.with_version(v);
209 lite::start(
210 session.clone(),
211 Some(stream),
212 self.publish.clone(),
213 self.consume.clone(),
214 self.stats.clone(),
215 v,
216 )?
217 }
218 Version::Ietf(v) => {
219 let parameters = ietf::Parameters::decode(&mut server.parameters, v)?;
221 let request_id_max = parameters
222 .get_varint(ietf::ParameterVarInt::MaxRequestId)
223 .map(ietf::RequestId);
224
225 let stream = stream.with_version(v);
226 ietf::start(
227 session.clone(),
228 Some(stream),
229 request_id_max,
230 true,
231 self.publish.clone(),
232 self.consume.clone(),
233 self.stats.clone(),
234 v,
235 )?;
236 None
237 }
238 };
239
240 Ok(Session::new(session, version, recv_bw))
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use std::{
248 collections::VecDeque,
249 sync::{Arc, Mutex},
250 };
251
252 use crate::coding::{Decode, Encode};
253 use bytes::{BufMut, Bytes};
254
255 #[derive(Debug, Clone, Default)]
256 struct FakeError;
257
258 impl std::fmt::Display for FakeError {
259 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 write!(f, "fake transport error")
261 }
262 }
263
264 impl std::error::Error for FakeError {}
265
266 impl web_transport_trait::Error for FakeError {
267 fn session_error(&self) -> Option<(u32, String)> {
268 Some((0, "closed".to_string()))
269 }
270 }
271
272 #[derive(Clone, Default)]
273 struct FakeSession {
274 state: Arc<FakeSessionState>,
275 }
276
277 #[derive(Default)]
278 struct FakeSessionState {
279 protocol: Option<&'static str>,
280 control_stream: Mutex<Option<(FakeSendStream, FakeRecvStream)>>,
281 close_events: Mutex<Vec<(u32, String)>>,
282 close_notify: tokio::sync::Notify,
283 control_writes: Arc<Mutex<Vec<u8>>>,
284 }
285
286 impl FakeSession {
287 fn new(protocol: Option<&'static str>, server_control_bytes: Vec<u8>) -> Self {
288 let writes = Arc::new(Mutex::new(Vec::new()));
289 let send = FakeSendStream { writes: writes.clone() };
290 let recv = FakeRecvStream {
291 data: VecDeque::from(server_control_bytes),
292 };
293 let state = FakeSessionState {
294 protocol,
295 control_stream: Mutex::new(Some((send, recv))),
296 close_events: Mutex::new(Vec::new()),
297 close_notify: tokio::sync::Notify::new(),
298 control_writes: writes,
299 };
300 Self { state: Arc::new(state) }
301 }
302
303 fn control_writes(&self) -> Vec<u8> {
304 self.state.control_writes.lock().unwrap().clone()
305 }
306
307 async fn wait_for_first_close(&self) -> (u32, String) {
308 loop {
309 let notified = self.state.close_notify.notified();
310 if let Some(close) = self.state.close_events.lock().unwrap().first().cloned() {
311 return close;
312 }
313 notified.await;
314 }
315 }
316 }
317
318 impl web_transport_trait::Session for FakeSession {
319 type SendStream = FakeSendStream;
320 type RecvStream = FakeRecvStream;
321 type Error = FakeError;
322
323 async fn accept_uni(&self) -> Result<Self::RecvStream, Self::Error> {
324 std::future::pending().await
325 }
326
327 async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
328 std::future::pending().await
329 }
330
331 async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
332 self.state.control_stream.lock().unwrap().take().ok_or(FakeError)
333 }
334
335 async fn open_uni(&self) -> Result<Self::SendStream, Self::Error> {
336 std::future::pending().await
337 }
338
339 fn send_datagram(&self, _payload: Bytes) -> Result<(), Self::Error> {
340 Ok(())
341 }
342
343 async fn recv_datagram(&self) -> Result<Bytes, Self::Error> {
344 std::future::pending().await
345 }
346
347 fn max_datagram_size(&self) -> usize {
348 1200
349 }
350
351 fn protocol(&self) -> Option<&str> {
352 self.state.protocol
353 }
354
355 fn close(&self, code: u32, reason: &str) {
356 self.state.close_events.lock().unwrap().push((code, reason.to_string()));
357 self.state.close_notify.notify_waiters();
358 }
359
360 async fn closed(&self) -> Self::Error {
361 self.state.close_notify.notified().await;
362 FakeError
363 }
364 }
365
366 #[derive(Clone, Default)]
367 struct FakeSendStream {
368 writes: Arc<Mutex<Vec<u8>>>,
369 }
370
371 impl web_transport_trait::SendStream for FakeSendStream {
372 type Error = FakeError;
373
374 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
375 self.writes.lock().unwrap().put_slice(buf);
376 Ok(buf.len())
377 }
378
379 fn set_priority(&mut self, _order: u8) {}
380
381 fn finish(&mut self) -> Result<(), Self::Error> {
382 Ok(())
383 }
384
385 fn reset(&mut self, _code: u32) {}
386
387 async fn closed(&mut self) -> Result<(), Self::Error> {
388 Ok(())
389 }
390 }
391
392 struct FakeRecvStream {
393 data: VecDeque<u8>,
394 }
395
396 impl web_transport_trait::RecvStream for FakeRecvStream {
397 type Error = FakeError;
398
399 async fn read(&mut self, dst: &mut [u8]) -> Result<Option<usize>, Self::Error> {
400 if self.data.is_empty() {
401 return Ok(None);
402 }
403
404 let size = dst.len().min(self.data.len());
405 for slot in dst.iter_mut().take(size) {
406 *slot = self.data.pop_front().unwrap();
407 }
408 Ok(Some(size))
409 }
410
411 fn stop(&mut self, _code: u32) {}
412
413 async fn closed(&mut self) -> Result<(), Self::Error> {
414 Ok(())
415 }
416 }
417
418 fn mock_server_setup(negotiated: Version) -> Vec<u8> {
419 let mut encoded = Vec::new();
420 let server = setup::Server {
421 version: negotiated.into(),
422 parameters: Bytes::new(),
423 };
424 server
425 .encode(&mut encoded, Version::Ietf(ietf::Version::Draft14))
426 .unwrap();
427
428 let info = lite::SessionInfo { bitrate: Some(1) };
430 let lite_v = lite::Version::try_from(negotiated).unwrap();
431 info.encode(&mut encoded, lite_v).unwrap();
432
433 encoded
434 }
435
436 async fn run_alpn_lite_fallback_case(protocol: Option<&'static str>) {
437 let fake = FakeSession::new(protocol, mock_server_setup(Version::Lite(lite::Version::Lite01)));
438 let client = Client::new().with_versions(
439 [
440 Version::Lite(lite::Version::Lite03),
441 Version::Lite(lite::Version::Lite02),
442 Version::Lite(lite::Version::Lite01),
443 Version::Ietf(ietf::Version::Draft14),
444 ]
445 .into(),
446 );
447
448 let _session = client.connect(fake.clone()).await.unwrap();
449
450 let mut setup_bytes = Bytes::from(fake.control_writes());
452 let setup = setup::Client::decode(&mut setup_bytes, Version::Ietf(ietf::Version::Draft14)).unwrap();
453 let advertised: Vec<Version> = setup.versions.iter().map(|v| Version::try_from(*v).unwrap()).collect();
454 assert_eq!(
455 advertised,
456 vec![
457 Version::Lite(lite::Version::Lite02),
458 Version::Lite(lite::Version::Lite01),
459 Version::Ietf(ietf::Version::Draft14),
460 ]
461 );
462
463 let (code, _) = fake.wait_for_first_close().await;
466 assert_eq!(code, Error::Cancel.to_code());
467 }
468
469 #[tokio::test(start_paused = true)]
470 async fn alpn_lite_falls_back_to_draft14_and_switches_version_post_setup() {
471 run_alpn_lite_fallback_case(Some(ALPN_LITE)).await;
472 }
473
474 #[tokio::test(start_paused = true)]
475 async fn no_alpn_falls_back_to_draft14_and_switches_version_post_setup() {
476 run_alpn_lite_fallback_case(None).await;
477 }
478}