1use crate::{
2 ALPN_14, ALPN_15, ALPN_16, ALPN_17, ALPN_LITE, ALPN_LITE_03, Error, NEGOTIATED, OriginConsumer, OriginProducer,
3 Session, Version, Versions,
4 coding::{self, Decode, Encode, Stream},
5 ietf, lite, setup,
6};
7
8#[derive(Default, Clone)]
10pub struct Client {
11 publish: Option<OriginConsumer>,
12 consume: Option<OriginProducer>,
13 versions: Versions,
14}
15
16impl Client {
17 pub fn new() -> Self {
18 Default::default()
19 }
20
21 pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
22 self.publish = publish.into();
23 self
24 }
25
26 pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
27 self.consume = consume.into();
28 self
29 }
30
31 pub fn with_versions(mut self, versions: Versions) -> Self {
32 self.versions = versions;
33 self
34 }
35
36 pub async fn connect<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
38 if self.publish.is_none() && self.consume.is_none() {
39 tracing::warn!("not publishing or consuming anything");
40 }
41
42 let (encoding, supported) = match session.protocol() {
45 Some(ALPN_17) => {
46 let v = self
47 .versions
48 .select(Version::Ietf(ietf::Version::Draft17))
49 .ok_or(Error::Version)?;
50
51 ietf::start(
53 session.clone(),
54 None,
55 None,
56 true,
57 self.publish.clone(),
58 self.consume.clone(),
59 ietf::Version::Draft17,
60 )?;
61
62 tracing::debug!(version = ?v, "connected");
63 return Ok(Session::new(session, v));
64 }
65 Some(ALPN_16) => {
66 let v = self
67 .versions
68 .select(Version::Ietf(ietf::Version::Draft16))
69 .ok_or(Error::Version)?;
70 (v, v.into())
71 }
72 Some(ALPN_15) => {
73 let v = self
74 .versions
75 .select(Version::Ietf(ietf::Version::Draft15))
76 .ok_or(Error::Version)?;
77 (v, v.into())
78 }
79 Some(ALPN_14) => {
80 let v = self
81 .versions
82 .select(Version::Ietf(ietf::Version::Draft14))
83 .ok_or(Error::Version)?;
84 (v, v.into())
85 }
86 Some(ALPN_LITE_03) => {
87 self.versions
88 .select(Version::Lite(lite::Version::Lite03))
89 .ok_or(Error::Version)?;
90
91 lite::start(
93 session.clone(),
94 None,
95 self.publish.clone(),
96 self.consume.clone(),
97 lite::Version::Lite03,
98 )?;
99
100 return Ok(Session::new(session, lite::Version::Lite03.into()));
101 }
102 Some(ALPN_LITE) | None => {
103 let supported = self.versions.filter(&NEGOTIATED.into()).ok_or(Error::Version)?;
104 (Version::Ietf(ietf::Version::Draft14), supported)
105 }
106 Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
107 };
108
109 let mut stream = Stream::open(&session, encoding).await?;
110
111 let ietf_encoding = ietf::Version::try_from(encoding).map_err(|_| Error::Version)?;
113
114 let mut parameters = ietf::Parameters::default();
115 parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
116 parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
117 let parameters = parameters.encode_bytes(ietf_encoding)?;
118
119 let client = setup::Client {
120 versions: supported.clone().into(),
121 parameters,
122 };
123
124 tracing::trace!(?client, "sending client setup");
126 stream.writer.encode(&client).await?;
127
128 let mut server: setup::Server = stream.reader.decode().await?;
129 tracing::trace!(?server, "received server setup");
130
131 let version = supported
132 .iter()
133 .find(|v| coding::Version::from(**v) == server.version)
134 .copied()
135 .ok_or(Error::Version)?;
136
137 match version {
138 Version::Lite(v) => {
139 let stream = stream.with_version(v);
140 lite::start(
141 session.clone(),
142 Some(stream),
143 self.publish.clone(),
144 self.consume.clone(),
145 v,
146 )?;
147 }
148 Version::Ietf(v) => {
149 let parameters = ietf::Parameters::decode(&mut server.parameters, v)?;
151 let request_id_max = parameters
152 .get_varint(ietf::ParameterVarInt::MaxRequestId)
153 .map(ietf::RequestId);
154
155 let stream = stream.with_version(v);
156 ietf::start(
157 session.clone(),
158 Some(stream),
159 request_id_max,
160 true,
161 self.publish.clone(),
162 self.consume.clone(),
163 v,
164 )?;
165 }
166 }
167
168 Ok(Session::new(session, version))
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use std::{
176 collections::VecDeque,
177 sync::{Arc, Mutex},
178 };
179
180 use crate::coding::{Decode, Encode};
181 use bytes::{BufMut, Bytes};
182
183 #[derive(Debug, Clone, Default)]
184 struct FakeError;
185
186 impl std::fmt::Display for FakeError {
187 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188 write!(f, "fake transport error")
189 }
190 }
191
192 impl std::error::Error for FakeError {}
193
194 impl web_transport_trait::Error for FakeError {
195 fn session_error(&self) -> Option<(u32, String)> {
196 Some((0, "closed".to_string()))
197 }
198 }
199
200 #[derive(Clone, Default)]
201 struct FakeSession {
202 state: Arc<FakeSessionState>,
203 }
204
205 #[derive(Default)]
206 struct FakeSessionState {
207 protocol: Option<&'static str>,
208 control_stream: Mutex<Option<(FakeSendStream, FakeRecvStream)>>,
209 close_events: Mutex<Vec<(u32, String)>>,
210 close_notify: tokio::sync::Notify,
211 control_writes: Arc<Mutex<Vec<u8>>>,
212 }
213
214 impl FakeSession {
215 fn new(protocol: Option<&'static str>, server_control_bytes: Vec<u8>) -> Self {
216 let writes = Arc::new(Mutex::new(Vec::new()));
217 let send = FakeSendStream { writes: writes.clone() };
218 let recv = FakeRecvStream {
219 data: VecDeque::from(server_control_bytes),
220 };
221 let state = FakeSessionState {
222 protocol,
223 control_stream: Mutex::new(Some((send, recv))),
224 close_events: Mutex::new(Vec::new()),
225 close_notify: tokio::sync::Notify::new(),
226 control_writes: writes,
227 };
228 Self { state: Arc::new(state) }
229 }
230
231 fn control_writes(&self) -> Vec<u8> {
232 self.state.control_writes.lock().unwrap().clone()
233 }
234
235 async fn wait_for_first_close(&self) -> (u32, String) {
236 loop {
237 let notified = self.state.close_notify.notified();
238 if let Some(close) = self.state.close_events.lock().unwrap().first().cloned() {
239 return close;
240 }
241 notified.await;
242 }
243 }
244 }
245
246 impl web_transport_trait::Session for FakeSession {
247 type SendStream = FakeSendStream;
248 type RecvStream = FakeRecvStream;
249 type Error = FakeError;
250
251 async fn accept_uni(&self) -> Result<Self::RecvStream, Self::Error> {
252 std::future::pending().await
253 }
254
255 async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
256 std::future::pending().await
257 }
258
259 async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
260 self.state.control_stream.lock().unwrap().take().ok_or(FakeError)
261 }
262
263 async fn open_uni(&self) -> Result<Self::SendStream, Self::Error> {
264 std::future::pending().await
265 }
266
267 fn send_datagram(&self, _payload: Bytes) -> Result<(), Self::Error> {
268 Ok(())
269 }
270
271 async fn recv_datagram(&self) -> Result<Bytes, Self::Error> {
272 std::future::pending().await
273 }
274
275 fn max_datagram_size(&self) -> usize {
276 1200
277 }
278
279 fn protocol(&self) -> Option<&str> {
280 self.state.protocol
281 }
282
283 fn close(&self, code: u32, reason: &str) {
284 self.state.close_events.lock().unwrap().push((code, reason.to_string()));
285 self.state.close_notify.notify_waiters();
286 }
287
288 async fn closed(&self) -> Self::Error {
289 self.state.close_notify.notified().await;
290 FakeError
291 }
292 }
293
294 #[derive(Clone, Default)]
295 struct FakeSendStream {
296 writes: Arc<Mutex<Vec<u8>>>,
297 }
298
299 impl web_transport_trait::SendStream for FakeSendStream {
300 type Error = FakeError;
301
302 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
303 self.writes.lock().unwrap().put_slice(buf);
304 Ok(buf.len())
305 }
306
307 fn set_priority(&mut self, _order: u8) {}
308
309 fn finish(&mut self) -> Result<(), Self::Error> {
310 Ok(())
311 }
312
313 fn reset(&mut self, _code: u32) {}
314
315 async fn closed(&mut self) -> Result<(), Self::Error> {
316 Ok(())
317 }
318 }
319
320 struct FakeRecvStream {
321 data: VecDeque<u8>,
322 }
323
324 impl web_transport_trait::RecvStream for FakeRecvStream {
325 type Error = FakeError;
326
327 async fn read(&mut self, dst: &mut [u8]) -> Result<Option<usize>, Self::Error> {
328 if self.data.is_empty() {
329 return Ok(None);
330 }
331
332 let size = dst.len().min(self.data.len());
333 for slot in dst.iter_mut().take(size) {
334 *slot = self.data.pop_front().unwrap();
335 }
336 Ok(Some(size))
337 }
338
339 fn stop(&mut self, _code: u32) {}
340
341 async fn closed(&mut self) -> Result<(), Self::Error> {
342 Ok(())
343 }
344 }
345
346 fn mock_server_setup(negotiated: Version) -> Vec<u8> {
347 let mut encoded = Vec::new();
348 let server = setup::Server {
349 version: negotiated.into(),
350 parameters: Bytes::new(),
351 };
352 server
353 .encode(&mut encoded, Version::Ietf(ietf::Version::Draft14))
354 .unwrap();
355
356 let info = lite::SessionInfo { bitrate: Some(1) };
358 let lite_v = lite::Version::try_from(negotiated).unwrap();
359 info.encode(&mut encoded, lite_v).unwrap();
360
361 encoded
362 }
363
364 async fn run_alpn_lite_fallback_case(protocol: Option<&'static str>) {
365 let fake = FakeSession::new(protocol, mock_server_setup(Version::Lite(lite::Version::Lite01)));
366 let client = Client::new().with_versions(
367 [
368 Version::Lite(lite::Version::Lite03),
369 Version::Lite(lite::Version::Lite02),
370 Version::Lite(lite::Version::Lite01),
371 Version::Ietf(ietf::Version::Draft14),
372 ]
373 .into(),
374 );
375
376 let _session = client.connect(fake.clone()).await.unwrap();
377
378 let mut setup_bytes = Bytes::from(fake.control_writes());
380 let setup = setup::Client::decode(&mut setup_bytes, Version::Ietf(ietf::Version::Draft14)).unwrap();
381 let advertised: Vec<Version> = setup.versions.iter().map(|v| Version::try_from(*v).unwrap()).collect();
382 assert_eq!(
383 advertised,
384 vec![
385 Version::Lite(lite::Version::Lite02),
386 Version::Lite(lite::Version::Lite01),
387 Version::Ietf(ietf::Version::Draft14),
388 ]
389 );
390
391 let (code, _) = fake.wait_for_first_close().await;
394 assert_eq!(code, Error::Cancel.to_code());
395 }
396
397 #[tokio::test(start_paused = true)]
398 async fn alpn_lite_falls_back_to_draft14_and_switches_version_post_setup() {
399 run_alpn_lite_fallback_case(Some(ALPN_LITE)).await;
400 }
401
402 #[tokio::test(start_paused = true)]
403 async fn no_alpn_falls_back_to_draft14_and_switches_version_post_setup() {
404 run_alpn_lite_fallback_case(None).await;
405 }
406}