Skip to main content

moq_net/
client.rs

1use crate::{
2	ALPN_14, ALPN_15, ALPN_16, ALPN_17, ALPN_18, ALPN_LITE, ALPN_LITE_03, ALPN_LITE_04, ALPN_LITE_05_WIP, Error,
3	NEGOTIATED, OriginConsumer, OriginProducer, Session, StatsHandle, Version, Versions,
4	coding::{self, Decode, Encode, Stream},
5	ietf, lite, setup,
6};
7
8/// A MoQ client session builder.
9#[derive(Default, Clone)]
10pub struct Client {
11	publish: Option<OriginConsumer>,
12	consume: Option<OriginProducer>,
13	stats: StatsHandle,
14	versions: Versions,
15}
16
17impl Client {
18	pub fn new() -> Self {
19		Default::default()
20	}
21
22	pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
23		self.publish = publish.into();
24		self
25	}
26
27	pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
28		self.consume = consume.into();
29		self
30	}
31
32	/// Attach a tier-scoped [`StatsHandle`]. Per-broadcast and per-subscription
33	/// counters will be bumped through this handle for the lifetime of the session.
34	/// Pass [`StatsHandle::default`] (a no-op handle) to opt out.
35	pub fn with_stats(mut self, stats: StatsHandle) -> Self {
36		self.stats = stats;
37		self
38	}
39
40	/// Set both publish and consume from an `OriginProducer`.
41	///
42	/// This is equivalent to calling `with_publish(origin.consume())` and `with_consume(origin)`.
43	pub fn with_origin(self, origin: OriginProducer) -> Self {
44		let consumer = origin.consume();
45		self.with_publish(consumer).with_consume(origin)
46	}
47
48	pub fn with_versions(mut self, versions: Versions) -> Self {
49		self.versions = versions;
50		self
51	}
52
53	/// Perform the MoQ handshake as a client negotiating the version.
54	pub async fn connect<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
55		if self.publish.is_none() && self.consume.is_none() {
56			tracing::warn!("not publishing or consuming anything");
57		}
58
59		// If ALPN was used to negotiate the version, use the appropriate encoding.
60		// Default to IETF 14 if no ALPN was used and we'll negotiate the version later.
61		let (encoding, supported) = match session.protocol() {
62			Some(ALPN_18) => {
63				let v = self
64					.versions
65					.select(Version::Ietf(ietf::Version::Draft18))
66					.ok_or(Error::Version)?;
67
68				// Draft-17+: SETUP is exchanged in the background by the session.
69				ietf::start(
70					session.clone(),
71					None,
72					None,
73					true,
74					self.publish.clone(),
75					self.consume.clone(),
76					ietf::Version::Draft18,
77				)?;
78
79				tracing::debug!(version = ?v, "connected");
80				return Ok(Session::new(session, v, None));
81			}
82			Some(ALPN_17) => {
83				let v = self
84					.versions
85					.select(Version::Ietf(ietf::Version::Draft17))
86					.ok_or(Error::Version)?;
87
88				// Draft-17+: SETUP is exchanged in the background by the session.
89				ietf::start(
90					session.clone(),
91					None,
92					None,
93					true,
94					self.publish.clone(),
95					self.consume.clone(),
96					ietf::Version::Draft17,
97				)?;
98
99				tracing::debug!(version = ?v, "connected");
100				return Ok(Session::new(session, v, None));
101			}
102			Some(ALPN_16) => {
103				let v = self
104					.versions
105					.select(Version::Ietf(ietf::Version::Draft16))
106					.ok_or(Error::Version)?;
107				(v, v.into())
108			}
109			Some(ALPN_15) => {
110				let v = self
111					.versions
112					.select(Version::Ietf(ietf::Version::Draft15))
113					.ok_or(Error::Version)?;
114				(v, v.into())
115			}
116			Some(ALPN_14) => {
117				let v = self
118					.versions
119					.select(Version::Ietf(ietf::Version::Draft14))
120					.ok_or(Error::Version)?;
121				(v, v.into())
122			}
123			Some(ALPN_LITE_05_WIP) => {
124				self.versions
125					.select(Version::Lite(lite::Version::Lite05Wip))
126					.ok_or(Error::Version)?;
127
128				let recv_bw = lite::start(
129					session.clone(),
130					None,
131					self.publish.clone(),
132					self.consume.clone(),
133					self.stats.clone(),
134					lite::Version::Lite05Wip,
135				)?;
136
137				return Ok(Session::new(session, lite::Version::Lite05Wip.into(), recv_bw));
138			}
139			Some(ALPN_LITE_04) => {
140				self.versions
141					.select(Version::Lite(lite::Version::Lite04))
142					.ok_or(Error::Version)?;
143
144				let recv_bw = lite::start(
145					session.clone(),
146					None,
147					self.publish.clone(),
148					self.consume.clone(),
149					self.stats.clone(),
150					lite::Version::Lite04,
151				)?;
152
153				return Ok(Session::new(session, lite::Version::Lite04.into(), recv_bw));
154			}
155			Some(ALPN_LITE_03) => {
156				self.versions
157					.select(Version::Lite(lite::Version::Lite03))
158					.ok_or(Error::Version)?;
159
160				// Starting with draft-03, there's no more SETUP control stream.
161				let recv_bw = lite::start(
162					session.clone(),
163					None,
164					self.publish.clone(),
165					self.consume.clone(),
166					self.stats.clone(),
167					lite::Version::Lite03,
168				)?;
169
170				return Ok(Session::new(session, lite::Version::Lite03.into(), recv_bw));
171			}
172			Some(ALPN_LITE) | None => {
173				let supported = self.versions.filter(&NEGOTIATED.into()).ok_or(Error::Version)?;
174				(Version::Ietf(ietf::Version::Draft14), supported)
175			}
176			Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
177		};
178
179		let mut stream = Stream::open(&session, encoding).await?;
180
181		// The encoding is always an IETF version for SETUP negotiation.
182		let ietf_encoding = ietf::Version::try_from(encoding).map_err(|_| Error::Version)?;
183
184		let mut parameters = ietf::Parameters::default();
185		parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
186		parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
187		let parameters = parameters.encode_bytes(ietf_encoding)?;
188
189		let client = setup::Client {
190			versions: supported.clone().into(),
191			parameters,
192		};
193
194		stream.writer.encode(&client).await?;
195
196		let mut server: setup::Server = stream.reader.decode().await?;
197
198		let version = supported
199			.iter()
200			.find(|v| coding::Version::from(**v) == server.version)
201			.copied()
202			.ok_or(Error::Version)?;
203
204		let recv_bw = match version {
205			Version::Lite(v) => {
206				let stream = stream.with_version(v);
207				lite::start(
208					session.clone(),
209					Some(stream),
210					self.publish.clone(),
211					self.consume.clone(),
212					self.stats.clone(),
213					v,
214				)?
215			}
216			Version::Ietf(v) => {
217				// Decode the parameters to get the initial request ID.
218				let parameters = ietf::Parameters::decode(&mut server.parameters, v)?;
219				let request_id_max = parameters
220					.get_varint(ietf::ParameterVarInt::MaxRequestId)
221					.map(ietf::RequestId);
222
223				let stream = stream.with_version(v);
224				ietf::start(
225					session.clone(),
226					Some(stream),
227					request_id_max,
228					true,
229					self.publish.clone(),
230					self.consume.clone(),
231					v,
232				)?;
233				None
234			}
235		};
236
237		Ok(Session::new(session, version, recv_bw))
238	}
239}
240
241#[cfg(test)]
242mod tests {
243	use super::*;
244	use std::{
245		collections::VecDeque,
246		sync::{Arc, Mutex},
247	};
248
249	use crate::coding::{Decode, Encode};
250	use bytes::{BufMut, Bytes};
251
252	#[derive(Debug, Clone, Default)]
253	struct FakeError;
254
255	impl std::fmt::Display for FakeError {
256		fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257			write!(f, "fake transport error")
258		}
259	}
260
261	impl std::error::Error for FakeError {}
262
263	impl web_transport_trait::Error for FakeError {
264		fn session_error(&self) -> Option<(u32, String)> {
265			Some((0, "closed".to_string()))
266		}
267	}
268
269	#[derive(Clone, Default)]
270	struct FakeSession {
271		state: Arc<FakeSessionState>,
272	}
273
274	#[derive(Default)]
275	struct FakeSessionState {
276		protocol: Option<&'static str>,
277		control_stream: Mutex<Option<(FakeSendStream, FakeRecvStream)>>,
278		close_events: Mutex<Vec<(u32, String)>>,
279		close_notify: tokio::sync::Notify,
280		control_writes: Arc<Mutex<Vec<u8>>>,
281	}
282
283	impl FakeSession {
284		fn new(protocol: Option<&'static str>, server_control_bytes: Vec<u8>) -> Self {
285			let writes = Arc::new(Mutex::new(Vec::new()));
286			let send = FakeSendStream { writes: writes.clone() };
287			let recv = FakeRecvStream {
288				data: VecDeque::from(server_control_bytes),
289			};
290			let state = FakeSessionState {
291				protocol,
292				control_stream: Mutex::new(Some((send, recv))),
293				close_events: Mutex::new(Vec::new()),
294				close_notify: tokio::sync::Notify::new(),
295				control_writes: writes,
296			};
297			Self { state: Arc::new(state) }
298		}
299
300		fn control_writes(&self) -> Vec<u8> {
301			self.state.control_writes.lock().unwrap().clone()
302		}
303
304		async fn wait_for_first_close(&self) -> (u32, String) {
305			loop {
306				let notified = self.state.close_notify.notified();
307				if let Some(close) = self.state.close_events.lock().unwrap().first().cloned() {
308					return close;
309				}
310				notified.await;
311			}
312		}
313	}
314
315	impl web_transport_trait::Session for FakeSession {
316		type SendStream = FakeSendStream;
317		type RecvStream = FakeRecvStream;
318		type Error = FakeError;
319
320		async fn accept_uni(&self) -> Result<Self::RecvStream, Self::Error> {
321			std::future::pending().await
322		}
323
324		async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
325			std::future::pending().await
326		}
327
328		async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
329			self.state.control_stream.lock().unwrap().take().ok_or(FakeError)
330		}
331
332		async fn open_uni(&self) -> Result<Self::SendStream, Self::Error> {
333			std::future::pending().await
334		}
335
336		fn send_datagram(&self, _payload: Bytes) -> Result<(), Self::Error> {
337			Ok(())
338		}
339
340		async fn recv_datagram(&self) -> Result<Bytes, Self::Error> {
341			std::future::pending().await
342		}
343
344		fn max_datagram_size(&self) -> usize {
345			1200
346		}
347
348		fn protocol(&self) -> Option<&str> {
349			self.state.protocol
350		}
351
352		fn close(&self, code: u32, reason: &str) {
353			self.state.close_events.lock().unwrap().push((code, reason.to_string()));
354			self.state.close_notify.notify_waiters();
355		}
356
357		async fn closed(&self) -> Self::Error {
358			self.state.close_notify.notified().await;
359			FakeError
360		}
361	}
362
363	#[derive(Clone, Default)]
364	struct FakeSendStream {
365		writes: Arc<Mutex<Vec<u8>>>,
366	}
367
368	impl web_transport_trait::SendStream for FakeSendStream {
369		type Error = FakeError;
370
371		async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
372			self.writes.lock().unwrap().put_slice(buf);
373			Ok(buf.len())
374		}
375
376		fn set_priority(&mut self, _order: u8) {}
377
378		fn finish(&mut self) -> Result<(), Self::Error> {
379			Ok(())
380		}
381
382		fn reset(&mut self, _code: u32) {}
383
384		async fn closed(&mut self) -> Result<(), Self::Error> {
385			Ok(())
386		}
387	}
388
389	struct FakeRecvStream {
390		data: VecDeque<u8>,
391	}
392
393	impl web_transport_trait::RecvStream for FakeRecvStream {
394		type Error = FakeError;
395
396		async fn read(&mut self, dst: &mut [u8]) -> Result<Option<usize>, Self::Error> {
397			if self.data.is_empty() {
398				return Ok(None);
399			}
400
401			let size = dst.len().min(self.data.len());
402			for slot in dst.iter_mut().take(size) {
403				*slot = self.data.pop_front().unwrap();
404			}
405			Ok(Some(size))
406		}
407
408		fn stop(&mut self, _code: u32) {}
409
410		async fn closed(&mut self) -> Result<(), Self::Error> {
411			Ok(())
412		}
413	}
414
415	fn mock_server_setup(negotiated: Version) -> Vec<u8> {
416		let mut encoded = Vec::new();
417		let server = setup::Server {
418			version: negotiated.into(),
419			parameters: Bytes::new(),
420		};
421		server
422			.encode(&mut encoded, Version::Ietf(ietf::Version::Draft14))
423			.unwrap();
424
425		// Add a setup-stream SessionInfo frame using the negotiated Lite version.
426		let info = lite::SessionInfo { bitrate: Some(1) };
427		let lite_v = lite::Version::try_from(negotiated).unwrap();
428		info.encode(&mut encoded, lite_v).unwrap();
429
430		encoded
431	}
432
433	async fn run_alpn_lite_fallback_case(protocol: Option<&'static str>) {
434		let fake = FakeSession::new(protocol, mock_server_setup(Version::Lite(lite::Version::Lite01)));
435		let client = Client::new().with_versions(
436			[
437				Version::Lite(lite::Version::Lite03),
438				Version::Lite(lite::Version::Lite02),
439				Version::Lite(lite::Version::Lite01),
440				Version::Ietf(ietf::Version::Draft14),
441			]
442			.into(),
443		);
444
445		let _session = client.connect(fake.clone()).await.unwrap();
446
447		// Verify the client setup was encoded using Draft14 framing (ALPN_LITE fallback path).
448		let mut setup_bytes = Bytes::from(fake.control_writes());
449		let setup = setup::Client::decode(&mut setup_bytes, Version::Ietf(ietf::Version::Draft14)).unwrap();
450		let advertised: Vec<Version> = setup.versions.iter().map(|v| Version::try_from(*v).unwrap()).collect();
451		assert_eq!(
452			advertised,
453			vec![
454				Version::Lite(lite::Version::Lite02),
455				Version::Lite(lite::Version::Lite01),
456				Version::Ietf(ietf::Version::Draft14),
457			]
458		);
459
460		// The first close comes from the background lite session task.
461		// Code 0 ("cancelled") means SessionInfo decoded successfully after set_version().
462		let (code, _) = fake.wait_for_first_close().await;
463		assert_eq!(code, Error::Cancel.to_code());
464	}
465
466	#[tokio::test(start_paused = true)]
467	async fn alpn_lite_falls_back_to_draft14_and_switches_version_post_setup() {
468		run_alpn_lite_fallback_case(Some(ALPN_LITE)).await;
469	}
470
471	#[tokio::test(start_paused = true)]
472	async fn no_alpn_falls_back_to_draft14_and_switches_version_post_setup() {
473		run_alpn_lite_fallback_case(None).await;
474	}
475}