Skip to main content

moq_net/
client.rs

1use crate::{
2	ALPN_14, ALPN_15, ALPN_16, ALPN_17, ALPN_18, ALPN_LITE, ALPN_LITE_03, ALPN_LITE_04, Error, NEGOTIATED,
3	OriginConsumer, OriginProducer, Session, Version, Versions,
4	coding::{self, Decode, Encode, Stream},
5	ietf, lite, setup,
6};
7
8/// A MoQ client session builder.
9#[derive(Default, Clone)]
10pub struct Client {
11	publish: Option<OriginConsumer>,
12	consume: Option<OriginProducer>,
13	versions: Versions,
14}
15
16impl Client {
17	pub fn new() -> Self {
18		Default::default()
19	}
20
21	pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
22		self.publish = publish.into();
23		self
24	}
25
26	pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
27		self.consume = consume.into();
28		self
29	}
30
31	/// Set both publish and consume from an `OriginProducer`.
32	///
33	/// This is equivalent to calling `with_publish(origin.consume())` and `with_consume(origin)`.
34	pub fn with_origin(self, origin: OriginProducer) -> Self {
35		let consumer = origin.consume();
36		self.with_publish(consumer).with_consume(origin)
37	}
38
39	pub fn with_versions(mut self, versions: Versions) -> Self {
40		self.versions = versions;
41		self
42	}
43
44	/// Perform the MoQ handshake as a client negotiating the version.
45	pub async fn connect<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
46		if self.publish.is_none() && self.consume.is_none() {
47			tracing::warn!("not publishing or consuming anything");
48		}
49
50		// If ALPN was used to negotiate the version, use the appropriate encoding.
51		// Default to IETF 14 if no ALPN was used and we'll negotiate the version later.
52		let (encoding, supported) = match session.protocol() {
53			Some(ALPN_18) => {
54				let v = self
55					.versions
56					.select(Version::Ietf(ietf::Version::Draft18))
57					.ok_or(Error::Version)?;
58
59				// Draft-17+: SETUP is exchanged in the background by the session.
60				ietf::start(
61					session.clone(),
62					None,
63					None,
64					true,
65					self.publish.clone(),
66					self.consume.clone(),
67					ietf::Version::Draft18,
68				)?;
69
70				tracing::debug!(version = ?v, "connected");
71				return Ok(Session::new(session, v, None));
72			}
73			Some(ALPN_17) => {
74				let v = self
75					.versions
76					.select(Version::Ietf(ietf::Version::Draft17))
77					.ok_or(Error::Version)?;
78
79				// Draft-17+: SETUP is exchanged in the background by the session.
80				ietf::start(
81					session.clone(),
82					None,
83					None,
84					true,
85					self.publish.clone(),
86					self.consume.clone(),
87					ietf::Version::Draft17,
88				)?;
89
90				tracing::debug!(version = ?v, "connected");
91				return Ok(Session::new(session, v, None));
92			}
93			Some(ALPN_16) => {
94				let v = self
95					.versions
96					.select(Version::Ietf(ietf::Version::Draft16))
97					.ok_or(Error::Version)?;
98				(v, v.into())
99			}
100			Some(ALPN_15) => {
101				let v = self
102					.versions
103					.select(Version::Ietf(ietf::Version::Draft15))
104					.ok_or(Error::Version)?;
105				(v, v.into())
106			}
107			Some(ALPN_14) => {
108				let v = self
109					.versions
110					.select(Version::Ietf(ietf::Version::Draft14))
111					.ok_or(Error::Version)?;
112				(v, v.into())
113			}
114			Some(ALPN_LITE_04) => {
115				self.versions
116					.select(Version::Lite(lite::Version::Lite04))
117					.ok_or(Error::Version)?;
118
119				let recv_bw = lite::start(
120					session.clone(),
121					None,
122					self.publish.clone(),
123					self.consume.clone(),
124					lite::Version::Lite04,
125				)?;
126
127				return Ok(Session::new(session, lite::Version::Lite04.into(), recv_bw));
128			}
129			Some(ALPN_LITE_03) => {
130				self.versions
131					.select(Version::Lite(lite::Version::Lite03))
132					.ok_or(Error::Version)?;
133
134				// Starting with draft-03, there's no more SETUP control stream.
135				let recv_bw = lite::start(
136					session.clone(),
137					None,
138					self.publish.clone(),
139					self.consume.clone(),
140					lite::Version::Lite03,
141				)?;
142
143				return Ok(Session::new(session, lite::Version::Lite03.into(), recv_bw));
144			}
145			Some(ALPN_LITE) | None => {
146				let supported = self.versions.filter(&NEGOTIATED.into()).ok_or(Error::Version)?;
147				(Version::Ietf(ietf::Version::Draft14), supported)
148			}
149			Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
150		};
151
152		let mut stream = Stream::open(&session, encoding).await?;
153
154		// The encoding is always an IETF version for SETUP negotiation.
155		let ietf_encoding = ietf::Version::try_from(encoding).map_err(|_| Error::Version)?;
156
157		let mut parameters = ietf::Parameters::default();
158		parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
159		parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
160		let parameters = parameters.encode_bytes(ietf_encoding)?;
161
162		let client = setup::Client {
163			versions: supported.clone().into(),
164			parameters,
165		};
166
167		stream.writer.encode(&client).await?;
168
169		let mut server: setup::Server = stream.reader.decode().await?;
170
171		let version = supported
172			.iter()
173			.find(|v| coding::Version::from(**v) == server.version)
174			.copied()
175			.ok_or(Error::Version)?;
176
177		let recv_bw = match version {
178			Version::Lite(v) => {
179				let stream = stream.with_version(v);
180				lite::start(
181					session.clone(),
182					Some(stream),
183					self.publish.clone(),
184					self.consume.clone(),
185					v,
186				)?
187			}
188			Version::Ietf(v) => {
189				// Decode the parameters to get the initial request ID.
190				let parameters = ietf::Parameters::decode(&mut server.parameters, v)?;
191				let request_id_max = parameters
192					.get_varint(ietf::ParameterVarInt::MaxRequestId)
193					.map(ietf::RequestId);
194
195				let stream = stream.with_version(v);
196				ietf::start(
197					session.clone(),
198					Some(stream),
199					request_id_max,
200					true,
201					self.publish.clone(),
202					self.consume.clone(),
203					v,
204				)?;
205				None
206			}
207		};
208
209		Ok(Session::new(session, version, recv_bw))
210	}
211}
212
213#[cfg(test)]
214mod tests {
215	use super::*;
216	use std::{
217		collections::VecDeque,
218		sync::{Arc, Mutex},
219	};
220
221	use crate::coding::{Decode, Encode};
222	use bytes::{BufMut, Bytes};
223
224	#[derive(Debug, Clone, Default)]
225	struct FakeError;
226
227	impl std::fmt::Display for FakeError {
228		fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229			write!(f, "fake transport error")
230		}
231	}
232
233	impl std::error::Error for FakeError {}
234
235	impl web_transport_trait::Error for FakeError {
236		fn session_error(&self) -> Option<(u32, String)> {
237			Some((0, "closed".to_string()))
238		}
239	}
240
241	#[derive(Clone, Default)]
242	struct FakeSession {
243		state: Arc<FakeSessionState>,
244	}
245
246	#[derive(Default)]
247	struct FakeSessionState {
248		protocol: Option<&'static str>,
249		control_stream: Mutex<Option<(FakeSendStream, FakeRecvStream)>>,
250		close_events: Mutex<Vec<(u32, String)>>,
251		close_notify: tokio::sync::Notify,
252		control_writes: Arc<Mutex<Vec<u8>>>,
253	}
254
255	impl FakeSession {
256		fn new(protocol: Option<&'static str>, server_control_bytes: Vec<u8>) -> Self {
257			let writes = Arc::new(Mutex::new(Vec::new()));
258			let send = FakeSendStream { writes: writes.clone() };
259			let recv = FakeRecvStream {
260				data: VecDeque::from(server_control_bytes),
261			};
262			let state = FakeSessionState {
263				protocol,
264				control_stream: Mutex::new(Some((send, recv))),
265				close_events: Mutex::new(Vec::new()),
266				close_notify: tokio::sync::Notify::new(),
267				control_writes: writes,
268			};
269			Self { state: Arc::new(state) }
270		}
271
272		fn control_writes(&self) -> Vec<u8> {
273			self.state.control_writes.lock().unwrap().clone()
274		}
275
276		async fn wait_for_first_close(&self) -> (u32, String) {
277			loop {
278				let notified = self.state.close_notify.notified();
279				if let Some(close) = self.state.close_events.lock().unwrap().first().cloned() {
280					return close;
281				}
282				notified.await;
283			}
284		}
285	}
286
287	impl web_transport_trait::Session for FakeSession {
288		type SendStream = FakeSendStream;
289		type RecvStream = FakeRecvStream;
290		type Error = FakeError;
291
292		async fn accept_uni(&self) -> Result<Self::RecvStream, Self::Error> {
293			std::future::pending().await
294		}
295
296		async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
297			std::future::pending().await
298		}
299
300		async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
301			self.state.control_stream.lock().unwrap().take().ok_or(FakeError)
302		}
303
304		async fn open_uni(&self) -> Result<Self::SendStream, Self::Error> {
305			std::future::pending().await
306		}
307
308		fn send_datagram(&self, _payload: Bytes) -> Result<(), Self::Error> {
309			Ok(())
310		}
311
312		async fn recv_datagram(&self) -> Result<Bytes, Self::Error> {
313			std::future::pending().await
314		}
315
316		fn max_datagram_size(&self) -> usize {
317			1200
318		}
319
320		fn protocol(&self) -> Option<&str> {
321			self.state.protocol
322		}
323
324		fn close(&self, code: u32, reason: &str) {
325			self.state.close_events.lock().unwrap().push((code, reason.to_string()));
326			self.state.close_notify.notify_waiters();
327		}
328
329		async fn closed(&self) -> Self::Error {
330			self.state.close_notify.notified().await;
331			FakeError
332		}
333	}
334
335	#[derive(Clone, Default)]
336	struct FakeSendStream {
337		writes: Arc<Mutex<Vec<u8>>>,
338	}
339
340	impl web_transport_trait::SendStream for FakeSendStream {
341		type Error = FakeError;
342
343		async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
344			self.writes.lock().unwrap().put_slice(buf);
345			Ok(buf.len())
346		}
347
348		fn set_priority(&mut self, _order: u8) {}
349
350		fn finish(&mut self) -> Result<(), Self::Error> {
351			Ok(())
352		}
353
354		fn reset(&mut self, _code: u32) {}
355
356		async fn closed(&mut self) -> Result<(), Self::Error> {
357			Ok(())
358		}
359	}
360
361	struct FakeRecvStream {
362		data: VecDeque<u8>,
363	}
364
365	impl web_transport_trait::RecvStream for FakeRecvStream {
366		type Error = FakeError;
367
368		async fn read(&mut self, dst: &mut [u8]) -> Result<Option<usize>, Self::Error> {
369			if self.data.is_empty() {
370				return Ok(None);
371			}
372
373			let size = dst.len().min(self.data.len());
374			for slot in dst.iter_mut().take(size) {
375				*slot = self.data.pop_front().unwrap();
376			}
377			Ok(Some(size))
378		}
379
380		fn stop(&mut self, _code: u32) {}
381
382		async fn closed(&mut self) -> Result<(), Self::Error> {
383			Ok(())
384		}
385	}
386
387	fn mock_server_setup(negotiated: Version) -> Vec<u8> {
388		let mut encoded = Vec::new();
389		let server = setup::Server {
390			version: negotiated.into(),
391			parameters: Bytes::new(),
392		};
393		server
394			.encode(&mut encoded, Version::Ietf(ietf::Version::Draft14))
395			.unwrap();
396
397		// Add a setup-stream SessionInfo frame using the negotiated Lite version.
398		let info = lite::SessionInfo { bitrate: Some(1) };
399		let lite_v = lite::Version::try_from(negotiated).unwrap();
400		info.encode(&mut encoded, lite_v).unwrap();
401
402		encoded
403	}
404
405	async fn run_alpn_lite_fallback_case(protocol: Option<&'static str>) {
406		let fake = FakeSession::new(protocol, mock_server_setup(Version::Lite(lite::Version::Lite01)));
407		let client = Client::new().with_versions(
408			[
409				Version::Lite(lite::Version::Lite03),
410				Version::Lite(lite::Version::Lite02),
411				Version::Lite(lite::Version::Lite01),
412				Version::Ietf(ietf::Version::Draft14),
413			]
414			.into(),
415		);
416
417		let _session = client.connect(fake.clone()).await.unwrap();
418
419		// Verify the client setup was encoded using Draft14 framing (ALPN_LITE fallback path).
420		let mut setup_bytes = Bytes::from(fake.control_writes());
421		let setup = setup::Client::decode(&mut setup_bytes, Version::Ietf(ietf::Version::Draft14)).unwrap();
422		let advertised: Vec<Version> = setup.versions.iter().map(|v| Version::try_from(*v).unwrap()).collect();
423		assert_eq!(
424			advertised,
425			vec![
426				Version::Lite(lite::Version::Lite02),
427				Version::Lite(lite::Version::Lite01),
428				Version::Ietf(ietf::Version::Draft14),
429			]
430		);
431
432		// The first close comes from the background lite session task.
433		// Code 0 ("cancelled") means SessionInfo decoded successfully after set_version().
434		let (code, _) = fake.wait_for_first_close().await;
435		assert_eq!(code, Error::Cancel.to_code());
436	}
437
438	#[tokio::test(start_paused = true)]
439	async fn alpn_lite_falls_back_to_draft14_and_switches_version_post_setup() {
440		run_alpn_lite_fallback_case(Some(ALPN_LITE)).await;
441	}
442
443	#[tokio::test(start_paused = true)]
444	async fn no_alpn_falls_back_to_draft14_and_switches_version_post_setup() {
445		run_alpn_lite_fallback_case(None).await;
446	}
447}