1use crate::{
2 ALPN_14, ALPN_15, ALPN_16, ALPN_17, ALPN_LITE, ALPN_LITE_03, Error, NEGOTIATED, OriginConsumer, OriginProducer,
3 Session, Version, Versions,
4 coding::{self, Decode, Encode, Reader, Stream, Writer},
5 ietf, lite, setup,
6};
7
8#[derive(Default, Clone)]
10pub struct Client {
11 publish: Option<OriginConsumer>,
12 consume: Option<OriginProducer>,
13 versions: Versions,
14}
15
16impl Client {
17 pub fn new() -> Self {
18 Default::default()
19 }
20
21 pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
22 self.publish = publish.into();
23 self
24 }
25
26 pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
27 self.consume = consume.into();
28 self
29 }
30
31 pub fn with_versions(mut self, versions: Versions) -> Self {
32 self.versions = versions;
33 self
34 }
35
36 pub async fn connect<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
38 if self.publish.is_none() && self.consume.is_none() {
39 tracing::warn!("not publishing or consuming anything");
40 }
41
42 let (encoding, supported) = match session.protocol() {
45 Some(ALPN_17) => {
46 let v = self
47 .versions
48 .select(Version::Ietf(ietf::Version::Draft17))
49 .ok_or(Error::Version)?;
50
51 let ietf_v = ietf::Version::Draft17;
52
53 let mut parameters = ietf::Parameters::default();
55 parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
56 let parameters = parameters.encode_bytes(ietf_v)?;
57
58 let client_setup = setup::Client {
59 versions: coding::Versions::from([v.into()]),
60 parameters,
61 };
62
63 let send_fut = async {
65 let send = session.open_uni().await.map_err(Error::from_transport)?;
66 let mut writer: Writer<S::SendStream, Version> = Writer::new(send, v);
67 writer.encode(&client_setup).await?;
69 Ok::<_, Error>(writer)
70 };
71
72 let recv_fut = async {
73 let recv = session.accept_uni().await.map_err(Error::from_transport)?;
74 let mut reader: Reader<S::RecvStream, Version> = Reader::new(recv, v);
75 let server: setup::Server = reader.decode().await?;
77 Ok::<_, Error>((reader, server))
78 };
79
80 let (send_result, recv_result) = tokio::join!(send_fut, recv_fut);
81 let writer = send_result?;
82 let (reader, _server_setup) = recv_result?;
83
84 let stream = Stream {
86 writer: writer.with_version(ietf_v),
87 reader: reader.with_version(ietf_v),
88 };
89
90 ietf::start(
91 session.clone(),
92 stream,
93 None, true,
95 self.publish.clone(),
96 self.consume.clone(),
97 ietf_v,
98 )?;
99
100 tracing::debug!(version = ?v, "connected");
101 return Ok(Session::new(session, v));
102 }
103 Some(ALPN_16) => {
104 let v = self
105 .versions
106 .select(Version::Ietf(ietf::Version::Draft16))
107 .ok_or(Error::Version)?;
108 (v, v.into())
109 }
110 Some(ALPN_15) => {
111 let v = self
112 .versions
113 .select(Version::Ietf(ietf::Version::Draft15))
114 .ok_or(Error::Version)?;
115 (v, v.into())
116 }
117 Some(ALPN_14) => {
118 let v = self
119 .versions
120 .select(Version::Ietf(ietf::Version::Draft14))
121 .ok_or(Error::Version)?;
122 (v, v.into())
123 }
124 Some(ALPN_LITE_03) => {
125 self.versions
126 .select(Version::Lite(lite::Version::Lite03))
127 .ok_or(Error::Version)?;
128
129 lite::start(
131 session.clone(),
132 None,
133 self.publish.clone(),
134 self.consume.clone(),
135 lite::Version::Lite03,
136 )?;
137
138 return Ok(Session::new(session, lite::Version::Lite03.into()));
139 }
140 Some(ALPN_LITE) | None => {
141 let supported = self.versions.filter(&NEGOTIATED.into()).ok_or(Error::Version)?;
142 (Version::Ietf(ietf::Version::Draft14), supported)
143 }
144 Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
145 };
146
147 let mut stream = Stream::open(&session, encoding).await?;
148
149 let ietf_encoding = ietf::Version::try_from(encoding).map_err(|_| Error::Version)?;
151
152 let mut parameters = ietf::Parameters::default();
153 parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
154 parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
155 let parameters = parameters.encode_bytes(ietf_encoding)?;
156
157 let client = setup::Client {
158 versions: supported.clone().into(),
159 parameters,
160 };
161
162 tracing::trace!(?client, "sending client setup");
164 stream.writer.encode(&client).await?;
165
166 let mut server: setup::Server = stream.reader.decode().await?;
167 tracing::trace!(?server, "received server setup");
168
169 let version = supported
170 .iter()
171 .find(|v| coding::Version::from(**v) == server.version)
172 .copied()
173 .ok_or(Error::Version)?;
174
175 match version {
176 Version::Lite(v) => {
177 let stream = stream.with_version(v);
178 lite::start(
179 session.clone(),
180 Some(stream),
181 self.publish.clone(),
182 self.consume.clone(),
183 v,
184 )?;
185 }
186 Version::Ietf(v) => {
187 let parameters = ietf::Parameters::decode(&mut server.parameters, v)?;
189 let request_id_max = parameters
190 .get_varint(ietf::ParameterVarInt::MaxRequestId)
191 .map(ietf::RequestId);
192
193 let stream = stream.with_version(v);
194 ietf::start(
195 session.clone(),
196 stream,
197 request_id_max,
198 true,
199 self.publish.clone(),
200 self.consume.clone(),
201 v,
202 )?;
203 }
204 }
205
206 Ok(Session::new(session, version))
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use std::{
214 collections::VecDeque,
215 sync::{Arc, Mutex},
216 };
217
218 use crate::coding::{Decode, Encode};
219 use bytes::{BufMut, Bytes};
220
221 #[derive(Debug, Clone, Default)]
222 struct FakeError;
223
224 impl std::fmt::Display for FakeError {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 write!(f, "fake transport error")
227 }
228 }
229
230 impl std::error::Error for FakeError {}
231
232 impl web_transport_trait::Error for FakeError {
233 fn session_error(&self) -> Option<(u32, String)> {
234 Some((0, "closed".to_string()))
235 }
236 }
237
238 #[derive(Clone, Default)]
239 struct FakeSession {
240 state: Arc<FakeSessionState>,
241 }
242
243 #[derive(Default)]
244 struct FakeSessionState {
245 protocol: Option<&'static str>,
246 control_stream: Mutex<Option<(FakeSendStream, FakeRecvStream)>>,
247 close_events: Mutex<Vec<(u32, String)>>,
248 close_notify: tokio::sync::Notify,
249 control_writes: Arc<Mutex<Vec<u8>>>,
250 }
251
252 impl FakeSession {
253 fn new(protocol: Option<&'static str>, server_control_bytes: Vec<u8>) -> Self {
254 let writes = Arc::new(Mutex::new(Vec::new()));
255 let send = FakeSendStream { writes: writes.clone() };
256 let recv = FakeRecvStream {
257 data: VecDeque::from(server_control_bytes),
258 };
259 let state = FakeSessionState {
260 protocol,
261 control_stream: Mutex::new(Some((send, recv))),
262 close_events: Mutex::new(Vec::new()),
263 close_notify: tokio::sync::Notify::new(),
264 control_writes: writes,
265 };
266 Self { state: Arc::new(state) }
267 }
268
269 fn control_writes(&self) -> Vec<u8> {
270 self.state.control_writes.lock().unwrap().clone()
271 }
272
273 async fn wait_for_first_close(&self) -> (u32, String) {
274 loop {
275 let notified = self.state.close_notify.notified();
276 if let Some(close) = self.state.close_events.lock().unwrap().first().cloned() {
277 return close;
278 }
279 notified.await;
280 }
281 }
282 }
283
284 impl web_transport_trait::Session for FakeSession {
285 type SendStream = FakeSendStream;
286 type RecvStream = FakeRecvStream;
287 type Error = FakeError;
288
289 async fn accept_uni(&self) -> Result<Self::RecvStream, Self::Error> {
290 std::future::pending().await
291 }
292
293 async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
294 std::future::pending().await
295 }
296
297 async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
298 self.state.control_stream.lock().unwrap().take().ok_or(FakeError)
299 }
300
301 async fn open_uni(&self) -> Result<Self::SendStream, Self::Error> {
302 std::future::pending().await
303 }
304
305 fn send_datagram(&self, _payload: Bytes) -> Result<(), Self::Error> {
306 Ok(())
307 }
308
309 async fn recv_datagram(&self) -> Result<Bytes, Self::Error> {
310 std::future::pending().await
311 }
312
313 fn max_datagram_size(&self) -> usize {
314 1200
315 }
316
317 fn protocol(&self) -> Option<&str> {
318 self.state.protocol
319 }
320
321 fn close(&self, code: u32, reason: &str) {
322 self.state.close_events.lock().unwrap().push((code, reason.to_string()));
323 self.state.close_notify.notify_waiters();
324 }
325
326 async fn closed(&self) -> Self::Error {
327 self.state.close_notify.notified().await;
328 FakeError
329 }
330 }
331
332 #[derive(Clone, Default)]
333 struct FakeSendStream {
334 writes: Arc<Mutex<Vec<u8>>>,
335 }
336
337 impl web_transport_trait::SendStream for FakeSendStream {
338 type Error = FakeError;
339
340 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
341 self.writes.lock().unwrap().put_slice(buf);
342 Ok(buf.len())
343 }
344
345 fn set_priority(&mut self, _order: u8) {}
346
347 fn finish(&mut self) -> Result<(), Self::Error> {
348 Ok(())
349 }
350
351 fn reset(&mut self, _code: u32) {}
352
353 async fn closed(&mut self) -> Result<(), Self::Error> {
354 Ok(())
355 }
356 }
357
358 struct FakeRecvStream {
359 data: VecDeque<u8>,
360 }
361
362 impl web_transport_trait::RecvStream for FakeRecvStream {
363 type Error = FakeError;
364
365 async fn read(&mut self, dst: &mut [u8]) -> Result<Option<usize>, Self::Error> {
366 if self.data.is_empty() {
367 return Ok(None);
368 }
369
370 let size = dst.len().min(self.data.len());
371 for slot in dst.iter_mut().take(size) {
372 *slot = self.data.pop_front().unwrap();
373 }
374 Ok(Some(size))
375 }
376
377 fn stop(&mut self, _code: u32) {}
378
379 async fn closed(&mut self) -> Result<(), Self::Error> {
380 Ok(())
381 }
382 }
383
384 fn mock_server_setup(negotiated: Version) -> Vec<u8> {
385 let mut encoded = Vec::new();
386 let server = setup::Server {
387 version: negotiated.into(),
388 parameters: Bytes::new(),
389 };
390 server
391 .encode(&mut encoded, Version::Ietf(ietf::Version::Draft14))
392 .unwrap();
393
394 let info = lite::SessionInfo { bitrate: Some(1) };
396 let lite_v = lite::Version::try_from(negotiated).unwrap();
397 info.encode(&mut encoded, lite_v).unwrap();
398
399 encoded
400 }
401
402 async fn run_alpn_lite_fallback_case(protocol: Option<&'static str>) {
403 let fake = FakeSession::new(protocol, mock_server_setup(Version::Lite(lite::Version::Lite01)));
404 let client = Client::new().with_versions(
405 [
406 Version::Lite(lite::Version::Lite03),
407 Version::Lite(lite::Version::Lite02),
408 Version::Lite(lite::Version::Lite01),
409 Version::Ietf(ietf::Version::Draft14),
410 ]
411 .into(),
412 );
413
414 let _session = client.connect(fake.clone()).await.unwrap();
415
416 let mut setup_bytes = Bytes::from(fake.control_writes());
418 let setup = setup::Client::decode(&mut setup_bytes, Version::Ietf(ietf::Version::Draft14)).unwrap();
419 let advertised: Vec<Version> = setup.versions.iter().map(|v| Version::try_from(*v).unwrap()).collect();
420 assert_eq!(
421 advertised,
422 vec![
423 Version::Lite(lite::Version::Lite02),
424 Version::Lite(lite::Version::Lite01),
425 Version::Ietf(ietf::Version::Draft14),
426 ]
427 );
428
429 let (code, _) = fake.wait_for_first_close().await;
432 assert_eq!(code, Error::Cancel.to_code());
433 }
434
435 #[tokio::test(start_paused = true)]
436 async fn alpn_lite_falls_back_to_draft14_and_switches_version_post_setup() {
437 run_alpn_lite_fallback_case(Some(ALPN_LITE)).await;
438 }
439
440 #[tokio::test(start_paused = true)]
441 async fn no_alpn_falls_back_to_draft14_and_switches_version_post_setup() {
442 run_alpn_lite_fallback_case(None).await;
443 }
444}