1use crate::{
2 ALPN_14, ALPN_15, ALPN_16, ALPN_17, ALPN_LITE, ALPN_LITE_03, Error, NEGOTIATED, OriginConsumer, OriginProducer,
3 Session, Version, Versions,
4 coding::{self, Decode, Encode, Reader, Stream, Writer},
5 ietf, lite, setup,
6};
7
8#[derive(Default, Clone)]
10pub struct Client {
11 publish: Option<OriginConsumer>,
12 consume: Option<OriginProducer>,
13 versions: Versions,
14}
15
16impl Client {
17 pub fn new() -> Self {
18 Default::default()
19 }
20
21 pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
22 self.publish = publish.into();
23 self
24 }
25
26 pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
27 self.consume = consume.into();
28 self
29 }
30
31 pub fn with_versions(mut self, versions: Versions) -> Self {
32 self.versions = versions;
33 self
34 }
35
36 pub async fn connect<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
38 if self.publish.is_none() && self.consume.is_none() {
39 tracing::warn!("not publishing or consuming anything");
40 }
41
42 let (encoding, supported) = match session.protocol() {
45 Some(ALPN_17) => {
46 let v = self
47 .versions
48 .select(Version::Ietf(ietf::Version::Draft17))
49 .ok_or(Error::Version)?;
50
51 let ietf_v = ietf::Version::Draft17;
52
53 let mut parameters = ietf::Parameters::default();
55 parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
56 let parameters = parameters.encode_bytes(ietf_v)?;
57
58 let client_setup = setup::Client {
59 versions: coding::Versions::from([v.into()]),
60 parameters,
61 };
62
63 let send_fut = async {
65 let send = session.open_uni().await.map_err(Error::from_transport)?;
66 let mut writer: Writer<S::SendStream, Version> = Writer::new(send, v);
67 writer.encode(&0x2F00u64).await?;
69 writer.encode(&client_setup).await?;
71 Ok::<_, Error>(writer)
72 };
73
74 let recv_fut = async {
75 let recv = session.accept_uni().await.map_err(Error::from_transport)?;
76 let mut reader: Reader<S::RecvStream, Version> = Reader::new(recv, v);
77 let stream_type: u64 = reader.decode().await?;
79 if stream_type != 0x2F00 {
80 return Err(Error::UnexpectedStream);
81 }
82 let server: setup::Server = reader.decode().await?;
84 Ok::<_, Error>((reader, server))
85 };
86
87 let (send_result, recv_result) = tokio::join!(send_fut, recv_fut);
88 let writer = send_result?;
89 let (reader, _server_setup) = recv_result?;
90
91 let stream = Stream {
93 writer: writer.with_version(ietf_v),
94 reader: reader.with_version(ietf_v),
95 };
96
97 ietf::start(
98 session.clone(),
99 stream,
100 None, true,
102 self.publish.clone(),
103 self.consume.clone(),
104 ietf_v,
105 )?;
106
107 tracing::debug!(version = ?v, "connected");
108 return Ok(Session::new(session, v));
109 }
110 Some(ALPN_16) => {
111 let v = self
112 .versions
113 .select(Version::Ietf(ietf::Version::Draft16))
114 .ok_or(Error::Version)?;
115 (v, v.into())
116 }
117 Some(ALPN_15) => {
118 let v = self
119 .versions
120 .select(Version::Ietf(ietf::Version::Draft15))
121 .ok_or(Error::Version)?;
122 (v, v.into())
123 }
124 Some(ALPN_14) => {
125 let v = self
126 .versions
127 .select(Version::Ietf(ietf::Version::Draft14))
128 .ok_or(Error::Version)?;
129 (v, v.into())
130 }
131 Some(ALPN_LITE_03) => {
132 self.versions
133 .select(Version::Lite(lite::Version::Lite03))
134 .ok_or(Error::Version)?;
135
136 lite::start(
138 session.clone(),
139 None,
140 self.publish.clone(),
141 self.consume.clone(),
142 lite::Version::Lite03,
143 )?;
144
145 return Ok(Session::new(session, lite::Version::Lite03.into()));
146 }
147 Some(ALPN_LITE) | None => {
148 let supported = self.versions.filter(&NEGOTIATED.into()).ok_or(Error::Version)?;
149 (Version::Ietf(ietf::Version::Draft14), supported)
150 }
151 Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
152 };
153
154 let mut stream = Stream::open(&session, encoding).await?;
155
156 let ietf_encoding = ietf::Version::try_from(encoding).map_err(|_| Error::Version)?;
158
159 let mut parameters = ietf::Parameters::default();
160 parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
161 parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
162 let parameters = parameters.encode_bytes(ietf_encoding)?;
163
164 let client = setup::Client {
165 versions: supported.clone().into(),
166 parameters,
167 };
168
169 tracing::trace!(?client, "sending client setup");
171 stream.writer.encode(&client).await?;
172
173 let mut server: setup::Server = stream.reader.decode().await?;
174 tracing::trace!(?server, "received server setup");
175
176 let version = supported
177 .iter()
178 .find(|v| coding::Version::from(**v) == server.version)
179 .copied()
180 .ok_or(Error::Version)?;
181
182 match version {
183 Version::Lite(v) => {
184 let stream = stream.with_version(v);
185 lite::start(
186 session.clone(),
187 Some(stream),
188 self.publish.clone(),
189 self.consume.clone(),
190 v,
191 )?;
192 }
193 Version::Ietf(v) => {
194 let parameters = ietf::Parameters::decode(&mut server.parameters, v)?;
196 let request_id_max = parameters
197 .get_varint(ietf::ParameterVarInt::MaxRequestId)
198 .map(ietf::RequestId);
199
200 let stream = stream.with_version(v);
201 ietf::start(
202 session.clone(),
203 stream,
204 request_id_max,
205 true,
206 self.publish.clone(),
207 self.consume.clone(),
208 v,
209 )?;
210 }
211 }
212
213 Ok(Session::new(session, version))
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use std::{
221 collections::VecDeque,
222 sync::{Arc, Mutex},
223 };
224
225 use crate::coding::{Decode, Encode};
226 use bytes::{BufMut, Bytes};
227
228 #[derive(Debug, Clone, Default)]
229 struct FakeError;
230
231 impl std::fmt::Display for FakeError {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 write!(f, "fake transport error")
234 }
235 }
236
237 impl std::error::Error for FakeError {}
238
239 impl web_transport_trait::Error for FakeError {
240 fn session_error(&self) -> Option<(u32, String)> {
241 Some((0, "closed".to_string()))
242 }
243 }
244
245 #[derive(Clone, Default)]
246 struct FakeSession {
247 state: Arc<FakeSessionState>,
248 }
249
250 #[derive(Default)]
251 struct FakeSessionState {
252 protocol: Option<&'static str>,
253 control_stream: Mutex<Option<(FakeSendStream, FakeRecvStream)>>,
254 close_events: Mutex<Vec<(u32, String)>>,
255 close_notify: tokio::sync::Notify,
256 control_writes: Arc<Mutex<Vec<u8>>>,
257 }
258
259 impl FakeSession {
260 fn new(protocol: Option<&'static str>, server_control_bytes: Vec<u8>) -> Self {
261 let writes = Arc::new(Mutex::new(Vec::new()));
262 let send = FakeSendStream { writes: writes.clone() };
263 let recv = FakeRecvStream {
264 data: VecDeque::from(server_control_bytes),
265 };
266 let state = FakeSessionState {
267 protocol,
268 control_stream: Mutex::new(Some((send, recv))),
269 close_events: Mutex::new(Vec::new()),
270 close_notify: tokio::sync::Notify::new(),
271 control_writes: writes,
272 };
273 Self { state: Arc::new(state) }
274 }
275
276 fn control_writes(&self) -> Vec<u8> {
277 self.state.control_writes.lock().unwrap().clone()
278 }
279
280 async fn wait_for_first_close(&self) -> (u32, String) {
281 loop {
282 let notified = self.state.close_notify.notified();
283 if let Some(close) = self.state.close_events.lock().unwrap().first().cloned() {
284 return close;
285 }
286 notified.await;
287 }
288 }
289 }
290
291 impl web_transport_trait::Session for FakeSession {
292 type SendStream = FakeSendStream;
293 type RecvStream = FakeRecvStream;
294 type Error = FakeError;
295
296 async fn accept_uni(&self) -> Result<Self::RecvStream, Self::Error> {
297 std::future::pending().await
298 }
299
300 async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
301 std::future::pending().await
302 }
303
304 async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
305 self.state.control_stream.lock().unwrap().take().ok_or(FakeError)
306 }
307
308 async fn open_uni(&self) -> Result<Self::SendStream, Self::Error> {
309 std::future::pending().await
310 }
311
312 fn send_datagram(&self, _payload: Bytes) -> Result<(), Self::Error> {
313 Ok(())
314 }
315
316 async fn recv_datagram(&self) -> Result<Bytes, Self::Error> {
317 std::future::pending().await
318 }
319
320 fn max_datagram_size(&self) -> usize {
321 1200
322 }
323
324 fn protocol(&self) -> Option<&str> {
325 self.state.protocol
326 }
327
328 fn close(&self, code: u32, reason: &str) {
329 self.state.close_events.lock().unwrap().push((code, reason.to_string()));
330 self.state.close_notify.notify_waiters();
331 }
332
333 async fn closed(&self) -> Self::Error {
334 self.state.close_notify.notified().await;
335 FakeError
336 }
337 }
338
339 #[derive(Clone, Default)]
340 struct FakeSendStream {
341 writes: Arc<Mutex<Vec<u8>>>,
342 }
343
344 impl web_transport_trait::SendStream for FakeSendStream {
345 type Error = FakeError;
346
347 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
348 self.writes.lock().unwrap().put_slice(buf);
349 Ok(buf.len())
350 }
351
352 fn set_priority(&mut self, _order: u8) {}
353
354 fn finish(&mut self) -> Result<(), Self::Error> {
355 Ok(())
356 }
357
358 fn reset(&mut self, _code: u32) {}
359
360 async fn closed(&mut self) -> Result<(), Self::Error> {
361 Ok(())
362 }
363 }
364
365 struct FakeRecvStream {
366 data: VecDeque<u8>,
367 }
368
369 impl web_transport_trait::RecvStream for FakeRecvStream {
370 type Error = FakeError;
371
372 async fn read(&mut self, dst: &mut [u8]) -> Result<Option<usize>, Self::Error> {
373 if self.data.is_empty() {
374 return Ok(None);
375 }
376
377 let size = dst.len().min(self.data.len());
378 for slot in dst.iter_mut().take(size) {
379 *slot = self.data.pop_front().unwrap();
380 }
381 Ok(Some(size))
382 }
383
384 fn stop(&mut self, _code: u32) {}
385
386 async fn closed(&mut self) -> Result<(), Self::Error> {
387 Ok(())
388 }
389 }
390
391 fn mock_server_setup(negotiated: Version) -> Vec<u8> {
392 let mut encoded = Vec::new();
393 let server = setup::Server {
394 version: negotiated.into(),
395 parameters: Bytes::new(),
396 };
397 server
398 .encode(&mut encoded, Version::Ietf(ietf::Version::Draft14))
399 .unwrap();
400
401 let info = lite::SessionInfo { bitrate: Some(1) };
403 let lite_v = lite::Version::try_from(negotiated).unwrap();
404 info.encode(&mut encoded, lite_v).unwrap();
405
406 encoded
407 }
408
409 async fn run_alpn_lite_fallback_case(protocol: Option<&'static str>) {
410 let fake = FakeSession::new(protocol, mock_server_setup(Version::Lite(lite::Version::Lite01)));
411 let client = Client::new().with_versions(
412 [
413 Version::Lite(lite::Version::Lite03),
414 Version::Lite(lite::Version::Lite02),
415 Version::Lite(lite::Version::Lite01),
416 Version::Ietf(ietf::Version::Draft14),
417 ]
418 .into(),
419 );
420
421 let _session = client.connect(fake.clone()).await.unwrap();
422
423 let mut setup_bytes = Bytes::from(fake.control_writes());
425 let setup = setup::Client::decode(&mut setup_bytes, Version::Ietf(ietf::Version::Draft14)).unwrap();
426 let advertised: Vec<Version> = setup.versions.iter().map(|v| Version::try_from(*v).unwrap()).collect();
427 assert_eq!(
428 advertised,
429 vec![
430 Version::Lite(lite::Version::Lite02),
431 Version::Lite(lite::Version::Lite01),
432 Version::Ietf(ietf::Version::Draft14),
433 ]
434 );
435
436 let (code, _) = fake.wait_for_first_close().await;
439 assert_eq!(code, Error::Cancel.to_code());
440 }
441
442 #[tokio::test(start_paused = true)]
443 async fn alpn_lite_falls_back_to_draft14_and_switches_version_post_setup() {
444 run_alpn_lite_fallback_case(Some(ALPN_LITE)).await;
445 }
446
447 #[tokio::test(start_paused = true)]
448 async fn no_alpn_falls_back_to_draft14_and_switches_version_post_setup() {
449 run_alpn_lite_fallback_case(None).await;
450 }
451}