moq_lite/session/
mod.rs

1use crate::{message, Error, OriginConsumer, OriginProducer};
2
3mod publisher;
4mod reader;
5mod stream;
6mod subscriber;
7mod writer;
8
9use publisher::*;
10use reader::*;
11use stream::*;
12use subscriber::*;
13use tokio::sync::oneshot;
14use writer::*;
15
16/// A MoQ session, constructed with [OriginProducer] and [OriginConsumer] halves.
17///
18/// This simplifies the state machine and immediately rejects any subscriptions that don't match the origin prefix.
19/// You probably want to use [Session] unless you're writing a relay.
20pub struct Session {
21	pub webtransport: web_transport::Session,
22}
23
24impl Session {
25	async fn new(
26		mut session: web_transport::Session,
27		stream: Stream,
28		// We will publish any local broadcasts from this origin.
29		publish: Option<OriginConsumer>,
30		// We will consume any remote broadcasts, inserting them into this origin.
31		subscribe: Option<OriginProducer>,
32	) -> Result<Self, Error> {
33		let publisher = Publisher::new(session.clone(), publish);
34		let subscriber = Subscriber::new(session.clone(), subscribe);
35
36		let this = Self {
37			webtransport: session.clone(),
38		};
39
40		let init = oneshot::channel();
41
42		web_async::spawn(async move {
43			let res = tokio::select! {
44				res = Self::run_session(stream) => res,
45				res = publisher.run() => res,
46				res = subscriber.run(init.0) => res,
47			};
48
49			match res {
50				Err(Error::WebTransport(web_transport::Error::Session(_))) => {
51					tracing::info!("session terminated");
52					session.close(1, "");
53				}
54				Err(err) => {
55					tracing::warn!(%err, "session error");
56					session.close(err.to_code(), &err.to_string());
57				}
58				_ => {
59					tracing::info!("session closed");
60					session.close(0, "");
61				}
62			}
63		});
64
65		// Wait until receiving the initial announcements to prevent some race conditions.
66		// Otherwise, `consume()` might return not found if we don't wait long enough, so just wait.
67		// If the announce stream fails or is closed, this will return an error instead of hanging.
68		// TODO return a better error
69		init.1.await.map_err(|_| Error::Cancel)?;
70
71		Ok(this)
72	}
73
74	/// Perform the MoQ handshake as a client.
75	pub async fn connect(
76		session: impl Into<web_transport::Session>,
77		publish: impl Into<Option<OriginConsumer>>,
78		subscribe: impl Into<Option<OriginProducer>>,
79	) -> Result<Self, Error> {
80		let mut session = session.into();
81		let mut stream = Stream::open(&mut session, message::ControlType::Session).await?;
82		Self::connect_setup(&mut stream).await?;
83		let session = Self::new(session, stream, publish.into(), subscribe.into()).await?;
84		Ok(session)
85	}
86
87	async fn connect_setup(setup: &mut Stream) -> Result<(), Error> {
88		let client = message::ClientSetup {
89			versions: [message::Version::CURRENT].into(),
90			extensions: Default::default(),
91		};
92
93		setup.writer.encode(&client).await?;
94		let server: message::ServerSetup = setup.reader.decode().await?;
95
96		tracing::debug!(version = ?server.version, "connected");
97
98		Ok(())
99	}
100
101	/// Perform the MoQ handshake as a server
102	pub async fn accept<
103		T: Into<web_transport::Session>,
104		P: Into<Option<OriginConsumer>>,
105		C: Into<Option<OriginProducer>>,
106	>(
107		session: T,
108		publish: P,
109		subscribe: C,
110	) -> Result<Self, Error> {
111		let mut session = session.into();
112		let mut stream = Stream::accept(&mut session).await?;
113		let kind = stream.reader.decode().await?;
114
115		Self::accept_setup(kind, &mut stream).await?;
116		let session = Self::new(session, stream, publish.into(), subscribe.into()).await?;
117		Ok(session)
118	}
119
120	async fn accept_setup(kind: message::ControlType, control: &mut Stream) -> Result<(), Error> {
121		if kind != message::ControlType::Session && kind != message::ControlType::ClientCompat {
122			return Err(Error::UnexpectedStream(kind));
123		}
124
125		let client: message::ClientSetup = control.reader.decode().await?;
126
127		if !client.versions.contains(&message::Version::CURRENT) {
128			return Err(Error::Version(client.versions, [message::Version::CURRENT].into()));
129		}
130
131		let server = message::ServerSetup {
132			version: message::Version::CURRENT,
133			extensions: Default::default(),
134		};
135
136		// Backwards compatibility with moq-transport-10
137		if kind == message::ControlType::ClientCompat {
138			// Write a 0x41 just to be backwards compatible.
139			control.writer.encode(&message::ControlType::ServerCompat).await?;
140		}
141
142		control.writer.encode(&server).await?;
143
144		tracing::debug!(version = ?server.version, "connected");
145
146		Ok(())
147	}
148
149	// TODO do something useful with this
150	async fn run_session(mut stream: Stream) -> Result<(), Error> {
151		while let Some(_info) = stream.reader.decode_maybe::<message::SessionInfo>().await? {}
152		Err(Error::Cancel)
153	}
154
155	/// Close the underlying WebTransport session.
156	pub fn close(mut self, err: Error) {
157		self.webtransport.close(err.to_code(), &err.to_string());
158	}
159
160	/// Block until the WebTransport session is closed.
161	pub async fn closed(&self) -> Error {
162		self.webtransport.closed().await.into()
163	}
164}