Skip to main content

moq_lite/
client.rs

1use crate::{
2	ALPN_14, ALPN_15, ALPN_16, ALPN_17, ALPN_LITE, ALPN_LITE_03, Error, NEGOTIATED, OriginConsumer, OriginProducer,
3	Session, Version, Versions,
4	coding::{self, Decode, Encode, Stream},
5	ietf, lite, setup,
6};
7
8/// A MoQ client session builder.
9#[derive(Default, Clone)]
10pub struct Client {
11	publish: Option<OriginConsumer>,
12	consume: Option<OriginProducer>,
13	versions: Versions,
14}
15
16impl Client {
17	pub fn new() -> Self {
18		Default::default()
19	}
20
21	pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
22		self.publish = publish.into();
23		self
24	}
25
26	pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
27		self.consume = consume.into();
28		self
29	}
30
31	pub fn with_versions(mut self, versions: Versions) -> Self {
32		self.versions = versions;
33		self
34	}
35
36	/// Perform the MoQ handshake as a client negotiating the version.
37	pub async fn connect<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
38		if self.publish.is_none() && self.consume.is_none() {
39			tracing::warn!("not publishing or consuming anything");
40		}
41
42		// If ALPN was used to negotiate the version, use the appropriate encoding.
43		// Default to IETF 14 if no ALPN was used and we'll negotiate the version later.
44		let (encoding, supported) = match session.protocol() {
45			Some(ALPN_17) => {
46				let v = self
47					.versions
48					.select(Version::Ietf(ietf::Version::Draft17))
49					.ok_or(Error::Version)?;
50
51				// Draft-17: SETUP is exchanged in the background by the session.
52				ietf::start(
53					session.clone(),
54					None,
55					None,
56					true,
57					self.publish.clone(),
58					self.consume.clone(),
59					ietf::Version::Draft17,
60				)?;
61
62				tracing::debug!(version = ?v, "connected");
63				return Ok(Session::new(session, v));
64			}
65			Some(ALPN_16) => {
66				let v = self
67					.versions
68					.select(Version::Ietf(ietf::Version::Draft16))
69					.ok_or(Error::Version)?;
70				(v, v.into())
71			}
72			Some(ALPN_15) => {
73				let v = self
74					.versions
75					.select(Version::Ietf(ietf::Version::Draft15))
76					.ok_or(Error::Version)?;
77				(v, v.into())
78			}
79			Some(ALPN_14) => {
80				let v = self
81					.versions
82					.select(Version::Ietf(ietf::Version::Draft14))
83					.ok_or(Error::Version)?;
84				(v, v.into())
85			}
86			Some(ALPN_LITE_03) => {
87				self.versions
88					.select(Version::Lite(lite::Version::Lite03))
89					.ok_or(Error::Version)?;
90
91				// Starting with draft-03, there's no more SETUP control stream.
92				lite::start(
93					session.clone(),
94					None,
95					self.publish.clone(),
96					self.consume.clone(),
97					lite::Version::Lite03,
98				)?;
99
100				return Ok(Session::new(session, lite::Version::Lite03.into()));
101			}
102			Some(ALPN_LITE) | None => {
103				let supported = self.versions.filter(&NEGOTIATED.into()).ok_or(Error::Version)?;
104				(Version::Ietf(ietf::Version::Draft14), supported)
105			}
106			Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
107		};
108
109		let mut stream = Stream::open(&session, encoding).await?;
110
111		// The encoding is always an IETF version for SETUP negotiation.
112		let ietf_encoding = ietf::Version::try_from(encoding).map_err(|_| Error::Version)?;
113
114		let mut parameters = ietf::Parameters::default();
115		parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
116		parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
117		let parameters = parameters.encode_bytes(ietf_encoding)?;
118
119		let client = setup::Client {
120			versions: supported.clone().into(),
121			parameters,
122		};
123
124		// TODO pretty print the parameters.
125		tracing::trace!(?client, "sending client setup");
126		stream.writer.encode(&client).await?;
127
128		let mut server: setup::Server = stream.reader.decode().await?;
129		tracing::trace!(?server, "received server setup");
130
131		let version = supported
132			.iter()
133			.find(|v| coding::Version::from(**v) == server.version)
134			.copied()
135			.ok_or(Error::Version)?;
136
137		match version {
138			Version::Lite(v) => {
139				let stream = stream.with_version(v);
140				lite::start(
141					session.clone(),
142					Some(stream),
143					self.publish.clone(),
144					self.consume.clone(),
145					v,
146				)?;
147			}
148			Version::Ietf(v) => {
149				// Decode the parameters to get the initial request ID.
150				let parameters = ietf::Parameters::decode(&mut server.parameters, v)?;
151				let request_id_max = parameters
152					.get_varint(ietf::ParameterVarInt::MaxRequestId)
153					.map(ietf::RequestId);
154
155				let stream = stream.with_version(v);
156				ietf::start(
157					session.clone(),
158					Some(stream),
159					request_id_max,
160					true,
161					self.publish.clone(),
162					self.consume.clone(),
163					v,
164				)?;
165			}
166		}
167
168		Ok(Session::new(session, version))
169	}
170}
171
172#[cfg(test)]
173mod tests {
174	use super::*;
175	use std::{
176		collections::VecDeque,
177		sync::{Arc, Mutex},
178	};
179
180	use crate::coding::{Decode, Encode};
181	use bytes::{BufMut, Bytes};
182
183	#[derive(Debug, Clone, Default)]
184	struct FakeError;
185
186	impl std::fmt::Display for FakeError {
187		fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188			write!(f, "fake transport error")
189		}
190	}
191
192	impl std::error::Error for FakeError {}
193
194	impl web_transport_trait::Error for FakeError {
195		fn session_error(&self) -> Option<(u32, String)> {
196			Some((0, "closed".to_string()))
197		}
198	}
199
200	#[derive(Clone, Default)]
201	struct FakeSession {
202		state: Arc<FakeSessionState>,
203	}
204
205	#[derive(Default)]
206	struct FakeSessionState {
207		protocol: Option<&'static str>,
208		control_stream: Mutex<Option<(FakeSendStream, FakeRecvStream)>>,
209		close_events: Mutex<Vec<(u32, String)>>,
210		close_notify: tokio::sync::Notify,
211		control_writes: Arc<Mutex<Vec<u8>>>,
212	}
213
214	impl FakeSession {
215		fn new(protocol: Option<&'static str>, server_control_bytes: Vec<u8>) -> Self {
216			let writes = Arc::new(Mutex::new(Vec::new()));
217			let send = FakeSendStream { writes: writes.clone() };
218			let recv = FakeRecvStream {
219				data: VecDeque::from(server_control_bytes),
220			};
221			let state = FakeSessionState {
222				protocol,
223				control_stream: Mutex::new(Some((send, recv))),
224				close_events: Mutex::new(Vec::new()),
225				close_notify: tokio::sync::Notify::new(),
226				control_writes: writes,
227			};
228			Self { state: Arc::new(state) }
229		}
230
231		fn control_writes(&self) -> Vec<u8> {
232			self.state.control_writes.lock().unwrap().clone()
233		}
234
235		async fn wait_for_first_close(&self) -> (u32, String) {
236			loop {
237				let notified = self.state.close_notify.notified();
238				if let Some(close) = self.state.close_events.lock().unwrap().first().cloned() {
239					return close;
240				}
241				notified.await;
242			}
243		}
244	}
245
246	impl web_transport_trait::Session for FakeSession {
247		type SendStream = FakeSendStream;
248		type RecvStream = FakeRecvStream;
249		type Error = FakeError;
250
251		async fn accept_uni(&self) -> Result<Self::RecvStream, Self::Error> {
252			std::future::pending().await
253		}
254
255		async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
256			std::future::pending().await
257		}
258
259		async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
260			self.state.control_stream.lock().unwrap().take().ok_or(FakeError)
261		}
262
263		async fn open_uni(&self) -> Result<Self::SendStream, Self::Error> {
264			std::future::pending().await
265		}
266
267		fn send_datagram(&self, _payload: Bytes) -> Result<(), Self::Error> {
268			Ok(())
269		}
270
271		async fn recv_datagram(&self) -> Result<Bytes, Self::Error> {
272			std::future::pending().await
273		}
274
275		fn max_datagram_size(&self) -> usize {
276			1200
277		}
278
279		fn protocol(&self) -> Option<&str> {
280			self.state.protocol
281		}
282
283		fn close(&self, code: u32, reason: &str) {
284			self.state.close_events.lock().unwrap().push((code, reason.to_string()));
285			self.state.close_notify.notify_waiters();
286		}
287
288		async fn closed(&self) -> Self::Error {
289			self.state.close_notify.notified().await;
290			FakeError
291		}
292	}
293
294	#[derive(Clone, Default)]
295	struct FakeSendStream {
296		writes: Arc<Mutex<Vec<u8>>>,
297	}
298
299	impl web_transport_trait::SendStream for FakeSendStream {
300		type Error = FakeError;
301
302		async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
303			self.writes.lock().unwrap().put_slice(buf);
304			Ok(buf.len())
305		}
306
307		fn set_priority(&mut self, _order: u8) {}
308
309		fn finish(&mut self) -> Result<(), Self::Error> {
310			Ok(())
311		}
312
313		fn reset(&mut self, _code: u32) {}
314
315		async fn closed(&mut self) -> Result<(), Self::Error> {
316			Ok(())
317		}
318	}
319
320	struct FakeRecvStream {
321		data: VecDeque<u8>,
322	}
323
324	impl web_transport_trait::RecvStream for FakeRecvStream {
325		type Error = FakeError;
326
327		async fn read(&mut self, dst: &mut [u8]) -> Result<Option<usize>, Self::Error> {
328			if self.data.is_empty() {
329				return Ok(None);
330			}
331
332			let size = dst.len().min(self.data.len());
333			for slot in dst.iter_mut().take(size) {
334				*slot = self.data.pop_front().unwrap();
335			}
336			Ok(Some(size))
337		}
338
339		fn stop(&mut self, _code: u32) {}
340
341		async fn closed(&mut self) -> Result<(), Self::Error> {
342			Ok(())
343		}
344	}
345
346	fn mock_server_setup(negotiated: Version) -> Vec<u8> {
347		let mut encoded = Vec::new();
348		let server = setup::Server {
349			version: negotiated.into(),
350			parameters: Bytes::new(),
351		};
352		server
353			.encode(&mut encoded, Version::Ietf(ietf::Version::Draft14))
354			.unwrap();
355
356		// Add a setup-stream SessionInfo frame using the negotiated Lite version.
357		let info = lite::SessionInfo { bitrate: Some(1) };
358		let lite_v = lite::Version::try_from(negotiated).unwrap();
359		info.encode(&mut encoded, lite_v).unwrap();
360
361		encoded
362	}
363
364	async fn run_alpn_lite_fallback_case(protocol: Option<&'static str>) {
365		let fake = FakeSession::new(protocol, mock_server_setup(Version::Lite(lite::Version::Lite01)));
366		let client = Client::new().with_versions(
367			[
368				Version::Lite(lite::Version::Lite03),
369				Version::Lite(lite::Version::Lite02),
370				Version::Lite(lite::Version::Lite01),
371				Version::Ietf(ietf::Version::Draft14),
372			]
373			.into(),
374		);
375
376		let _session = client.connect(fake.clone()).await.unwrap();
377
378		// Verify the client setup was encoded using Draft14 framing (ALPN_LITE fallback path).
379		let mut setup_bytes = Bytes::from(fake.control_writes());
380		let setup = setup::Client::decode(&mut setup_bytes, Version::Ietf(ietf::Version::Draft14)).unwrap();
381		let advertised: Vec<Version> = setup.versions.iter().map(|v| Version::try_from(*v).unwrap()).collect();
382		assert_eq!(
383			advertised,
384			vec![
385				Version::Lite(lite::Version::Lite02),
386				Version::Lite(lite::Version::Lite01),
387				Version::Ietf(ietf::Version::Draft14),
388			]
389		);
390
391		// The first close comes from the background lite session task.
392		// Code 0 ("cancelled") means SessionInfo decoded successfully after set_version().
393		let (code, _) = fake.wait_for_first_close().await;
394		assert_eq!(code, Error::Cancel.to_code());
395	}
396
397	#[tokio::test(start_paused = true)]
398	async fn alpn_lite_falls_back_to_draft14_and_switches_version_post_setup() {
399		run_alpn_lite_fallback_case(Some(ALPN_LITE)).await;
400	}
401
402	#[tokio::test(start_paused = true)]
403	async fn no_alpn_falls_back_to_draft14_and_switches_version_post_setup() {
404		run_alpn_lite_fallback_case(None).await;
405	}
406}