1use std::sync::Arc;
2
3use crate::{
4 coding::{self, Stream},
5 ietf, lite, Error, OriginConsumer, OriginProducer,
6};
7
8pub struct Session<S: web_transport_trait::Session> {
9 session: S,
10}
11
12const SUPPORTED: [coding::Version; 2] = [coding::Version::LITE_LATEST, coding::Version::IETF_LATEST];
14
15impl<S: web_transport_trait::Session> Session<S> {
16 fn new(session: S) -> Self {
17 Self { session }
18 }
19
20 pub async fn connect(
25 session: S,
26 publish: impl Into<Option<OriginConsumer>>,
27 subscribe: impl Into<Option<OriginProducer>>,
28 ) -> Result<Self, Error> {
29 let mut stream = Stream::open(&session).await?;
30
31 stream.writer.encode(&lite::ControlType::ClientCompat).await?;
33
34 let mut extensions = coding::Extensions::default();
36 extensions.set(ietf::Role::Both);
37
38 let client = lite::ClientSetup {
39 versions: SUPPORTED.into(),
40 extensions,
41 };
42
43 stream.writer.encode(&client).await?;
44
45 let server_compat: lite::ControlType = stream.reader.decode().await?;
47 if server_compat != lite::ControlType::ServerCompat {
48 return Err(Error::UnexpectedStream);
49 }
50
51 let server: lite::ServerSetup = stream.reader.decode().await?;
52
53 tracing::debug!(version = ?server.version, "connected");
54
55 match server.version {
56 coding::Version::LITE_LATEST => {
57 lite::start(session.clone(), stream, publish.into(), subscribe.into()).await?;
58 }
59 coding::Version::IETF_LATEST => {
60 ietf::start(session.clone(), stream, publish.into(), subscribe.into()).await?;
61 }
62 _ => return Err(Error::Version(client.versions, [server.version].into())),
63 }
64
65 Ok(Self::new(session))
66 }
67
68 pub async fn accept(
73 session: S,
74 publish: impl Into<Option<OriginConsumer>>,
75 subscribe: impl Into<Option<OriginProducer>>,
76 ) -> Result<Self, Error> {
77 let mut stream = Stream::accept(&session).await?;
78 let kind: lite::ControlType = stream.reader.decode().await?;
79
80 if kind != lite::ControlType::Session && kind != lite::ControlType::ClientCompat {
81 return Err(Error::UnexpectedStream);
82 }
83
84 let client: lite::ClientSetup = stream.reader.decode().await?;
85
86 let version = client
87 .versions
88 .iter()
89 .find(|v| SUPPORTED.contains(v))
90 .copied()
91 .ok_or_else(|| Error::Version(client.versions, SUPPORTED.into()))?;
92
93 let server = lite::ServerSetup {
94 version,
95 extensions: Default::default(),
96 };
97
98 if kind == lite::ControlType::ClientCompat {
100 stream.writer.encode(&lite::ControlType::ServerCompat).await?;
102 }
103
104 stream.writer.encode(&server).await?;
105
106 tracing::debug!(version = ?server.version, "connected");
107
108 match version {
109 coding::Version::LITE_LATEST => {
110 lite::start(session.clone(), stream, publish.into(), subscribe.into()).await?;
111 }
112 coding::Version::IETF_LATEST => {
113 ietf::start(session.clone(), stream, publish.into(), subscribe.into()).await?;
114 }
115 _ => unreachable!(),
116 }
117
118 Ok(Self::new(session))
119 }
120
121 pub fn close(self, err: Error) {
123 self.session.close(err.to_code(), err.to_string().as_ref());
124 }
125
126 pub async fn closed(&self) -> Result<(), Error> {
128 match self.session.closed().await {
129 Ok(()) => Ok(()),
130 Err(err) => Err(Error::Transport(Arc::new(err))),
131 }
132 }
133}