Skip to main content

moq_net/
client.rs

1use crate::{
2	ALPN_14, ALPN_15, ALPN_16, ALPN_17, ALPN_18, ALPN_LITE, ALPN_LITE_03, ALPN_LITE_04, ALPN_LITE_05_WIP, Error,
3	NEGOTIATED, OriginConsumer, OriginProducer, Session, StatsHandle, Version, Versions,
4	coding::{self, Decode, Encode, Stream},
5	ietf, lite, setup,
6};
7
8/// A MoQ client session builder.
9#[derive(Default, Clone)]
10pub struct Client {
11	publish: Option<OriginConsumer>,
12	consume: Option<OriginProducer>,
13	stats: StatsHandle,
14	versions: Versions,
15}
16
17impl Client {
18	pub fn new() -> Self {
19		Default::default()
20	}
21
22	pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
23		self.publish = publish.into();
24		self
25	}
26
27	pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
28		self.consume = consume.into();
29		self
30	}
31
32	/// Attach a tier-scoped [`StatsHandle`]. Per-broadcast and per-subscription
33	/// counters will be bumped through this handle for the lifetime of the session.
34	/// Pass [`StatsHandle::default`] (a no-op handle) to opt out.
35	pub fn with_stats(mut self, stats: StatsHandle) -> Self {
36		self.stats = stats;
37		self
38	}
39
40	/// Set both publish and consume from an `OriginProducer`.
41	///
42	/// This is equivalent to calling `with_publish(origin.consume())` and `with_consume(origin)`.
43	pub fn with_origin(self, origin: OriginProducer) -> Self {
44		let consumer = origin.consume();
45		self.with_publish(consumer).with_consume(origin)
46	}
47
48	pub fn with_versions(mut self, versions: Versions) -> Self {
49		self.versions = versions;
50		self
51	}
52
53	/// Perform the MoQ handshake as a client negotiating the version.
54	pub async fn connect<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
55		if self.publish.is_none() && self.consume.is_none() {
56			tracing::warn!("not publishing or consuming anything");
57		}
58
59		// If ALPN was used to negotiate the version, use the appropriate encoding.
60		// Default to IETF 14 if no ALPN was used and we'll negotiate the version later.
61		let (encoding, supported) = match session.protocol() {
62			Some(ALPN_18) => {
63				let v = self
64					.versions
65					.select(Version::Ietf(ietf::Version::Draft18))
66					.ok_or(Error::Version)?;
67
68				// Draft-17+: SETUP is exchanged in the background by the session.
69				ietf::start(
70					session.clone(),
71					None,
72					None,
73					true,
74					self.publish.clone(),
75					self.consume.clone(),
76					self.stats.clone(),
77					ietf::Version::Draft18,
78				)?;
79
80				tracing::debug!(version = ?v, "connected");
81				return Ok(Session::new(session, v, None));
82			}
83			Some(ALPN_17) => {
84				let v = self
85					.versions
86					.select(Version::Ietf(ietf::Version::Draft17))
87					.ok_or(Error::Version)?;
88
89				// Draft-17+: SETUP is exchanged in the background by the session.
90				ietf::start(
91					session.clone(),
92					None,
93					None,
94					true,
95					self.publish.clone(),
96					self.consume.clone(),
97					self.stats.clone(),
98					ietf::Version::Draft17,
99				)?;
100
101				tracing::debug!(version = ?v, "connected");
102				return Ok(Session::new(session, v, None));
103			}
104			Some(ALPN_16) => {
105				let v = self
106					.versions
107					.select(Version::Ietf(ietf::Version::Draft16))
108					.ok_or(Error::Version)?;
109				(v, v.into())
110			}
111			Some(ALPN_15) => {
112				let v = self
113					.versions
114					.select(Version::Ietf(ietf::Version::Draft15))
115					.ok_or(Error::Version)?;
116				(v, v.into())
117			}
118			Some(ALPN_14) => {
119				let v = self
120					.versions
121					.select(Version::Ietf(ietf::Version::Draft14))
122					.ok_or(Error::Version)?;
123				(v, v.into())
124			}
125			Some(ALPN_LITE_05_WIP) => {
126				self.versions
127					.select(Version::Lite(lite::Version::Lite05Wip))
128					.ok_or(Error::Version)?;
129
130				let recv_bw = lite::start(
131					session.clone(),
132					None,
133					self.publish.clone(),
134					self.consume.clone(),
135					self.stats.clone(),
136					lite::Version::Lite05Wip,
137				)?;
138
139				return Ok(Session::new(session, lite::Version::Lite05Wip.into(), recv_bw));
140			}
141			Some(ALPN_LITE_04) => {
142				self.versions
143					.select(Version::Lite(lite::Version::Lite04))
144					.ok_or(Error::Version)?;
145
146				let recv_bw = lite::start(
147					session.clone(),
148					None,
149					self.publish.clone(),
150					self.consume.clone(),
151					self.stats.clone(),
152					lite::Version::Lite04,
153				)?;
154
155				return Ok(Session::new(session, lite::Version::Lite04.into(), recv_bw));
156			}
157			Some(ALPN_LITE_03) => {
158				self.versions
159					.select(Version::Lite(lite::Version::Lite03))
160					.ok_or(Error::Version)?;
161
162				// Starting with draft-03, there's no more SETUP control stream.
163				let recv_bw = lite::start(
164					session.clone(),
165					None,
166					self.publish.clone(),
167					self.consume.clone(),
168					self.stats.clone(),
169					lite::Version::Lite03,
170				)?;
171
172				return Ok(Session::new(session, lite::Version::Lite03.into(), recv_bw));
173			}
174			Some(ALPN_LITE) | None => {
175				let supported = self.versions.filter(&NEGOTIATED.into()).ok_or(Error::Version)?;
176				(Version::Ietf(ietf::Version::Draft14), supported)
177			}
178			Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
179		};
180
181		let mut stream = Stream::open(&session, encoding).await?;
182
183		// The encoding is always an IETF version for SETUP negotiation.
184		let ietf_encoding = ietf::Version::try_from(encoding).map_err(|_| Error::Version)?;
185
186		let mut parameters = ietf::Parameters::default();
187		parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
188		parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
189		let parameters = parameters.encode_bytes(ietf_encoding)?;
190
191		let client = setup::Client {
192			versions: supported.clone().into(),
193			parameters,
194		};
195
196		stream.writer.encode(&client).await?;
197
198		let mut server: setup::Server = stream.reader.decode().await?;
199
200		let version = supported
201			.iter()
202			.find(|v| coding::Version::from(**v) == server.version)
203			.copied()
204			.ok_or(Error::Version)?;
205
206		let recv_bw = match version {
207			Version::Lite(v) => {
208				let stream = stream.with_version(v);
209				lite::start(
210					session.clone(),
211					Some(stream),
212					self.publish.clone(),
213					self.consume.clone(),
214					self.stats.clone(),
215					v,
216				)?
217			}
218			Version::Ietf(v) => {
219				// Decode the parameters to get the initial request ID.
220				let parameters = ietf::Parameters::decode(&mut server.parameters, v)?;
221				let request_id_max = parameters
222					.get_varint(ietf::ParameterVarInt::MaxRequestId)
223					.map(ietf::RequestId);
224
225				let stream = stream.with_version(v);
226				ietf::start(
227					session.clone(),
228					Some(stream),
229					request_id_max,
230					true,
231					self.publish.clone(),
232					self.consume.clone(),
233					self.stats.clone(),
234					v,
235				)?;
236				None
237			}
238		};
239
240		Ok(Session::new(session, version, recv_bw))
241	}
242}
243
244#[cfg(test)]
245mod tests {
246	use super::*;
247	use std::{
248		collections::VecDeque,
249		sync::{Arc, Mutex},
250	};
251
252	use crate::coding::{Decode, Encode};
253	use bytes::{BufMut, Bytes};
254
255	#[derive(Debug, Clone, Default)]
256	struct FakeError;
257
258	impl std::fmt::Display for FakeError {
259		fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260			write!(f, "fake transport error")
261		}
262	}
263
264	impl std::error::Error for FakeError {}
265
266	impl web_transport_trait::Error for FakeError {
267		fn session_error(&self) -> Option<(u32, String)> {
268			Some((0, "closed".to_string()))
269		}
270	}
271
272	#[derive(Clone, Default)]
273	struct FakeSession {
274		state: Arc<FakeSessionState>,
275	}
276
277	#[derive(Default)]
278	struct FakeSessionState {
279		protocol: Option<&'static str>,
280		control_stream: Mutex<Option<(FakeSendStream, FakeRecvStream)>>,
281		close_events: Mutex<Vec<(u32, String)>>,
282		close_notify: tokio::sync::Notify,
283		control_writes: Arc<Mutex<Vec<u8>>>,
284	}
285
286	impl FakeSession {
287		fn new(protocol: Option<&'static str>, server_control_bytes: Vec<u8>) -> Self {
288			let writes = Arc::new(Mutex::new(Vec::new()));
289			let send = FakeSendStream { writes: writes.clone() };
290			let recv = FakeRecvStream {
291				data: VecDeque::from(server_control_bytes),
292			};
293			let state = FakeSessionState {
294				protocol,
295				control_stream: Mutex::new(Some((send, recv))),
296				close_events: Mutex::new(Vec::new()),
297				close_notify: tokio::sync::Notify::new(),
298				control_writes: writes,
299			};
300			Self { state: Arc::new(state) }
301		}
302
303		fn control_writes(&self) -> Vec<u8> {
304			self.state.control_writes.lock().unwrap().clone()
305		}
306
307		async fn wait_for_first_close(&self) -> (u32, String) {
308			loop {
309				let notified = self.state.close_notify.notified();
310				if let Some(close) = self.state.close_events.lock().unwrap().first().cloned() {
311					return close;
312				}
313				notified.await;
314			}
315		}
316	}
317
318	impl web_transport_trait::Session for FakeSession {
319		type SendStream = FakeSendStream;
320		type RecvStream = FakeRecvStream;
321		type Error = FakeError;
322
323		async fn accept_uni(&self) -> Result<Self::RecvStream, Self::Error> {
324			std::future::pending().await
325		}
326
327		async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
328			std::future::pending().await
329		}
330
331		async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
332			self.state.control_stream.lock().unwrap().take().ok_or(FakeError)
333		}
334
335		async fn open_uni(&self) -> Result<Self::SendStream, Self::Error> {
336			std::future::pending().await
337		}
338
339		fn send_datagram(&self, _payload: Bytes) -> Result<(), Self::Error> {
340			Ok(())
341		}
342
343		async fn recv_datagram(&self) -> Result<Bytes, Self::Error> {
344			std::future::pending().await
345		}
346
347		fn max_datagram_size(&self) -> usize {
348			1200
349		}
350
351		fn protocol(&self) -> Option<&str> {
352			self.state.protocol
353		}
354
355		fn close(&self, code: u32, reason: &str) {
356			self.state.close_events.lock().unwrap().push((code, reason.to_string()));
357			self.state.close_notify.notify_waiters();
358		}
359
360		async fn closed(&self) -> Self::Error {
361			self.state.close_notify.notified().await;
362			FakeError
363		}
364	}
365
366	#[derive(Clone, Default)]
367	struct FakeSendStream {
368		writes: Arc<Mutex<Vec<u8>>>,
369	}
370
371	impl web_transport_trait::SendStream for FakeSendStream {
372		type Error = FakeError;
373
374		async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
375			self.writes.lock().unwrap().put_slice(buf);
376			Ok(buf.len())
377		}
378
379		fn set_priority(&mut self, _order: u8) {}
380
381		fn finish(&mut self) -> Result<(), Self::Error> {
382			Ok(())
383		}
384
385		fn reset(&mut self, _code: u32) {}
386
387		async fn closed(&mut self) -> Result<(), Self::Error> {
388			Ok(())
389		}
390	}
391
392	struct FakeRecvStream {
393		data: VecDeque<u8>,
394	}
395
396	impl web_transport_trait::RecvStream for FakeRecvStream {
397		type Error = FakeError;
398
399		async fn read(&mut self, dst: &mut [u8]) -> Result<Option<usize>, Self::Error> {
400			if self.data.is_empty() {
401				return Ok(None);
402			}
403
404			let size = dst.len().min(self.data.len());
405			for slot in dst.iter_mut().take(size) {
406				*slot = self.data.pop_front().unwrap();
407			}
408			Ok(Some(size))
409		}
410
411		fn stop(&mut self, _code: u32) {}
412
413		async fn closed(&mut self) -> Result<(), Self::Error> {
414			Ok(())
415		}
416	}
417
418	fn mock_server_setup(negotiated: Version) -> Vec<u8> {
419		let mut encoded = Vec::new();
420		let server = setup::Server {
421			version: negotiated.into(),
422			parameters: Bytes::new(),
423		};
424		server
425			.encode(&mut encoded, Version::Ietf(ietf::Version::Draft14))
426			.unwrap();
427
428		// Add a setup-stream SessionInfo frame using the negotiated Lite version.
429		let info = lite::SessionInfo { bitrate: Some(1) };
430		let lite_v = lite::Version::try_from(negotiated).unwrap();
431		info.encode(&mut encoded, lite_v).unwrap();
432
433		encoded
434	}
435
436	async fn run_alpn_lite_fallback_case(protocol: Option<&'static str>) {
437		let fake = FakeSession::new(protocol, mock_server_setup(Version::Lite(lite::Version::Lite01)));
438		let client = Client::new().with_versions(
439			[
440				Version::Lite(lite::Version::Lite03),
441				Version::Lite(lite::Version::Lite02),
442				Version::Lite(lite::Version::Lite01),
443				Version::Ietf(ietf::Version::Draft14),
444			]
445			.into(),
446		);
447
448		let _session = client.connect(fake.clone()).await.unwrap();
449
450		// Verify the client setup was encoded using Draft14 framing (ALPN_LITE fallback path).
451		let mut setup_bytes = Bytes::from(fake.control_writes());
452		let setup = setup::Client::decode(&mut setup_bytes, Version::Ietf(ietf::Version::Draft14)).unwrap();
453		let advertised: Vec<Version> = setup.versions.iter().map(|v| Version::try_from(*v).unwrap()).collect();
454		assert_eq!(
455			advertised,
456			vec![
457				Version::Lite(lite::Version::Lite02),
458				Version::Lite(lite::Version::Lite01),
459				Version::Ietf(ietf::Version::Draft14),
460			]
461		);
462
463		// The first close comes from the background lite session task.
464		// Code 0 ("cancelled") means SessionInfo decoded successfully after set_version().
465		let (code, _) = fake.wait_for_first_close().await;
466		assert_eq!(code, Error::Cancel.to_code());
467	}
468
469	#[tokio::test(start_paused = true)]
470	async fn alpn_lite_falls_back_to_draft14_and_switches_version_post_setup() {
471		run_alpn_lite_fallback_case(Some(ALPN_LITE)).await;
472	}
473
474	#[tokio::test(start_paused = true)]
475	async fn no_alpn_falls_back_to_draft14_and_switches_version_post_setup() {
476		run_alpn_lite_fallback_case(None).await;
477	}
478}