moq_lite/
server.rs

1// TODO: Uncomment when observability feature is merged
2// use std::sync::Arc;
3
4use crate::{
5	Error, OriginConsumer, OriginProducer, Session, VERSIONS,
6	coding::{Decode, Encode, Stream},
7	ietf, lite, setup,
8};
9
10/// A MoQ server session builder.
11#[derive(Default, Clone)]
12pub struct Server {
13	publish: Option<OriginConsumer>,
14	consume: Option<OriginProducer>,
15	// TODO: Uncomment when observability feature is merged
16	// stats: Option<Arc<dyn crate::Stats>>,
17}
18
19impl Server {
20	pub fn new() -> Self {
21		Default::default()
22	}
23
24	pub fn with_publish(mut self, publish: impl Into<Option<OriginConsumer>>) -> Self {
25		self.publish = publish.into();
26		self
27	}
28
29	pub fn with_consume(mut self, consume: impl Into<Option<OriginProducer>>) -> Self {
30		self.consume = consume.into();
31		self
32	}
33
34	// TODO: Uncomment when observability feature is merged
35	// pub fn with_stats(mut self, stats: impl Into<Option<Arc<dyn crate::Stats>>>) -> Self {
36	// 	self.stats = stats.into();
37	// 	self
38	// }
39
40	/// Perform the MoQ handshake as a server for the given session.
41	pub async fn accept<S: web_transport_trait::Session>(&self, session: S) -> Result<Session, Error> {
42		if self.publish.is_none() && self.consume.is_none() {
43			tracing::warn!("not publishing or consuming anything");
44		}
45
46		// Accept with an initial version; we'll switch to the negotiated version later
47		let mut stream = Stream::accept(&session, ()).await?;
48		let mut client: setup::Client = stream.reader.decode().await?;
49		tracing::trace!(?client, "received client setup");
50
51		// Choose the version to use
52		let version = client
53			.versions
54			.iter()
55			.find(|v| VERSIONS.contains(v))
56			.copied()
57			.ok_or_else(|| Error::Version(client.versions.clone(), VERSIONS.into()))?;
58
59		// Only encode parameters if we're using the IETF draft because it has max_request_id
60		let parameters = if ietf::Version::try_from(version).is_ok() && client.kind == setup::ClientKind::Ietf14 {
61			let mut parameters = ietf::Parameters::default();
62			parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
63			parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
64			parameters.encode_bytes(())
65		} else {
66			lite::Parameters::default().encode_bytes(())
67		};
68
69		let server = setup::Server { version, parameters };
70		tracing::trace!(?server, "sending server setup");
71
72		let mut stream = stream.with_version(client.kind.reply());
73		stream.writer.encode(&server).await?;
74
75		if let Ok(version) = lite::Version::try_from(version) {
76			let stream = stream.with_version(version);
77			lite::start(
78				session.clone(),
79				stream,
80				self.publish.clone(),
81				self.consume.clone(),
82				version,
83			)
84			.await?;
85		} else if let Ok(version) = ietf::Version::try_from(version) {
86			// Decode the client's parameters to get their max request ID.
87			let parameters = ietf::Parameters::decode(&mut client.parameters, version)?;
88			let request_id_max =
89				ietf::RequestId(parameters.get_varint(ietf::ParameterVarInt::MaxRequestId).unwrap_or(0));
90
91			let stream = stream.with_version(version);
92			ietf::start(
93				session.clone(),
94				stream,
95				request_id_max,
96				false,
97				self.publish.clone(),
98				self.consume.clone(),
99				version,
100			)
101			.await?;
102		} else {
103			// unreachable, but just in case
104			return Err(Error::Version(client.versions, VERSIONS.into()));
105		}
106
107		tracing::debug!(?version, "connected");
108
109		Ok(Session::new(session))
110	}
111}