Skip to main content

moq_lite/
client.rs

1// TODO: Uncomment when observability feature is merged
2// use std::sync::Arc;
3
4use crate::{
5	Error, NEGOTIATED, OriginConsumer, OriginProducer, Session, Version, Versions,
6	coding::{self, Decode, Encode, Stream},
7	ietf, lite, setup,
8};
9
10/// A MoQ client session builder.
11#[derive(Default, Clone)]
12pub struct Client {
13	publish: Option<OriginConsumer>,
14	consume: Option<OriginProducer>,
15	versions: Versions,
16	// TODO: Uncomment when observability feature is merged
17	// stats: Option<Arc<dyn crate::Stats>>,
18}
19
20impl Client {
21	pub fn new() -> Self {
22		Default::default()
23	}
24
25	pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
26		self.publish = publish.into();
27		self
28	}
29
30	pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
31		self.consume = consume.into();
32		self
33	}
34
35	pub fn with_versions(mut self, versions: Versions) -> Self {
36		self.versions = versions;
37		self
38	}
39
40	// TODO: Uncomment when observability feature is merged
41	// pub fn with_stats(mut self, stats: impl Into<Option<Arc<dyn crate::Stats>>>) -> Self {
42	// 	self.stats = stats.into();
43	// 	self
44	// }
45
46	/// Perform the MoQ handshake as a client negotiating the version.
47	pub async fn connect<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
48		if self.publish.is_none() && self.consume.is_none() {
49			tracing::warn!("not publishing or consuming anything");
50		}
51
52		// If ALPN was used to negotiate the version, use the appropriate encoding.
53		// Default to IETF 14 if no ALPN was used and we'll negotiate the version later.
54		let (encoding, supported) = match session.protocol() {
55			Some(ietf::ALPN_16) => {
56				let v = self
57					.versions
58					.select(ietf::Version::Draft16.into())
59					.ok_or(Error::Version)?;
60				(v, v.into())
61			}
62			Some(ietf::ALPN_15) => {
63				let v = self
64					.versions
65					.select(ietf::Version::Draft15.into())
66					.ok_or(Error::Version)?;
67				(v, v.into())
68			}
69			Some(ietf::ALPN_14) => {
70				let v = self
71					.versions
72					.select(ietf::Version::Draft14.into())
73					.ok_or(Error::Version)?;
74				(v, v.into())
75			}
76			Some(lite::ALPN_03) => {
77				self.versions
78					.select(lite::Version::Draft03.into())
79					.ok_or(Error::Version)?;
80
81				// Starting with draft-03, there's no more SETUP control stream.
82				lite::start(
83					session.clone(),
84					None,
85					self.publish.clone(),
86					self.consume.clone(),
87					lite::Version::Draft03,
88				)?;
89
90				tracing::debug!(version = ?lite::Version::Draft03, "connected");
91
92				return Ok(Session::new(session));
93			}
94			Some(lite::ALPN) | None => {
95				let supported = self.versions.filter(&NEGOTIATED.into()).ok_or(Error::Version)?;
96				(ietf::Version::Draft14.into(), supported)
97			}
98			Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
99		};
100
101		let mut stream = Stream::open(&session, encoding).await?;
102
103		let ietf_version = match encoding {
104			Version::Ietf(v) => v,
105			_ => ietf::Version::Draft14,
106		};
107		let mut parameters = ietf::Parameters::default();
108		parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
109		parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
110		let parameters = parameters.encode_bytes(ietf_version)?;
111
112		let client = setup::Client {
113			versions: supported.clone().into(),
114			parameters,
115		};
116
117		// TODO pretty print the parameters.
118		tracing::trace!(?client, "sending client setup");
119		stream.writer.encode(&client).await?;
120
121		let mut server: setup::Server = stream.reader.decode().await?;
122		tracing::trace!(?server, "received server setup");
123
124		let version = supported
125			.iter()
126			.find(|v| coding::Version::from(**v) == server.version)
127			.copied()
128			.ok_or(Error::Version)?;
129
130		match version {
131			Version::Lite(version) => {
132				let stream = stream.with_version(version);
133				lite::start(
134					session.clone(),
135					Some(stream),
136					self.publish.clone(),
137					self.consume.clone(),
138					version,
139				)?;
140			}
141			Version::Ietf(version) => {
142				// Decode the parameters to get the initial request ID.
143				let parameters = ietf::Parameters::decode(&mut server.parameters, version)?;
144				let request_id_max = ietf::RequestId(
145					parameters
146						.get_varint(ietf::ParameterVarInt::MaxRequestId)
147						.unwrap_or_default(),
148				);
149
150				let stream = stream.with_version(version);
151				ietf::start(
152					session.clone(),
153					stream,
154					request_id_max,
155					true,
156					self.publish.clone(),
157					self.consume.clone(),
158					version,
159				)?;
160			}
161		}
162
163		tracing::debug!(version = ?server.version, "connected");
164
165		Ok(Session::new(session))
166	}
167}