1use crate::{
2 ALPN_14, ALPN_15, ALPN_16, ALPN_17, ALPN_18, ALPN_LITE, ALPN_LITE_03, ALPN_LITE_04, Error, NEGOTIATED,
3 OriginConsumer, OriginProducer, Session, Version, Versions,
4 coding::{self, Decode, Encode, Stream},
5 ietf, lite, setup,
6};
7
8#[derive(Default, Clone)]
10pub struct Client {
11 publish: Option<OriginConsumer>,
12 consume: Option<OriginProducer>,
13 versions: Versions,
14}
15
16impl Client {
17 pub fn new() -> Self {
18 Default::default()
19 }
20
21 pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
22 self.publish = publish.into();
23 self
24 }
25
26 pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
27 self.consume = consume.into();
28 self
29 }
30
31 pub fn with_origin(self, origin: OriginProducer) -> Self {
35 let consumer = origin.consume();
36 self.with_publish(consumer).with_consume(origin)
37 }
38
39 pub fn with_versions(mut self, versions: Versions) -> Self {
40 self.versions = versions;
41 self
42 }
43
44 pub async fn connect<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
46 if self.publish.is_none() && self.consume.is_none() {
47 tracing::warn!("not publishing or consuming anything");
48 }
49
50 let (encoding, supported) = match session.protocol() {
53 Some(ALPN_18) => {
54 let v = self
55 .versions
56 .select(Version::Ietf(ietf::Version::Draft18))
57 .ok_or(Error::Version)?;
58
59 ietf::start(
61 session.clone(),
62 None,
63 None,
64 true,
65 self.publish.clone(),
66 self.consume.clone(),
67 ietf::Version::Draft18,
68 )?;
69
70 tracing::debug!(version = ?v, "connected");
71 return Ok(Session::new(session, v, None));
72 }
73 Some(ALPN_17) => {
74 let v = self
75 .versions
76 .select(Version::Ietf(ietf::Version::Draft17))
77 .ok_or(Error::Version)?;
78
79 ietf::start(
81 session.clone(),
82 None,
83 None,
84 true,
85 self.publish.clone(),
86 self.consume.clone(),
87 ietf::Version::Draft17,
88 )?;
89
90 tracing::debug!(version = ?v, "connected");
91 return Ok(Session::new(session, v, None));
92 }
93 Some(ALPN_16) => {
94 let v = self
95 .versions
96 .select(Version::Ietf(ietf::Version::Draft16))
97 .ok_or(Error::Version)?;
98 (v, v.into())
99 }
100 Some(ALPN_15) => {
101 let v = self
102 .versions
103 .select(Version::Ietf(ietf::Version::Draft15))
104 .ok_or(Error::Version)?;
105 (v, v.into())
106 }
107 Some(ALPN_14) => {
108 let v = self
109 .versions
110 .select(Version::Ietf(ietf::Version::Draft14))
111 .ok_or(Error::Version)?;
112 (v, v.into())
113 }
114 Some(ALPN_LITE_04) => {
115 self.versions
116 .select(Version::Lite(lite::Version::Lite04))
117 .ok_or(Error::Version)?;
118
119 let recv_bw = lite::start(
120 session.clone(),
121 None,
122 self.publish.clone(),
123 self.consume.clone(),
124 lite::Version::Lite04,
125 )?;
126
127 return Ok(Session::new(session, lite::Version::Lite04.into(), recv_bw));
128 }
129 Some(ALPN_LITE_03) => {
130 self.versions
131 .select(Version::Lite(lite::Version::Lite03))
132 .ok_or(Error::Version)?;
133
134 let recv_bw = lite::start(
136 session.clone(),
137 None,
138 self.publish.clone(),
139 self.consume.clone(),
140 lite::Version::Lite03,
141 )?;
142
143 return Ok(Session::new(session, lite::Version::Lite03.into(), recv_bw));
144 }
145 Some(ALPN_LITE) | None => {
146 let supported = self.versions.filter(&NEGOTIATED.into()).ok_or(Error::Version)?;
147 (Version::Ietf(ietf::Version::Draft14), supported)
148 }
149 Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
150 };
151
152 let mut stream = Stream::open(&session, encoding).await?;
153
154 let ietf_encoding = ietf::Version::try_from(encoding).map_err(|_| Error::Version)?;
156
157 let mut parameters = ietf::Parameters::default();
158 parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
159 parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
160 let parameters = parameters.encode_bytes(ietf_encoding)?;
161
162 let client = setup::Client {
163 versions: supported.clone().into(),
164 parameters,
165 };
166
167 stream.writer.encode(&client).await?;
168
169 let mut server: setup::Server = stream.reader.decode().await?;
170
171 let version = supported
172 .iter()
173 .find(|v| coding::Version::from(**v) == server.version)
174 .copied()
175 .ok_or(Error::Version)?;
176
177 let recv_bw = match version {
178 Version::Lite(v) => {
179 let stream = stream.with_version(v);
180 lite::start(
181 session.clone(),
182 Some(stream),
183 self.publish.clone(),
184 self.consume.clone(),
185 v,
186 )?
187 }
188 Version::Ietf(v) => {
189 let parameters = ietf::Parameters::decode(&mut server.parameters, v)?;
191 let request_id_max = parameters
192 .get_varint(ietf::ParameterVarInt::MaxRequestId)
193 .map(ietf::RequestId);
194
195 let stream = stream.with_version(v);
196 ietf::start(
197 session.clone(),
198 Some(stream),
199 request_id_max,
200 true,
201 self.publish.clone(),
202 self.consume.clone(),
203 v,
204 )?;
205 None
206 }
207 };
208
209 Ok(Session::new(session, version, recv_bw))
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use std::{
217 collections::VecDeque,
218 sync::{Arc, Mutex},
219 };
220
221 use crate::coding::{Decode, Encode};
222 use bytes::{BufMut, Bytes};
223
224 #[derive(Debug, Clone, Default)]
225 struct FakeError;
226
227 impl std::fmt::Display for FakeError {
228 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229 write!(f, "fake transport error")
230 }
231 }
232
233 impl std::error::Error for FakeError {}
234
235 impl web_transport_trait::Error for FakeError {
236 fn session_error(&self) -> Option<(u32, String)> {
237 Some((0, "closed".to_string()))
238 }
239 }
240
241 #[derive(Clone, Default)]
242 struct FakeSession {
243 state: Arc<FakeSessionState>,
244 }
245
246 #[derive(Default)]
247 struct FakeSessionState {
248 protocol: Option<&'static str>,
249 control_stream: Mutex<Option<(FakeSendStream, FakeRecvStream)>>,
250 close_events: Mutex<Vec<(u32, String)>>,
251 close_notify: tokio::sync::Notify,
252 control_writes: Arc<Mutex<Vec<u8>>>,
253 }
254
255 impl FakeSession {
256 fn new(protocol: Option<&'static str>, server_control_bytes: Vec<u8>) -> Self {
257 let writes = Arc::new(Mutex::new(Vec::new()));
258 let send = FakeSendStream { writes: writes.clone() };
259 let recv = FakeRecvStream {
260 data: VecDeque::from(server_control_bytes),
261 };
262 let state = FakeSessionState {
263 protocol,
264 control_stream: Mutex::new(Some((send, recv))),
265 close_events: Mutex::new(Vec::new()),
266 close_notify: tokio::sync::Notify::new(),
267 control_writes: writes,
268 };
269 Self { state: Arc::new(state) }
270 }
271
272 fn control_writes(&self) -> Vec<u8> {
273 self.state.control_writes.lock().unwrap().clone()
274 }
275
276 async fn wait_for_first_close(&self) -> (u32, String) {
277 loop {
278 let notified = self.state.close_notify.notified();
279 if let Some(close) = self.state.close_events.lock().unwrap().first().cloned() {
280 return close;
281 }
282 notified.await;
283 }
284 }
285 }
286
287 impl web_transport_trait::Session for FakeSession {
288 type SendStream = FakeSendStream;
289 type RecvStream = FakeRecvStream;
290 type Error = FakeError;
291
292 async fn accept_uni(&self) -> Result<Self::RecvStream, Self::Error> {
293 std::future::pending().await
294 }
295
296 async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
297 std::future::pending().await
298 }
299
300 async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
301 self.state.control_stream.lock().unwrap().take().ok_or(FakeError)
302 }
303
304 async fn open_uni(&self) -> Result<Self::SendStream, Self::Error> {
305 std::future::pending().await
306 }
307
308 fn send_datagram(&self, _payload: Bytes) -> Result<(), Self::Error> {
309 Ok(())
310 }
311
312 async fn recv_datagram(&self) -> Result<Bytes, Self::Error> {
313 std::future::pending().await
314 }
315
316 fn max_datagram_size(&self) -> usize {
317 1200
318 }
319
320 fn protocol(&self) -> Option<&str> {
321 self.state.protocol
322 }
323
324 fn close(&self, code: u32, reason: &str) {
325 self.state.close_events.lock().unwrap().push((code, reason.to_string()));
326 self.state.close_notify.notify_waiters();
327 }
328
329 async fn closed(&self) -> Self::Error {
330 self.state.close_notify.notified().await;
331 FakeError
332 }
333 }
334
335 #[derive(Clone, Default)]
336 struct FakeSendStream {
337 writes: Arc<Mutex<Vec<u8>>>,
338 }
339
340 impl web_transport_trait::SendStream for FakeSendStream {
341 type Error = FakeError;
342
343 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
344 self.writes.lock().unwrap().put_slice(buf);
345 Ok(buf.len())
346 }
347
348 fn set_priority(&mut self, _order: u8) {}
349
350 fn finish(&mut self) -> Result<(), Self::Error> {
351 Ok(())
352 }
353
354 fn reset(&mut self, _code: u32) {}
355
356 async fn closed(&mut self) -> Result<(), Self::Error> {
357 Ok(())
358 }
359 }
360
361 struct FakeRecvStream {
362 data: VecDeque<u8>,
363 }
364
365 impl web_transport_trait::RecvStream for FakeRecvStream {
366 type Error = FakeError;
367
368 async fn read(&mut self, dst: &mut [u8]) -> Result<Option<usize>, Self::Error> {
369 if self.data.is_empty() {
370 return Ok(None);
371 }
372
373 let size = dst.len().min(self.data.len());
374 for slot in dst.iter_mut().take(size) {
375 *slot = self.data.pop_front().unwrap();
376 }
377 Ok(Some(size))
378 }
379
380 fn stop(&mut self, _code: u32) {}
381
382 async fn closed(&mut self) -> Result<(), Self::Error> {
383 Ok(())
384 }
385 }
386
387 fn mock_server_setup(negotiated: Version) -> Vec<u8> {
388 let mut encoded = Vec::new();
389 let server = setup::Server {
390 version: negotiated.into(),
391 parameters: Bytes::new(),
392 };
393 server
394 .encode(&mut encoded, Version::Ietf(ietf::Version::Draft14))
395 .unwrap();
396
397 let info = lite::SessionInfo { bitrate: Some(1) };
399 let lite_v = lite::Version::try_from(negotiated).unwrap();
400 info.encode(&mut encoded, lite_v).unwrap();
401
402 encoded
403 }
404
405 async fn run_alpn_lite_fallback_case(protocol: Option<&'static str>) {
406 let fake = FakeSession::new(protocol, mock_server_setup(Version::Lite(lite::Version::Lite01)));
407 let client = Client::new().with_versions(
408 [
409 Version::Lite(lite::Version::Lite03),
410 Version::Lite(lite::Version::Lite02),
411 Version::Lite(lite::Version::Lite01),
412 Version::Ietf(ietf::Version::Draft14),
413 ]
414 .into(),
415 );
416
417 let _session = client.connect(fake.clone()).await.unwrap();
418
419 let mut setup_bytes = Bytes::from(fake.control_writes());
421 let setup = setup::Client::decode(&mut setup_bytes, Version::Ietf(ietf::Version::Draft14)).unwrap();
422 let advertised: Vec<Version> = setup.versions.iter().map(|v| Version::try_from(*v).unwrap()).collect();
423 assert_eq!(
424 advertised,
425 vec![
426 Version::Lite(lite::Version::Lite02),
427 Version::Lite(lite::Version::Lite01),
428 Version::Ietf(ietf::Version::Draft14),
429 ]
430 );
431
432 let (code, _) = fake.wait_for_first_close().await;
435 assert_eq!(code, Error::Cancel.to_code());
436 }
437
438 #[tokio::test(start_paused = true)]
439 async fn alpn_lite_falls_back_to_draft14_and_switches_version_post_setup() {
440 run_alpn_lite_fallback_case(Some(ALPN_LITE)).await;
441 }
442
443 #[tokio::test(start_paused = true)]
444 async fn no_alpn_falls_back_to_draft14_and_switches_version_post_setup() {
445 run_alpn_lite_fallback_case(None).await;
446 }
447}