Skip to main content

moq_lite/
server.rs

1// TODO: Uncomment when observability feature is merged
2// use std::sync::Arc;
3
4use crate::{
5	Error, NEGOTIATED, OriginConsumer, OriginProducer, Session, Version,
6	coding::{Decode, Encode, Stream},
7	ietf, lite, setup,
8};
9
10/// A MoQ server session builder.
11#[derive(Default, Clone)]
12pub struct Server {
13	publish: Option<OriginConsumer>,
14	consume: Option<OriginProducer>,
15	// TODO: Uncomment when observability feature is merged
16	// stats: Option<Arc<dyn crate::Stats>>,
17}
18
19impl Server {
20	pub fn new() -> Self {
21		Default::default()
22	}
23
24	pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
25		self.publish = publish.into();
26		self
27	}
28
29	pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
30		self.consume = consume.into();
31		self
32	}
33
34	// TODO: Uncomment when observability feature is merged
35	// pub fn with_stats(mut self, stats: impl Into<Option<Arc<dyn crate::Stats>>>) -> Self {
36	// 	self.stats = stats.into();
37	// 	self
38	// }
39
40	/// Perform the MoQ handshake as a server for the given session.
41	pub async fn accept<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
42		if self.publish.is_none() && self.consume.is_none() {
43			tracing::warn!("not publishing or consuming anything");
44		}
45
46		let (encoding, supported) = match session.protocol() {
47			Some(p) if p == ietf::ALPN_16 => (
48				Version::Ietf(ietf::Version::Draft16),
49				vec![ietf::Version::Draft16.into()],
50			),
51			Some(p) if p == ietf::ALPN_15 => (
52				Version::Ietf(ietf::Version::Draft15),
53				vec![ietf::Version::Draft15.into()],
54			),
55			Some(p) if p == ietf::ALPN_14 => (
56				Version::Ietf(ietf::Version::Draft14),
57				vec![ietf::Version::Draft14.into()],
58			),
59			Some(p) if p == lite::ALPN => (Version::Ietf(ietf::Version::Draft14), NEGOTIATED.to_vec()),
60			None => (Version::Ietf(ietf::Version::Draft14), NEGOTIATED.to_vec()),
61			Some(p) => return Err(Error::UnknownAlpn(p.to_string())),
62		};
63
64		let mut stream = Stream::accept(&session, encoding).await?;
65
66		let mut client: setup::Client = stream.reader.decode().await?;
67		tracing::trace!(?client, "received client setup");
68
69		// Choose the version to use
70		let version = client
71			.versions
72			.iter()
73			.flat_map(|v| Version::try_from(*v).ok())
74			.find(|v| supported.contains(v))
75			.ok_or(Error::Version)?;
76
77		// Only encode parameters if we're using the IETF draft because it has max_request_id
78		let parameters = match version {
79			Version::Ietf(ietf_version) => {
80				let mut parameters = ietf::Parameters::default();
81				parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
82				parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
83				parameters.encode_bytes(ietf_version)
84			}
85			Version::Lite(_) => lite::Parameters::default().encode_bytes(()),
86		};
87
88		let server = setup::Server {
89			version: version.into(),
90			parameters,
91		};
92		tracing::trace!(?server, "sending server setup");
93		stream.writer.encode(&server).await?;
94
95		match version {
96			Version::Lite(version) => {
97				let stream = stream.with_version(version);
98				lite::start(
99					session.clone(),
100					stream,
101					self.publish.clone(),
102					self.consume.clone(),
103					version,
104				)
105				.await?;
106			}
107			Version::Ietf(version) => {
108				// Decode the client's parameters to get their max request ID.
109				let parameters = ietf::Parameters::decode(&mut client.parameters, version)?;
110				let request_id_max =
111					ietf::RequestId(parameters.get_varint(ietf::ParameterVarInt::MaxRequestId).unwrap_or(0));
112
113				let stream = stream.with_version(version);
114				ietf::start(
115					session.clone(),
116					stream,
117					request_id_max,
118					false,
119					self.publish.clone(),
120					self.consume.clone(),
121					version,
122				)
123				.await?;
124			}
125		};
126
127		tracing::debug!(?version, "connected");
128
129		Ok(Session::new(session))
130	}
131}