1use crate::{
2 ALPN_14, ALPN_15, ALPN_16, ALPN_17, ALPN_LITE, ALPN_LITE_03, ALPN_LITE_04, Error, NEGOTIATED, OriginConsumer,
3 OriginProducer, Session, Version, Versions,
4 coding::{self, Decode, Encode, Stream},
5 ietf, lite, setup,
6};
7
8#[derive(Default, Clone)]
10pub struct Client {
11 publish: Option<OriginConsumer>,
12 consume: Option<OriginProducer>,
13 versions: Versions,
14}
15
16impl Client {
17 pub fn new() -> Self {
18 Default::default()
19 }
20
21 pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
22 self.publish = publish.into();
23 self
24 }
25
26 pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
27 self.consume = consume.into();
28 self
29 }
30
31 pub fn with_origin(self, origin: OriginProducer) -> Self {
35 let consumer = origin.consume();
36 self.with_publish(consumer).with_consume(origin)
37 }
38
39 pub fn with_versions(mut self, versions: Versions) -> Self {
40 self.versions = versions;
41 self
42 }
43
44 pub async fn connect<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
46 if self.publish.is_none() && self.consume.is_none() {
47 tracing::warn!("not publishing or consuming anything");
48 }
49
50 let (encoding, supported) = match session.protocol() {
53 Some(ALPN_17) => {
54 let v = self
55 .versions
56 .select(Version::Ietf(ietf::Version::Draft17))
57 .ok_or(Error::Version)?;
58
59 ietf::start(
61 session.clone(),
62 None,
63 None,
64 true,
65 self.publish.clone(),
66 self.consume.clone(),
67 ietf::Version::Draft17,
68 )?;
69
70 tracing::debug!(version = ?v, "connected");
71 return Ok(Session::new(session, v, None));
72 }
73 Some(ALPN_16) => {
74 let v = self
75 .versions
76 .select(Version::Ietf(ietf::Version::Draft16))
77 .ok_or(Error::Version)?;
78 (v, v.into())
79 }
80 Some(ALPN_15) => {
81 let v = self
82 .versions
83 .select(Version::Ietf(ietf::Version::Draft15))
84 .ok_or(Error::Version)?;
85 (v, v.into())
86 }
87 Some(ALPN_14) => {
88 let v = self
89 .versions
90 .select(Version::Ietf(ietf::Version::Draft14))
91 .ok_or(Error::Version)?;
92 (v, v.into())
93 }
94 Some(ALPN_LITE_04) => {
95 self.versions
96 .select(Version::Lite(lite::Version::Lite04))
97 .ok_or(Error::Version)?;
98
99 let recv_bw = lite::start(
100 session.clone(),
101 None,
102 self.publish.clone(),
103 self.consume.clone(),
104 lite::Version::Lite04,
105 )?;
106
107 return Ok(Session::new(session, lite::Version::Lite04.into(), recv_bw));
108 }
109 Some(ALPN_LITE_03) => {
110 self.versions
111 .select(Version::Lite(lite::Version::Lite03))
112 .ok_or(Error::Version)?;
113
114 let recv_bw = lite::start(
116 session.clone(),
117 None,
118 self.publish.clone(),
119 self.consume.clone(),
120 lite::Version::Lite03,
121 )?;
122
123 return Ok(Session::new(session, lite::Version::Lite03.into(), recv_bw));
124 }
125 Some(ALPN_LITE) | None => {
126 let supported = self.versions.filter(&NEGOTIATED.into()).ok_or(Error::Version)?;
127 (Version::Ietf(ietf::Version::Draft14), supported)
128 }
129 Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
130 };
131
132 let mut stream = Stream::open(&session, encoding).await?;
133
134 let ietf_encoding = ietf::Version::try_from(encoding).map_err(|_| Error::Version)?;
136
137 let mut parameters = ietf::Parameters::default();
138 parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
139 parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
140 let parameters = parameters.encode_bytes(ietf_encoding)?;
141
142 let client = setup::Client {
143 versions: supported.clone().into(),
144 parameters,
145 };
146
147 stream.writer.encode(&client).await?;
148
149 let mut server: setup::Server = stream.reader.decode().await?;
150
151 let version = supported
152 .iter()
153 .find(|v| coding::Version::from(**v) == server.version)
154 .copied()
155 .ok_or(Error::Version)?;
156
157 let recv_bw = match version {
158 Version::Lite(v) => {
159 let stream = stream.with_version(v);
160 lite::start(
161 session.clone(),
162 Some(stream),
163 self.publish.clone(),
164 self.consume.clone(),
165 v,
166 )?
167 }
168 Version::Ietf(v) => {
169 let parameters = ietf::Parameters::decode(&mut server.parameters, v)?;
171 let request_id_max = parameters
172 .get_varint(ietf::ParameterVarInt::MaxRequestId)
173 .map(ietf::RequestId);
174
175 let stream = stream.with_version(v);
176 ietf::start(
177 session.clone(),
178 Some(stream),
179 request_id_max,
180 true,
181 self.publish.clone(),
182 self.consume.clone(),
183 v,
184 )?;
185 None
186 }
187 };
188
189 Ok(Session::new(session, version, recv_bw))
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use std::{
197 collections::VecDeque,
198 sync::{Arc, Mutex},
199 };
200
201 use crate::coding::{Decode, Encode};
202 use bytes::{BufMut, Bytes};
203
204 #[derive(Debug, Clone, Default)]
205 struct FakeError;
206
207 impl std::fmt::Display for FakeError {
208 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 write!(f, "fake transport error")
210 }
211 }
212
213 impl std::error::Error for FakeError {}
214
215 impl web_transport_trait::Error for FakeError {
216 fn session_error(&self) -> Option<(u32, String)> {
217 Some((0, "closed".to_string()))
218 }
219 }
220
221 #[derive(Clone, Default)]
222 struct FakeSession {
223 state: Arc<FakeSessionState>,
224 }
225
226 #[derive(Default)]
227 struct FakeSessionState {
228 protocol: Option<&'static str>,
229 control_stream: Mutex<Option<(FakeSendStream, FakeRecvStream)>>,
230 close_events: Mutex<Vec<(u32, String)>>,
231 close_notify: tokio::sync::Notify,
232 control_writes: Arc<Mutex<Vec<u8>>>,
233 }
234
235 impl FakeSession {
236 fn new(protocol: Option<&'static str>, server_control_bytes: Vec<u8>) -> Self {
237 let writes = Arc::new(Mutex::new(Vec::new()));
238 let send = FakeSendStream { writes: writes.clone() };
239 let recv = FakeRecvStream {
240 data: VecDeque::from(server_control_bytes),
241 };
242 let state = FakeSessionState {
243 protocol,
244 control_stream: Mutex::new(Some((send, recv))),
245 close_events: Mutex::new(Vec::new()),
246 close_notify: tokio::sync::Notify::new(),
247 control_writes: writes,
248 };
249 Self { state: Arc::new(state) }
250 }
251
252 fn control_writes(&self) -> Vec<u8> {
253 self.state.control_writes.lock().unwrap().clone()
254 }
255
256 async fn wait_for_first_close(&self) -> (u32, String) {
257 loop {
258 let notified = self.state.close_notify.notified();
259 if let Some(close) = self.state.close_events.lock().unwrap().first().cloned() {
260 return close;
261 }
262 notified.await;
263 }
264 }
265 }
266
267 impl web_transport_trait::Session for FakeSession {
268 type SendStream = FakeSendStream;
269 type RecvStream = FakeRecvStream;
270 type Error = FakeError;
271
272 async fn accept_uni(&self) -> Result<Self::RecvStream, Self::Error> {
273 std::future::pending().await
274 }
275
276 async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
277 std::future::pending().await
278 }
279
280 async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
281 self.state.control_stream.lock().unwrap().take().ok_or(FakeError)
282 }
283
284 async fn open_uni(&self) -> Result<Self::SendStream, Self::Error> {
285 std::future::pending().await
286 }
287
288 fn send_datagram(&self, _payload: Bytes) -> Result<(), Self::Error> {
289 Ok(())
290 }
291
292 async fn recv_datagram(&self) -> Result<Bytes, Self::Error> {
293 std::future::pending().await
294 }
295
296 fn max_datagram_size(&self) -> usize {
297 1200
298 }
299
300 fn protocol(&self) -> Option<&str> {
301 self.state.protocol
302 }
303
304 fn close(&self, code: u32, reason: &str) {
305 self.state.close_events.lock().unwrap().push((code, reason.to_string()));
306 self.state.close_notify.notify_waiters();
307 }
308
309 async fn closed(&self) -> Self::Error {
310 self.state.close_notify.notified().await;
311 FakeError
312 }
313 }
314
315 #[derive(Clone, Default)]
316 struct FakeSendStream {
317 writes: Arc<Mutex<Vec<u8>>>,
318 }
319
320 impl web_transport_trait::SendStream for FakeSendStream {
321 type Error = FakeError;
322
323 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
324 self.writes.lock().unwrap().put_slice(buf);
325 Ok(buf.len())
326 }
327
328 fn set_priority(&mut self, _order: u8) {}
329
330 fn finish(&mut self) -> Result<(), Self::Error> {
331 Ok(())
332 }
333
334 fn reset(&mut self, _code: u32) {}
335
336 async fn closed(&mut self) -> Result<(), Self::Error> {
337 Ok(())
338 }
339 }
340
341 struct FakeRecvStream {
342 data: VecDeque<u8>,
343 }
344
345 impl web_transport_trait::RecvStream for FakeRecvStream {
346 type Error = FakeError;
347
348 async fn read(&mut self, dst: &mut [u8]) -> Result<Option<usize>, Self::Error> {
349 if self.data.is_empty() {
350 return Ok(None);
351 }
352
353 let size = dst.len().min(self.data.len());
354 for slot in dst.iter_mut().take(size) {
355 *slot = self.data.pop_front().unwrap();
356 }
357 Ok(Some(size))
358 }
359
360 fn stop(&mut self, _code: u32) {}
361
362 async fn closed(&mut self) -> Result<(), Self::Error> {
363 Ok(())
364 }
365 }
366
367 fn mock_server_setup(negotiated: Version) -> Vec<u8> {
368 let mut encoded = Vec::new();
369 let server = setup::Server {
370 version: negotiated.into(),
371 parameters: Bytes::new(),
372 };
373 server
374 .encode(&mut encoded, Version::Ietf(ietf::Version::Draft14))
375 .unwrap();
376
377 let info = lite::SessionInfo { bitrate: Some(1) };
379 let lite_v = lite::Version::try_from(negotiated).unwrap();
380 info.encode(&mut encoded, lite_v).unwrap();
381
382 encoded
383 }
384
385 async fn run_alpn_lite_fallback_case(protocol: Option<&'static str>) {
386 let fake = FakeSession::new(protocol, mock_server_setup(Version::Lite(lite::Version::Lite01)));
387 let client = Client::new().with_versions(
388 [
389 Version::Lite(lite::Version::Lite03),
390 Version::Lite(lite::Version::Lite02),
391 Version::Lite(lite::Version::Lite01),
392 Version::Ietf(ietf::Version::Draft14),
393 ]
394 .into(),
395 );
396
397 let _session = client.connect(fake.clone()).await.unwrap();
398
399 let mut setup_bytes = Bytes::from(fake.control_writes());
401 let setup = setup::Client::decode(&mut setup_bytes, Version::Ietf(ietf::Version::Draft14)).unwrap();
402 let advertised: Vec<Version> = setup.versions.iter().map(|v| Version::try_from(*v).unwrap()).collect();
403 assert_eq!(
404 advertised,
405 vec![
406 Version::Lite(lite::Version::Lite02),
407 Version::Lite(lite::Version::Lite01),
408 Version::Ietf(ietf::Version::Draft14),
409 ]
410 );
411
412 let (code, _) = fake.wait_for_first_close().await;
415 assert_eq!(code, Error::Cancel.to_code());
416 }
417
418 #[tokio::test(start_paused = true)]
419 async fn alpn_lite_falls_back_to_draft14_and_switches_version_post_setup() {
420 run_alpn_lite_fallback_case(Some(ALPN_LITE)).await;
421 }
422
423 #[tokio::test(start_paused = true)]
424 async fn no_alpn_falls_back_to_draft14_and_switches_version_post_setup() {
425 run_alpn_lite_fallback_case(None).await;
426 }
427}