Skip to main content

moq_lite/
client.rs

1use crate::{
2	ALPN_14, ALPN_15, ALPN_16, ALPN_17, ALPN_LITE, ALPN_LITE_03, Error, NEGOTIATED, OriginConsumer, OriginProducer,
3	Session, Version, Versions,
4	coding::{self, Decode, Encode, Reader, Stream, Writer},
5	ietf, lite, setup,
6};
7
8/// A MoQ client session builder.
9#[derive(Default, Clone)]
10pub struct Client {
11	publish: Option<OriginConsumer>,
12	consume: Option<OriginProducer>,
13	versions: Versions,
14}
15
16impl Client {
17	pub fn new() -> Self {
18		Default::default()
19	}
20
21	pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
22		self.publish = publish.into();
23		self
24	}
25
26	pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
27		self.consume = consume.into();
28		self
29	}
30
31	pub fn with_versions(mut self, versions: Versions) -> Self {
32		self.versions = versions;
33		self
34	}
35
36	/// Perform the MoQ handshake as a client negotiating the version.
37	pub async fn connect<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
38		if self.publish.is_none() && self.consume.is_none() {
39			tracing::warn!("not publishing or consuming anything");
40		}
41
42		// If ALPN was used to negotiate the version, use the appropriate encoding.
43		// Default to IETF 14 if no ALPN was used and we'll negotiate the version later.
44		let (encoding, supported) = match session.protocol() {
45			Some(ALPN_17) => {
46				let v = self
47					.versions
48					.select(Version::Ietf(ietf::Version::Draft17))
49					.ok_or(Error::Version)?;
50
51				let ietf_v = ietf::Version::Draft17;
52
53				// Draft-17: SETUP uses uni streams
54				let mut parameters = ietf::Parameters::default();
55				parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
56				let parameters = parameters.encode_bytes(ietf_v)?;
57
58				let client_setup = setup::Client {
59					versions: coding::Versions::from([v.into()]),
60					parameters,
61				};
62
63				// Send and receive SETUP concurrently on uni streams
64				let send_fut = async {
65					let send = session.open_uni().await.map_err(Error::from_transport)?;
66					let mut writer: Writer<S::SendStream, Version> = Writer::new(send, v);
67					// Write SETUP message (includes stream type 0x2F00)
68					writer.encode(&client_setup).await?;
69					Ok::<_, Error>(writer)
70				};
71
72				let recv_fut = async {
73					let recv = session.accept_uni().await.map_err(Error::from_transport)?;
74					let mut reader: Reader<S::RecvStream, Version> = Reader::new(recv, v);
75					// Read SETUP message (includes stream type 0x2F00)
76					let server: setup::Server = reader.decode().await?;
77					Ok::<_, Error>((reader, server))
78				};
79
80				let (send_result, recv_result) = tokio::join!(send_fut, recv_fut);
81				let writer = send_result?;
82				let (reader, _server_setup) = recv_result?;
83
84				// Construct a Stream from the uni streams for GOAWAY/control
85				let stream = Stream {
86					writer: writer.with_version(ietf_v),
87					reader: reader.with_version(ietf_v),
88				};
89
90				ietf::start(
91					session.clone(),
92					stream,
93					None, // Draft-17 removed MaxRequestId
94					true,
95					self.publish.clone(),
96					self.consume.clone(),
97					ietf_v,
98				)?;
99
100				tracing::debug!(version = ?v, "connected");
101				return Ok(Session::new(session, v));
102			}
103			Some(ALPN_16) => {
104				let v = self
105					.versions
106					.select(Version::Ietf(ietf::Version::Draft16))
107					.ok_or(Error::Version)?;
108				(v, v.into())
109			}
110			Some(ALPN_15) => {
111				let v = self
112					.versions
113					.select(Version::Ietf(ietf::Version::Draft15))
114					.ok_or(Error::Version)?;
115				(v, v.into())
116			}
117			Some(ALPN_14) => {
118				let v = self
119					.versions
120					.select(Version::Ietf(ietf::Version::Draft14))
121					.ok_or(Error::Version)?;
122				(v, v.into())
123			}
124			Some(ALPN_LITE_03) => {
125				self.versions
126					.select(Version::Lite(lite::Version::Lite03))
127					.ok_or(Error::Version)?;
128
129				// Starting with draft-03, there's no more SETUP control stream.
130				lite::start(
131					session.clone(),
132					None,
133					self.publish.clone(),
134					self.consume.clone(),
135					lite::Version::Lite03,
136				)?;
137
138				return Ok(Session::new(session, lite::Version::Lite03.into()));
139			}
140			Some(ALPN_LITE) | None => {
141				let supported = self.versions.filter(&NEGOTIATED.into()).ok_or(Error::Version)?;
142				(Version::Ietf(ietf::Version::Draft14), supported)
143			}
144			Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
145		};
146
147		let mut stream = Stream::open(&session, encoding).await?;
148
149		// The encoding is always an IETF version for SETUP negotiation.
150		let ietf_encoding = ietf::Version::try_from(encoding).map_err(|_| Error::Version)?;
151
152		let mut parameters = ietf::Parameters::default();
153		parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
154		parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
155		let parameters = parameters.encode_bytes(ietf_encoding)?;
156
157		let client = setup::Client {
158			versions: supported.clone().into(),
159			parameters,
160		};
161
162		// TODO pretty print the parameters.
163		tracing::trace!(?client, "sending client setup");
164		stream.writer.encode(&client).await?;
165
166		let mut server: setup::Server = stream.reader.decode().await?;
167		tracing::trace!(?server, "received server setup");
168
169		let version = supported
170			.iter()
171			.find(|v| coding::Version::from(**v) == server.version)
172			.copied()
173			.ok_or(Error::Version)?;
174
175		match version {
176			Version::Lite(v) => {
177				let stream = stream.with_version(v);
178				lite::start(
179					session.clone(),
180					Some(stream),
181					self.publish.clone(),
182					self.consume.clone(),
183					v,
184				)?;
185			}
186			Version::Ietf(v) => {
187				// Decode the parameters to get the initial request ID.
188				let parameters = ietf::Parameters::decode(&mut server.parameters, v)?;
189				let request_id_max = parameters
190					.get_varint(ietf::ParameterVarInt::MaxRequestId)
191					.map(ietf::RequestId);
192
193				let stream = stream.with_version(v);
194				ietf::start(
195					session.clone(),
196					stream,
197					request_id_max,
198					true,
199					self.publish.clone(),
200					self.consume.clone(),
201					v,
202				)?;
203			}
204		}
205
206		Ok(Session::new(session, version))
207	}
208}
209
210#[cfg(test)]
211mod tests {
212	use super::*;
213	use std::{
214		collections::VecDeque,
215		sync::{Arc, Mutex},
216	};
217
218	use crate::coding::{Decode, Encode};
219	use bytes::{BufMut, Bytes};
220
221	#[derive(Debug, Clone, Default)]
222	struct FakeError;
223
224	impl std::fmt::Display for FakeError {
225		fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226			write!(f, "fake transport error")
227		}
228	}
229
230	impl std::error::Error for FakeError {}
231
232	impl web_transport_trait::Error for FakeError {
233		fn session_error(&self) -> Option<(u32, String)> {
234			Some((0, "closed".to_string()))
235		}
236	}
237
238	#[derive(Clone, Default)]
239	struct FakeSession {
240		state: Arc<FakeSessionState>,
241	}
242
243	#[derive(Default)]
244	struct FakeSessionState {
245		protocol: Option<&'static str>,
246		control_stream: Mutex<Option<(FakeSendStream, FakeRecvStream)>>,
247		close_events: Mutex<Vec<(u32, String)>>,
248		close_notify: tokio::sync::Notify,
249		control_writes: Arc<Mutex<Vec<u8>>>,
250	}
251
252	impl FakeSession {
253		fn new(protocol: Option<&'static str>, server_control_bytes: Vec<u8>) -> Self {
254			let writes = Arc::new(Mutex::new(Vec::new()));
255			let send = FakeSendStream { writes: writes.clone() };
256			let recv = FakeRecvStream {
257				data: VecDeque::from(server_control_bytes),
258			};
259			let state = FakeSessionState {
260				protocol,
261				control_stream: Mutex::new(Some((send, recv))),
262				close_events: Mutex::new(Vec::new()),
263				close_notify: tokio::sync::Notify::new(),
264				control_writes: writes,
265			};
266			Self { state: Arc::new(state) }
267		}
268
269		fn control_writes(&self) -> Vec<u8> {
270			self.state.control_writes.lock().unwrap().clone()
271		}
272
273		async fn wait_for_first_close(&self) -> (u32, String) {
274			loop {
275				let notified = self.state.close_notify.notified();
276				if let Some(close) = self.state.close_events.lock().unwrap().first().cloned() {
277					return close;
278				}
279				notified.await;
280			}
281		}
282	}
283
284	impl web_transport_trait::Session for FakeSession {
285		type SendStream = FakeSendStream;
286		type RecvStream = FakeRecvStream;
287		type Error = FakeError;
288
289		async fn accept_uni(&self) -> Result<Self::RecvStream, Self::Error> {
290			std::future::pending().await
291		}
292
293		async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
294			std::future::pending().await
295		}
296
297		async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
298			self.state.control_stream.lock().unwrap().take().ok_or(FakeError)
299		}
300
301		async fn open_uni(&self) -> Result<Self::SendStream, Self::Error> {
302			std::future::pending().await
303		}
304
305		fn send_datagram(&self, _payload: Bytes) -> Result<(), Self::Error> {
306			Ok(())
307		}
308
309		async fn recv_datagram(&self) -> Result<Bytes, Self::Error> {
310			std::future::pending().await
311		}
312
313		fn max_datagram_size(&self) -> usize {
314			1200
315		}
316
317		fn protocol(&self) -> Option<&str> {
318			self.state.protocol
319		}
320
321		fn close(&self, code: u32, reason: &str) {
322			self.state.close_events.lock().unwrap().push((code, reason.to_string()));
323			self.state.close_notify.notify_waiters();
324		}
325
326		async fn closed(&self) -> Self::Error {
327			self.state.close_notify.notified().await;
328			FakeError
329		}
330	}
331
332	#[derive(Clone, Default)]
333	struct FakeSendStream {
334		writes: Arc<Mutex<Vec<u8>>>,
335	}
336
337	impl web_transport_trait::SendStream for FakeSendStream {
338		type Error = FakeError;
339
340		async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
341			self.writes.lock().unwrap().put_slice(buf);
342			Ok(buf.len())
343		}
344
345		fn set_priority(&mut self, _order: u8) {}
346
347		fn finish(&mut self) -> Result<(), Self::Error> {
348			Ok(())
349		}
350
351		fn reset(&mut self, _code: u32) {}
352
353		async fn closed(&mut self) -> Result<(), Self::Error> {
354			Ok(())
355		}
356	}
357
358	struct FakeRecvStream {
359		data: VecDeque<u8>,
360	}
361
362	impl web_transport_trait::RecvStream for FakeRecvStream {
363		type Error = FakeError;
364
365		async fn read(&mut self, dst: &mut [u8]) -> Result<Option<usize>, Self::Error> {
366			if self.data.is_empty() {
367				return Ok(None);
368			}
369
370			let size = dst.len().min(self.data.len());
371			for slot in dst.iter_mut().take(size) {
372				*slot = self.data.pop_front().unwrap();
373			}
374			Ok(Some(size))
375		}
376
377		fn stop(&mut self, _code: u32) {}
378
379		async fn closed(&mut self) -> Result<(), Self::Error> {
380			Ok(())
381		}
382	}
383
384	fn mock_server_setup(negotiated: Version) -> Vec<u8> {
385		let mut encoded = Vec::new();
386		let server = setup::Server {
387			version: negotiated.into(),
388			parameters: Bytes::new(),
389		};
390		server
391			.encode(&mut encoded, Version::Ietf(ietf::Version::Draft14))
392			.unwrap();
393
394		// Add a setup-stream SessionInfo frame using the negotiated Lite version.
395		let info = lite::SessionInfo { bitrate: Some(1) };
396		let lite_v = lite::Version::try_from(negotiated).unwrap();
397		info.encode(&mut encoded, lite_v).unwrap();
398
399		encoded
400	}
401
402	async fn run_alpn_lite_fallback_case(protocol: Option<&'static str>) {
403		let fake = FakeSession::new(protocol, mock_server_setup(Version::Lite(lite::Version::Lite01)));
404		let client = Client::new().with_versions(
405			[
406				Version::Lite(lite::Version::Lite03),
407				Version::Lite(lite::Version::Lite02),
408				Version::Lite(lite::Version::Lite01),
409				Version::Ietf(ietf::Version::Draft14),
410			]
411			.into(),
412		);
413
414		let _session = client.connect(fake.clone()).await.unwrap();
415
416		// Verify the client setup was encoded using Draft14 framing (ALPN_LITE fallback path).
417		let mut setup_bytes = Bytes::from(fake.control_writes());
418		let setup = setup::Client::decode(&mut setup_bytes, Version::Ietf(ietf::Version::Draft14)).unwrap();
419		let advertised: Vec<Version> = setup.versions.iter().map(|v| Version::try_from(*v).unwrap()).collect();
420		assert_eq!(
421			advertised,
422			vec![
423				Version::Lite(lite::Version::Lite02),
424				Version::Lite(lite::Version::Lite01),
425				Version::Ietf(ietf::Version::Draft14),
426			]
427		);
428
429		// The first close comes from the background lite session task.
430		// Code 0 ("cancelled") means SessionInfo decoded successfully after set_version().
431		let (code, _) = fake.wait_for_first_close().await;
432		assert_eq!(code, Error::Cancel.to_code());
433	}
434
435	#[tokio::test(start_paused = true)]
436	async fn alpn_lite_falls_back_to_draft14_and_switches_version_post_setup() {
437		run_alpn_lite_fallback_case(Some(ALPN_LITE)).await;
438	}
439
440	#[tokio::test(start_paused = true)]
441	async fn no_alpn_falls_back_to_draft14_and_switches_version_post_setup() {
442		run_alpn_lite_fallback_case(None).await;
443	}
444}