Skip to main content

moq_lite/
client.rs

1use crate::{
2	ALPN_14, ALPN_15, ALPN_16, ALPN_17, ALPN_LITE, ALPN_LITE_03, Error, NEGOTIATED, OriginConsumer, OriginProducer,
3	Session, Version, Versions,
4	coding::{self, Decode, Encode, Reader, Stream, Writer},
5	ietf, lite, setup,
6};
7
8/// A MoQ client session builder.
9#[derive(Default, Clone)]
10pub struct Client {
11	publish: Option<OriginConsumer>,
12	consume: Option<OriginProducer>,
13	versions: Versions,
14}
15
16impl Client {
17	pub fn new() -> Self {
18		Default::default()
19	}
20
21	pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
22		self.publish = publish.into();
23		self
24	}
25
26	pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
27		self.consume = consume.into();
28		self
29	}
30
31	pub fn with_versions(mut self, versions: Versions) -> Self {
32		self.versions = versions;
33		self
34	}
35
36	/// Perform the MoQ handshake as a client negotiating the version.
37	pub async fn connect<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
38		if self.publish.is_none() && self.consume.is_none() {
39			tracing::warn!("not publishing or consuming anything");
40		}
41
42		// If ALPN was used to negotiate the version, use the appropriate encoding.
43		// Default to IETF 14 if no ALPN was used and we'll negotiate the version later.
44		let (encoding, supported) = match session.protocol() {
45			Some(ALPN_17) => {
46				let v = self
47					.versions
48					.select(Version::Ietf(ietf::Version::Draft17))
49					.ok_or(Error::Version)?;
50
51				let ietf_v = ietf::Version::Draft17;
52
53				// Draft-17: SETUP uses uni streams
54				let mut parameters = ietf::Parameters::default();
55				parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
56				let parameters = parameters.encode_bytes(ietf_v)?;
57
58				let client_setup = setup::Client {
59					versions: coding::Versions::from([v.into()]),
60					parameters,
61				};
62
63				// Send and receive SETUP concurrently on uni streams
64				let send_fut = async {
65					let send = session.open_uni().await.map_err(Error::from_transport)?;
66					let mut writer: Writer<S::SendStream, Version> = Writer::new(send, v);
67					// Write stream type 0x2F00
68					writer.encode(&0x2F00u64).await?;
69					// Write SETUP message
70					writer.encode(&client_setup).await?;
71					Ok::<_, Error>(writer)
72				};
73
74				let recv_fut = async {
75					let recv = session.accept_uni().await.map_err(Error::from_transport)?;
76					let mut reader: Reader<S::RecvStream, Version> = Reader::new(recv, v);
77					// Read stream type
78					let stream_type: u64 = reader.decode().await?;
79					if stream_type != 0x2F00 {
80						return Err(Error::UnexpectedStream);
81					}
82					// Read SETUP message
83					let server: setup::Server = reader.decode().await?;
84					Ok::<_, Error>((reader, server))
85				};
86
87				let (send_result, recv_result) = tokio::join!(send_fut, recv_fut);
88				let writer = send_result?;
89				let (reader, _server_setup) = recv_result?;
90
91				// Construct a Stream from the uni streams for GOAWAY/control
92				let stream = Stream {
93					writer: writer.with_version(ietf_v),
94					reader: reader.with_version(ietf_v),
95				};
96
97				ietf::start(
98					session.clone(),
99					stream,
100					None, // Draft-17 removed MaxRequestId
101					true,
102					self.publish.clone(),
103					self.consume.clone(),
104					ietf_v,
105				)?;
106
107				tracing::debug!(version = ?v, "connected");
108				return Ok(Session::new(session, v));
109			}
110			Some(ALPN_16) => {
111				let v = self
112					.versions
113					.select(Version::Ietf(ietf::Version::Draft16))
114					.ok_or(Error::Version)?;
115				(v, v.into())
116			}
117			Some(ALPN_15) => {
118				let v = self
119					.versions
120					.select(Version::Ietf(ietf::Version::Draft15))
121					.ok_or(Error::Version)?;
122				(v, v.into())
123			}
124			Some(ALPN_14) => {
125				let v = self
126					.versions
127					.select(Version::Ietf(ietf::Version::Draft14))
128					.ok_or(Error::Version)?;
129				(v, v.into())
130			}
131			Some(ALPN_LITE_03) => {
132				self.versions
133					.select(Version::Lite(lite::Version::Lite03))
134					.ok_or(Error::Version)?;
135
136				// Starting with draft-03, there's no more SETUP control stream.
137				lite::start(
138					session.clone(),
139					None,
140					self.publish.clone(),
141					self.consume.clone(),
142					lite::Version::Lite03,
143				)?;
144
145				return Ok(Session::new(session, lite::Version::Lite03.into()));
146			}
147			Some(ALPN_LITE) | None => {
148				let supported = self.versions.filter(&NEGOTIATED.into()).ok_or(Error::Version)?;
149				(Version::Ietf(ietf::Version::Draft14), supported)
150			}
151			Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
152		};
153
154		let mut stream = Stream::open(&session, encoding).await?;
155
156		// The encoding is always an IETF version for SETUP negotiation.
157		let ietf_encoding = ietf::Version::try_from(encoding).map_err(|_| Error::Version)?;
158
159		let mut parameters = ietf::Parameters::default();
160		parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
161		parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
162		let parameters = parameters.encode_bytes(ietf_encoding)?;
163
164		let client = setup::Client {
165			versions: supported.clone().into(),
166			parameters,
167		};
168
169		// TODO pretty print the parameters.
170		tracing::trace!(?client, "sending client setup");
171		stream.writer.encode(&client).await?;
172
173		let mut server: setup::Server = stream.reader.decode().await?;
174		tracing::trace!(?server, "received server setup");
175
176		let version = supported
177			.iter()
178			.find(|v| coding::Version::from(**v) == server.version)
179			.copied()
180			.ok_or(Error::Version)?;
181
182		match version {
183			Version::Lite(v) => {
184				let stream = stream.with_version(v);
185				lite::start(
186					session.clone(),
187					Some(stream),
188					self.publish.clone(),
189					self.consume.clone(),
190					v,
191				)?;
192			}
193			Version::Ietf(v) => {
194				// Decode the parameters to get the initial request ID.
195				let parameters = ietf::Parameters::decode(&mut server.parameters, v)?;
196				let request_id_max = parameters
197					.get_varint(ietf::ParameterVarInt::MaxRequestId)
198					.map(ietf::RequestId);
199
200				let stream = stream.with_version(v);
201				ietf::start(
202					session.clone(),
203					stream,
204					request_id_max,
205					true,
206					self.publish.clone(),
207					self.consume.clone(),
208					v,
209				)?;
210			}
211		}
212
213		Ok(Session::new(session, version))
214	}
215}
216
217#[cfg(test)]
218mod tests {
219	use super::*;
220	use std::{
221		collections::VecDeque,
222		sync::{Arc, Mutex},
223	};
224
225	use crate::coding::{Decode, Encode};
226	use bytes::{BufMut, Bytes};
227
228	#[derive(Debug, Clone, Default)]
229	struct FakeError;
230
231	impl std::fmt::Display for FakeError {
232		fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233			write!(f, "fake transport error")
234		}
235	}
236
237	impl std::error::Error for FakeError {}
238
239	impl web_transport_trait::Error for FakeError {
240		fn session_error(&self) -> Option<(u32, String)> {
241			Some((0, "closed".to_string()))
242		}
243	}
244
245	#[derive(Clone, Default)]
246	struct FakeSession {
247		state: Arc<FakeSessionState>,
248	}
249
250	#[derive(Default)]
251	struct FakeSessionState {
252		protocol: Option<&'static str>,
253		control_stream: Mutex<Option<(FakeSendStream, FakeRecvStream)>>,
254		close_events: Mutex<Vec<(u32, String)>>,
255		close_notify: tokio::sync::Notify,
256		control_writes: Arc<Mutex<Vec<u8>>>,
257	}
258
259	impl FakeSession {
260		fn new(protocol: Option<&'static str>, server_control_bytes: Vec<u8>) -> Self {
261			let writes = Arc::new(Mutex::new(Vec::new()));
262			let send = FakeSendStream { writes: writes.clone() };
263			let recv = FakeRecvStream {
264				data: VecDeque::from(server_control_bytes),
265			};
266			let state = FakeSessionState {
267				protocol,
268				control_stream: Mutex::new(Some((send, recv))),
269				close_events: Mutex::new(Vec::new()),
270				close_notify: tokio::sync::Notify::new(),
271				control_writes: writes,
272			};
273			Self { state: Arc::new(state) }
274		}
275
276		fn control_writes(&self) -> Vec<u8> {
277			self.state.control_writes.lock().unwrap().clone()
278		}
279
280		async fn wait_for_first_close(&self) -> (u32, String) {
281			loop {
282				let notified = self.state.close_notify.notified();
283				if let Some(close) = self.state.close_events.lock().unwrap().first().cloned() {
284					return close;
285				}
286				notified.await;
287			}
288		}
289	}
290
291	impl web_transport_trait::Session for FakeSession {
292		type SendStream = FakeSendStream;
293		type RecvStream = FakeRecvStream;
294		type Error = FakeError;
295
296		async fn accept_uni(&self) -> Result<Self::RecvStream, Self::Error> {
297			std::future::pending().await
298		}
299
300		async fn accept_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
301			std::future::pending().await
302		}
303
304		async fn open_bi(&self) -> Result<(Self::SendStream, Self::RecvStream), Self::Error> {
305			self.state.control_stream.lock().unwrap().take().ok_or(FakeError)
306		}
307
308		async fn open_uni(&self) -> Result<Self::SendStream, Self::Error> {
309			std::future::pending().await
310		}
311
312		fn send_datagram(&self, _payload: Bytes) -> Result<(), Self::Error> {
313			Ok(())
314		}
315
316		async fn recv_datagram(&self) -> Result<Bytes, Self::Error> {
317			std::future::pending().await
318		}
319
320		fn max_datagram_size(&self) -> usize {
321			1200
322		}
323
324		fn protocol(&self) -> Option<&str> {
325			self.state.protocol
326		}
327
328		fn close(&self, code: u32, reason: &str) {
329			self.state.close_events.lock().unwrap().push((code, reason.to_string()));
330			self.state.close_notify.notify_waiters();
331		}
332
333		async fn closed(&self) -> Self::Error {
334			self.state.close_notify.notified().await;
335			FakeError
336		}
337	}
338
339	#[derive(Clone, Default)]
340	struct FakeSendStream {
341		writes: Arc<Mutex<Vec<u8>>>,
342	}
343
344	impl web_transport_trait::SendStream for FakeSendStream {
345		type Error = FakeError;
346
347		async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
348			self.writes.lock().unwrap().put_slice(buf);
349			Ok(buf.len())
350		}
351
352		fn set_priority(&mut self, _order: u8) {}
353
354		fn finish(&mut self) -> Result<(), Self::Error> {
355			Ok(())
356		}
357
358		fn reset(&mut self, _code: u32) {}
359
360		async fn closed(&mut self) -> Result<(), Self::Error> {
361			Ok(())
362		}
363	}
364
365	struct FakeRecvStream {
366		data: VecDeque<u8>,
367	}
368
369	impl web_transport_trait::RecvStream for FakeRecvStream {
370		type Error = FakeError;
371
372		async fn read(&mut self, dst: &mut [u8]) -> Result<Option<usize>, Self::Error> {
373			if self.data.is_empty() {
374				return Ok(None);
375			}
376
377			let size = dst.len().min(self.data.len());
378			for slot in dst.iter_mut().take(size) {
379				*slot = self.data.pop_front().unwrap();
380			}
381			Ok(Some(size))
382		}
383
384		fn stop(&mut self, _code: u32) {}
385
386		async fn closed(&mut self) -> Result<(), Self::Error> {
387			Ok(())
388		}
389	}
390
391	fn mock_server_setup(negotiated: Version) -> Vec<u8> {
392		let mut encoded = Vec::new();
393		let server = setup::Server {
394			version: negotiated.into(),
395			parameters: Bytes::new(),
396		};
397		server
398			.encode(&mut encoded, Version::Ietf(ietf::Version::Draft14))
399			.unwrap();
400
401		// Add a setup-stream SessionInfo frame using the negotiated Lite version.
402		let info = lite::SessionInfo { bitrate: Some(1) };
403		let lite_v = lite::Version::try_from(negotiated).unwrap();
404		info.encode(&mut encoded, lite_v).unwrap();
405
406		encoded
407	}
408
409	async fn run_alpn_lite_fallback_case(protocol: Option<&'static str>) {
410		let fake = FakeSession::new(protocol, mock_server_setup(Version::Lite(lite::Version::Lite01)));
411		let client = Client::new().with_versions(
412			[
413				Version::Lite(lite::Version::Lite03),
414				Version::Lite(lite::Version::Lite02),
415				Version::Lite(lite::Version::Lite01),
416				Version::Ietf(ietf::Version::Draft14),
417			]
418			.into(),
419		);
420
421		let _session = client.connect(fake.clone()).await.unwrap();
422
423		// Verify the client setup was encoded using Draft14 framing (ALPN_LITE fallback path).
424		let mut setup_bytes = Bytes::from(fake.control_writes());
425		let setup = setup::Client::decode(&mut setup_bytes, Version::Ietf(ietf::Version::Draft14)).unwrap();
426		let advertised: Vec<Version> = setup.versions.iter().map(|v| Version::try_from(*v).unwrap()).collect();
427		assert_eq!(
428			advertised,
429			vec![
430				Version::Lite(lite::Version::Lite02),
431				Version::Lite(lite::Version::Lite01),
432				Version::Ietf(ietf::Version::Draft14),
433			]
434		);
435
436		// The first close comes from the background lite session task.
437		// Code 0 ("cancelled") means SessionInfo decoded successfully after set_version().
438		let (code, _) = fake.wait_for_first_close().await;
439		assert_eq!(code, Error::Cancel.to_code());
440	}
441
442	#[tokio::test(start_paused = true)]
443	async fn alpn_lite_falls_back_to_draft14_and_switches_version_post_setup() {
444		run_alpn_lite_fallback_case(Some(ALPN_LITE)).await;
445	}
446
447	#[tokio::test(start_paused = true)]
448	async fn no_alpn_falls_back_to_draft14_and_switches_version_post_setup() {
449		run_alpn_lite_fallback_case(None).await;
450	}
451}