Skip to main content

moq_lite/
client.rs

1// TODO: Uncomment when observability feature is merged
2// use std::sync::Arc;
3
4use crate::{
5	Error, NEGOTIATED, OriginConsumer, OriginProducer, Session, Version,
6	coding::{self, Decode, Encode, Stream},
7	ietf, lite, setup,
8};
9
10/// A MoQ client session builder.
11#[derive(Default, Clone)]
12pub struct Client {
13	publish: Option<OriginConsumer>,
14	consume: Option<OriginProducer>,
15	// TODO: Uncomment when observability feature is merged
16	// stats: Option<Arc<dyn crate::Stats>>,
17}
18
19impl Client {
20	pub fn new() -> Self {
21		Default::default()
22	}
23
24	pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
25		self.publish = publish.into();
26		self
27	}
28
29	pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
30		self.consume = consume.into();
31		self
32	}
33
34	// TODO: Uncomment when observability feature is merged
35	// pub fn with_stats(mut self, stats: impl Into<Option<Arc<dyn crate::Stats>>>) -> Self {
36	// 	self.stats = stats.into();
37	// 	self
38	// }
39
40	/// Perform the MoQ handshake as a client negotiating the version.
41	pub async fn connect<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
42		if self.publish.is_none() && self.consume.is_none() {
43			tracing::warn!("not publishing or consuming anything");
44		}
45
46		// If ALPN was used to negotiate the version, use the appropriate encoding.
47		// Default to IETF 14 if no ALPN was used and we'll negotiate the version later.
48		let (encoding, supported) = match session.protocol() {
49			Some(p) if p == ietf::ALPN_16 => (
50				Version::Ietf(ietf::Version::Draft16),
51				vec![ietf::Version::Draft16.into()],
52			),
53			Some(p) if p == ietf::ALPN_15 => (
54				Version::Ietf(ietf::Version::Draft15),
55				vec![ietf::Version::Draft15.into()],
56			),
57			Some(p) if p == ietf::ALPN_14 => (
58				Version::Ietf(ietf::Version::Draft14),
59				vec![ietf::Version::Draft14.into()],
60			),
61			Some(p) if p == lite::ALPN => (Version::Ietf(ietf::Version::Draft14), NEGOTIATED.to_vec()),
62			None => (Version::Ietf(ietf::Version::Draft14), NEGOTIATED.to_vec()),
63			Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
64		};
65
66		let mut stream = Stream::open(&session, encoding).await?;
67
68		let ietf_version = match encoding {
69			Version::Ietf(v) => v,
70			_ => ietf::Version::Draft14,
71		};
72		let mut parameters = ietf::Parameters::default();
73		parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
74		parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
75		let parameters = parameters.encode_bytes(ietf_version);
76
77		let client = setup::Client {
78			versions: supported.clone().into(),
79			parameters,
80		};
81
82		// TODO pretty print the parameters.
83		tracing::trace!(?client, "sending client setup");
84		stream.writer.encode(&client).await?;
85
86		let mut server: setup::Server = stream.reader.decode().await?;
87		tracing::trace!(?server, "received server setup");
88
89		let version = supported
90			.iter()
91			.find(|v| coding::Version::from(**v) == server.version)
92			.copied()
93			.ok_or(Error::Version)?;
94
95		match version {
96			Version::Lite(version) => {
97				let stream = stream.with_version(version);
98				lite::start(
99					session.clone(),
100					stream,
101					self.publish.clone(),
102					self.consume.clone(),
103					version,
104				)
105				.await?;
106			}
107			Version::Ietf(version) => {
108				// Decode the parameters to get the initial request ID.
109				let parameters = ietf::Parameters::decode(&mut server.parameters, version)?;
110				let request_id_max = ietf::RequestId(
111					parameters
112						.get_varint(ietf::ParameterVarInt::MaxRequestId)
113						.unwrap_or_default(),
114				);
115
116				let stream = stream.with_version(version);
117				ietf::start(
118					session.clone(),
119					stream,
120					request_id_max,
121					true,
122					self.publish.clone(),
123					self.consume.clone(),
124					version,
125				)
126				.await?;
127			}
128		}
129
130		tracing::debug!(version = ?server.version, "connected");
131
132		Ok(Session::new(session))
133	}
134}