1use std::sync::Arc;
2
3use crate::{
4 coding::{self, Decode, Encode, Stream},
5 ietf, lite, setup, Error, OriginConsumer, OriginProducer,
6};
7
8pub struct Session<S: web_transport_trait::Session> {
9 session: S,
10}
11
12pub const VERSIONS: [coding::Version; 3] = [
16 lite::Version::Draft02.coding(),
17 lite::Version::Draft01.coding(),
18 ietf::Version::Draft14.coding(),
19];
20
21pub const ALPNS: [&str; 2] = [lite::ALPN, ietf::ALPN];
23
24impl<S: web_transport_trait::Session> Session<S> {
25 fn new(session: S) -> Self {
26 Self { session }
27 }
28
29 pub async fn connect(
34 session: S,
35 publish: impl Into<Option<OriginConsumer>>,
36 subscribe: impl Into<Option<OriginProducer>>,
37 ) -> Result<Self, Error> {
38 let mut stream = Stream::open(&session, setup::ServerKind::Ietf14).await?;
39
40 let mut parameters = ietf::Parameters::default();
41 parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
42 parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
43 let parameters = parameters.encode_bytes(());
44
45 let client = setup::Client {
46 kind: setup::ClientKind::Ietf14,
49 versions: VERSIONS.into(),
50 parameters,
51 };
52
53 tracing::trace!(?client, "sending client setup");
55 stream.writer.encode(&client).await?;
56
57 let mut server: setup::Server = stream.reader.decode().await?;
58 tracing::trace!(?server, "received server setup");
59
60 if let Ok(version) = lite::Version::try_from(server.version) {
61 let stream = stream.with_version(version);
62 lite::start(session.clone(), stream, publish.into(), subscribe.into(), version).await?;
63 } else if let Ok(version) = ietf::Version::try_from(server.version) {
64 let parameters = ietf::Parameters::decode(&mut server.parameters, version)?;
66 let request_id_max =
67 ietf::RequestId(parameters.get_varint(ietf::ParameterVarInt::MaxRequestId).unwrap_or(0));
68
69 let stream = stream.with_version(version);
70 ietf::start(
71 session.clone(),
72 stream,
73 request_id_max,
74 true,
75 publish.into(),
76 subscribe.into(),
77 version,
78 )
79 .await?;
80 } else {
81 return Err(Error::Version(client.versions, [server.version].into()));
83 }
84
85 tracing::debug!(version = ?server.version, "connected");
86
87 Ok(Self::new(session))
88 }
89
90 pub async fn accept(
95 session: S,
96 publish: impl Into<Option<OriginConsumer>>,
97 subscribe: impl Into<Option<OriginProducer>>,
98 ) -> Result<Self, Error> {
99 let mut stream = Stream::accept(&session, ()).await?;
101 let client: setup::Client = stream.reader.decode().await?;
102 tracing::trace!(?client, "received client setup");
103
104 let version = client
106 .versions
107 .iter()
108 .find(|v| VERSIONS.contains(v))
109 .copied()
110 .ok_or_else(|| Error::Version(client.versions.clone(), VERSIONS.into()))?;
111
112 let parameters = if ietf::Version::try_from(version).is_ok() && client.kind == setup::ClientKind::Ietf14 {
114 let mut parameters = ietf::Parameters::default();
115 parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64);
116 parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec());
117 parameters.encode_bytes(())
118 } else {
119 lite::Parameters::default().encode_bytes(())
120 };
121
122 let mut server = setup::Server { version, parameters };
123 tracing::trace!(?server, "sending server setup");
124
125 let mut stream = stream.with_version(client.kind.reply());
126 stream.writer.encode(&server).await?;
127
128 if let Ok(version) = lite::Version::try_from(version) {
129 let stream = stream.with_version(version);
130 lite::start(session.clone(), stream, publish.into(), subscribe.into(), version).await?;
131 } else if let Ok(version) = ietf::Version::try_from(version) {
132 let parameters = ietf::Parameters::decode(&mut server.parameters, version)?;
134 let request_id_max =
135 ietf::RequestId(parameters.get_varint(ietf::ParameterVarInt::MaxRequestId).unwrap_or(0));
136
137 let stream = stream.with_version(version);
138 ietf::start(
139 session.clone(),
140 stream,
141 request_id_max,
142 false,
143 publish.into(),
144 subscribe.into(),
145 version,
146 )
147 .await?;
148 } else {
149 return Err(Error::Version(client.versions, VERSIONS.into()));
151 }
152
153 tracing::debug!(?version, "connected");
154
155 Ok(Self::new(session))
156 }
157
158 pub fn close(self, err: Error) {
160 self.session.close(err.to_code(), err.to_string().as_ref());
161 }
162
163 pub async fn closed(&self) -> Result<(), Error> {
166 Err(Error::Transport(Arc::new(self.session.closed().await)))
167 }
168}