moq_lite/
session.rs

1use std::sync::Arc;
2
3use crate::{
4	coding::{self, Decode, Encode, Stream},
5	ietf, lite, setup, Error, OriginConsumer, OriginProducer,
6};
7
8pub struct Session<S: web_transport_trait::Session> {
9	session: S,
10}
11
12/// The versions of MoQ that are supported by this implementation.
13///
14/// Ordered by preference, with the client's preference taking priority.
15pub const VERSIONS: [coding::Version; 3] = [
16	lite::Version::Draft02.coding(),
17	lite::Version::Draft01.coding(),
18	ietf::Version::Draft14.coding(),
19];
20
21/// The ALPN strings for supported versions.
22pub const ALPNS: [&str; 2] = [lite::ALPN, ietf::ALPN];
23
24impl<S: web_transport_trait::Session> Session<S> {
25	fn new(session: S) -> Self {
26		Self { session }
27	}
28
29	/// Perform the MoQ handshake as a client, negotiating the version.
30	///
31	/// Publishing is performed with [OriginConsumer] and subscribing with [OriginProducer].
32	/// The connection remains active until the session is closed.
33	pub async fn connect(
34		session: S,
35		publish: impl Into<Option<OriginConsumer>>,
36		subscribe: impl Into<Option<OriginProducer>>,
37	) -> Result<Self, Error> {
38		let mut stream = Stream::open(&session, setup::ServerKind::Ietf14).await?;
39
40		let mut parameters = ietf::Parameters::default();
41		parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
42		parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
43		let parameters = parameters.encode_bytes(());
44
45		let client = setup::Client {
46			// Unfortunately, we have to pick a single draft range to support.
47			// moq-lite can support this handshake.
48			kind: setup::ClientKind::Ietf14,
49			versions: VERSIONS.into(),
50			parameters,
51		};
52
53		// TODO pretty print the parameters.
54		tracing::trace!(?client, "sending client setup");
55		stream.writer.encode(&client).await?;
56
57		let mut server: setup::Server = stream.reader.decode().await?;
58		tracing::trace!(?server, "received server setup");
59
60		if let Ok(version) = lite::Version::try_from(server.version) {
61			let stream = stream.with_version(version);
62			lite::start(session.clone(), stream, publish.into(), subscribe.into(), version).await?;
63		} else if let Ok(version) = ietf::Version::try_from(server.version) {
64			// Decode the parameters to get the initial request ID.
65			let parameters = ietf::Parameters::decode(&mut server.parameters, version)?;
66			let request_id_max =
67				ietf::RequestId(parameters.get_varint(ietf::ParameterVarInt::MaxRequestId).unwrap_or(0));
68
69			let stream = stream.with_version(version);
70			ietf::start(
71				session.clone(),
72				stream,
73				request_id_max,
74				true,
75				publish.into(),
76				subscribe.into(),
77				version,
78			)
79			.await?;
80		} else {
81			// unreachable, but just in case
82			return Err(Error::Version(client.versions, [server.version].into()));
83		}
84
85		tracing::debug!(version = ?server.version, "connected");
86
87		Ok(Self::new(session))
88	}
89
90	/// Perform the MoQ handshake as a server.
91	///
92	/// Publishing is performed with [OriginConsumer] and subscribing with [OriginProducer].
93	/// The connection remains active until the session is closed.
94	pub async fn accept(
95		session: S,
96		publish: impl Into<Option<OriginConsumer>>,
97		subscribe: impl Into<Option<OriginProducer>>,
98	) -> Result<Self, Error> {
99		// Accept with an initial version; we'll switch to the negotiated version later
100		let mut stream = Stream::accept(&session, ()).await?;
101		let client: setup::Client = stream.reader.decode().await?;
102		tracing::trace!(?client, "received client setup");
103
104		// Choose the version to use
105		let version = client
106			.versions
107			.iter()
108			.find(|v| VERSIONS.contains(v))
109			.copied()
110			.ok_or_else(|| Error::Version(client.versions.clone(), VERSIONS.into()))?;
111
112		// Only encode parameters if we're using the IETF draft because it has max_request_id
113		let parameters = if ietf::Version::try_from(version).is_ok() && client.kind == setup::ClientKind::Ietf14 {
114			let mut parameters = ietf::Parameters::default();
115			parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
116			parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
117			parameters.encode_bytes(())
118		} else {
119			lite::Parameters::default().encode_bytes(())
120		};
121
122		let mut server = setup::Server { version, parameters };
123		tracing::trace!(?server, "sending server setup");
124
125		let mut stream = stream.with_version(client.kind.reply());
126		stream.writer.encode(&server).await?;
127
128		if let Ok(version) = lite::Version::try_from(version) {
129			let stream = stream.with_version(version);
130			lite::start(session.clone(), stream, publish.into(), subscribe.into(), version).await?;
131		} else if let Ok(version) = ietf::Version::try_from(version) {
132			// Decode the parameters to get the initial request ID.
133			let parameters = ietf::Parameters::decode(&mut server.parameters, version)?;
134			let request_id_max =
135				ietf::RequestId(parameters.get_varint(ietf::ParameterVarInt::MaxRequestId).unwrap_or(0));
136
137			let stream = stream.with_version(version);
138			ietf::start(
139				session.clone(),
140				stream,
141				request_id_max,
142				false,
143				publish.into(),
144				subscribe.into(),
145				version,
146			)
147			.await?;
148		} else {
149			// unreachable, but just in case
150			return Err(Error::Version(client.versions, VERSIONS.into()));
151		}
152
153		tracing::debug!(?version, "connected");
154
155		Ok(Self::new(session))
156	}
157
158	/// Close the underlying transport session.
159	pub fn close(self, err: Error) {
160		self.session.close(err.to_code(), err.to_string().as_ref());
161	}
162
163	/// Block until the transport session is closed.
164	// TODO Remove the Result the next time we make a breaking change.
165	pub async fn closed(&self) -> Result<(), Error> {
166		Err(Error::Transport(Arc::new(self.session.closed().await)))
167	}
168}