Skip to main content

moq_lite/
client.rs

1use crate::{
2	ALPN_14, ALPN_15, ALPN_16, ALPN_17, ALPN_LITE, ALPN_LITE_03, ALPN_LITE_04, Error, NEGOTIATED, OriginConsumer,
3	OriginProducer, Session, Version, Versions,
4	coding::{self, Decode, Encode, Stream},
5	ietf, lite, setup,
6};
7
8/// A MoQ client session builder.
9#[derive(Default, Clone)]
10pub struct Client {
11	publish: Option<OriginConsumer>,
12	consume: Option<OriginProducer>,
13	versions: Versions,
14}
15
16impl Client {
17	pub fn new() -> Self {
18		Default::default()
19	}
20
21	pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
22		self.publish = publish.into();
23		self
24	}
25
26	pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
27		self.consume = consume.into();
28		self
29	}
30
31	/// Set both publish and consume from an `OriginProducer`.
32	///
33	/// This is equivalent to calling `with_publish(origin.consume())` and `with_consume(origin)`.
34	pub fn with_origin(self, origin: OriginProducer) -> Self {
35		let consumer = origin.consume();
36		self.with_publish(consumer).with_consume(origin)
37	}
38
39	pub fn with_versions(mut self, versions: Versions) -> Self {
40		self.versions = versions;
41		self
42	}
43
44	/// Perform the MoQ handshake as a client negotiating the version.
45	pub async fn connect<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
46		if self.publish.is_none() && self.consume.is_none() {
47			tracing::warn!("not publishing or consuming anything");
48		}
49
50		// If ALPN was used to negotiate the version, use the appropriate encoding.
51		// Default to IETF 14 if no ALPN was used and we'll negotiate the version later.
52		let (encoding, supported) = match session.protocol() {
53			Some(ALPN_17) => {
54				let v = self
55					.versions
56					.select(Version::Ietf(ietf::Version::Draft17))
57					.ok_or(Error::Version)?;
58
59				// Draft-17: SETUP is exchanged in the background by the session.
60				ietf::start(
61					session.clone(),
62					None,
63					None,
64					true,
65					self.publish.clone(),
66					self.consume.clone(),
67					ietf::Version::Draft17,
68				)?;
69
70				tracing::debug!(version = ?v, "connected");
71				return Ok(Session::new(session, v, None));
72			}
73			Some(ALPN_16) => {
74				let v = self
75					.versions
76					.select(Version::Ietf(ietf::Version::Draft16))
77					.ok_or(Error::Version)?;
78				(v, v.into())
79			}
80			Some(ALPN_15) => {
81				let v = self
82					.versions
83					.select(Version::Ietf(ietf::Version::Draft15))
84					.ok_or(Error::Version)?;
85				(v, v.into())
86			}
87			Some(ALPN_14) => {
88				let v = self
89					.versions
90					.select(Version::Ietf(ietf::Version::Draft14))
91					.ok_or(Error::Version)?;
92				(v, v.into())
93			}
94			Some(ALPN_LITE_04) => {
95				self.versions
96					.select(Version::Lite(lite::Version::Lite04))
97					.ok_or(Error::Version)?;
98
99				let recv_bw = lite::start(
100					session.clone(),
101					None,
102					self.publish.clone(),
103					self.consume.clone(),
104					lite::Version::Lite04,
105				)?;
106
107				return Ok(Session::new(session, lite::Version::Lite04.into(), recv_bw));
108			}
109			Some(ALPN_LITE_03) => {
110				self.versions
111					.select(Version::Lite(lite::Version::Lite03))
112					.ok_or(Error::Version)?;
113
114				// Starting with draft-03, there's no more SETUP control stream.
115				let recv_bw = lite::start(
116					session.clone(),
117					None,
118					self.publish.clone(),
119					self.consume.clone(),
120					lite::Version::Lite03,
121				)?;
122
123				return Ok(Session::new(session, lite::Version::Lite03.into(), recv_bw));
124			}
125			Some(ALPN_LITE) | None => {
126				let supported = self.versions.filter(&NEGOTIATED.into()).ok_or(Error::Version)?;
127				(Version::Ietf(ietf::Version::Draft14), supported)
128			}
129			Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
130		};
131
132		let mut stream = Stream::open(&session, encoding).await?;
133
134		// The encoding is always an IETF version for SETUP negotiation.
135		let ietf_encoding = ietf::Version::try_from(encoding).map_err(|_| Error::Version)?;
136
137		let mut parameters = ietf::Parameters::default();
138		parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
139		parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
140		let parameters = parameters.encode_bytes(ietf_encoding)?;
141
142		let client = setup::Client {
143			versions: supported.clone().into(),
144			parameters,
145		};
146
147		stream.writer.encode(&client).await?;
148
149		let mut server: setup::Server = stream.reader.decode().await?;
150
151		let version = supported
152			.iter()
153			.find(|v| coding::Version::from(**v) == server.version)
154			.copied()
155			.ok_or(Error::Version)?;
156
157		let recv_bw = match version {
158			Version::Lite(v) => {
159				let stream = stream.with_version(v);
160				lite::start(
161					session.clone(),
162					Some(stream),
163					self.publish.clone(),
164					self.consume.clone(),
165					v,
166				)?
167			}
168			Version::Ietf(v) => {
169				// Decode the parameters to get the initial request ID.
170				let parameters = ietf::Parameters::decode(&mut server.parameters, v)?;
171				let request_id_max = parameters
172					.get_varint(ietf::ParameterVarInt::MaxRequestId)
173					.map(ietf::RequestId);
174
175				let stream = stream.with_version(v);
176				ietf::start(
177					session.clone(),
178					Some(stream),
179					request_id_max,
180					true,
181					self.publish.clone(),
182					self.consume.clone(),
183					v,
184				)?;
185				None
186			}
187		};
188
189		Ok(Session::new(session, version, recv_bw))
190	}
191}
192
193#[cfg(test)]
194mod tests {
195	use super::*;
196	use std::{
197		collections::VecDeque,
198		sync::{Arc, Mutex},
199	};
200
201	use crate::coding::{Decode, Encode};
202	use bytes::{BufMut, Bytes};
203
204	#[derive(Debug, Clone, Default)]
205	struct FakeError;
206
207	impl std::fmt::Display for FakeError {
208		fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209			write!(f, "fake transport error")
210		}
211	}
212
213	impl std::error::Error for FakeError {}
214
215	impl web_transport_trait::Error for FakeError {
216		fn session_error(&self) -> Option<(u32, String)> {
217			Some((0, "closed".to_string()))
218		}
219	}
220
221	#[derive(Clone, Default)]
222	struct FakeSession {
223		state: Arc<FakeSessionState>,
224	}
225
226	#[derive(Default)]
227	struct FakeSessionState {
228		protocol: Option<&'static str>,
229		control_stream: Mutex<Option<(FakeSendStream, FakeRecvStream)>>,
230		close_events: Mutex<Vec<(u32, String)>>,
231		close_notify: tokio::sync::Notify,
232		control_writes: Arc<Mutex<Vec<u8>>>,
233	}
234
235	impl FakeSession {
236		fn new(protocol: Option<&'static str>, server_control_bytes: Vec<u8>) -> Self {
237			let writes = Arc::new(Mutex::new(Vec::new()));
238			let send = FakeSendStream { writes: writes.clone() };
239			let recv = FakeRecvStream {
240				data: VecDeque::from(server_control_bytes),
241			};
242			let state = FakeSessionState {
243				protocol,
244				control_stream: Mutex::new(Some((send, recv))),
245				close_events: Mutex::new(Vec::new()),
246				close_notify: tokio::sync::Notify::new(),
247				control_writes: writes,
248			};
249			Self { state: Arc::new(state) }
250		}
251
252		fn control_writes(&self) -> Vec<u8> {
253			self.state.control_writes.lock().unwrap().clone()
254		}
255
256		async fn wait_for_first_close(&self) -> (u32, String) {
257			loop {
258				let notified = self.state.close_notify.notified();
259				if let Some(close) = self.state.close_events.lock().unwrap().first().cloned() {
260					return close;
261				}
262				notified.await;
263			}
264		}
265	}
266
267	impl web_transport_trait::Session for FakeSession {
268		type SendStream = FakeSendStream;
269		type RecvStream = FakeRecvStream;
270		type Error = FakeError;
271
272		async fn accept_uni(&self) -> Result<Self::RecvStream, Self::Error> {
273			std::future::pending().await
274		}
275
276		async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
277			std::future::pending().await
278		}
279
280		async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
281			self.state.control_stream.lock().unwrap().take().ok_or(FakeError)
282		}
283
284		async fn open_uni(&self) -> Result<Self::SendStream, Self::Error> {
285			std::future::pending().await
286		}
287
288		fn send_datagram(&self, _payload: Bytes) -> Result<(), Self::Error> {
289			Ok(())
290		}
291
292		async fn recv_datagram(&self) -> Result<Bytes, Self::Error> {
293			std::future::pending().await
294		}
295
296		fn max_datagram_size(&self) -> usize {
297			1200
298		}
299
300		fn protocol(&self) -> Option<&str> {
301			self.state.protocol
302		}
303
304		fn close(&self, code: u32, reason: &str) {
305			self.state.close_events.lock().unwrap().push((code, reason.to_string()));
306			self.state.close_notify.notify_waiters();
307		}
308
309		async fn closed(&self) -> Self::Error {
310			self.state.close_notify.notified().await;
311			FakeError
312		}
313	}
314
315	#[derive(Clone, Default)]
316	struct FakeSendStream {
317		writes: Arc<Mutex<Vec<u8>>>,
318	}
319
320	impl web_transport_trait::SendStream for FakeSendStream {
321		type Error = FakeError;
322
323		async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
324			self.writes.lock().unwrap().put_slice(buf);
325			Ok(buf.len())
326		}
327
328		fn set_priority(&mut self, _order: u8) {}
329
330		fn finish(&mut self) -> Result<(), Self::Error> {
331			Ok(())
332		}
333
334		fn reset(&mut self, _code: u32) {}
335
336		async fn closed(&mut self) -> Result<(), Self::Error> {
337			Ok(())
338		}
339	}
340
341	struct FakeRecvStream {
342		data: VecDeque<u8>,
343	}
344
345	impl web_transport_trait::RecvStream for FakeRecvStream {
346		type Error = FakeError;
347
348		async fn read(&mut self, dst: &mut [u8]) -> Result<Option<usize>, Self::Error> {
349			if self.data.is_empty() {
350				return Ok(None);
351			}
352
353			let size = dst.len().min(self.data.len());
354			for slot in dst.iter_mut().take(size) {
355				*slot = self.data.pop_front().unwrap();
356			}
357			Ok(Some(size))
358		}
359
360		fn stop(&mut self, _code: u32) {}
361
362		async fn closed(&mut self) -> Result<(), Self::Error> {
363			Ok(())
364		}
365	}
366
367	fn mock_server_setup(negotiated: Version) -> Vec<u8> {
368		let mut encoded = Vec::new();
369		let server = setup::Server {
370			version: negotiated.into(),
371			parameters: Bytes::new(),
372		};
373		server
374			.encode(&mut encoded, Version::Ietf(ietf::Version::Draft14))
375			.unwrap();
376
377		// Add a setup-stream SessionInfo frame using the negotiated Lite version.
378		let info = lite::SessionInfo { bitrate: Some(1) };
379		let lite_v = lite::Version::try_from(negotiated).unwrap();
380		info.encode(&mut encoded, lite_v).unwrap();
381
382		encoded
383	}
384
385	async fn run_alpn_lite_fallback_case(protocol: Option<&'static str>) {
386		let fake = FakeSession::new(protocol, mock_server_setup(Version::Lite(lite::Version::Lite01)));
387		let client = Client::new().with_versions(
388			[
389				Version::Lite(lite::Version::Lite03),
390				Version::Lite(lite::Version::Lite02),
391				Version::Lite(lite::Version::Lite01),
392				Version::Ietf(ietf::Version::Draft14),
393			]
394			.into(),
395		);
396
397		let _session = client.connect(fake.clone()).await.unwrap();
398
399		// Verify the client setup was encoded using Draft14 framing (ALPN_LITE fallback path).
400		let mut setup_bytes = Bytes::from(fake.control_writes());
401		let setup = setup::Client::decode(&mut setup_bytes, Version::Ietf(ietf::Version::Draft14)).unwrap();
402		let advertised: Vec<Version> = setup.versions.iter().map(|v| Version::try_from(*v).unwrap()).collect();
403		assert_eq!(
404			advertised,
405			vec![
406				Version::Lite(lite::Version::Lite02),
407				Version::Lite(lite::Version::Lite01),
408				Version::Ietf(ietf::Version::Draft14),
409			]
410		);
411
412		// The first close comes from the background lite session task.
413		// Code 0 ("cancelled") means SessionInfo decoded successfully after set_version().
414		let (code, _) = fake.wait_for_first_close().await;
415		assert_eq!(code, Error::Cancel.to_code());
416	}
417
418	#[tokio::test(start_paused = true)]
419	async fn alpn_lite_falls_back_to_draft14_and_switches_version_post_setup() {
420		run_alpn_lite_fallback_case(Some(ALPN_LITE)).await;
421	}
422
423	#[tokio::test(start_paused = true)]
424	async fn no_alpn_falls_back_to_draft14_and_switches_version_post_setup() {
425		run_alpn_lite_fallback_case(None).await;
426	}
427}