1use qmux::tokio_tungstenite;
2use qmux::tokio_tungstenite::tungstenite::{self, client::IntoClientRequest, http};
3use std::collections::HashSet;
4use std::sync::{Arc, LazyLock, Mutex};
5use std::{net, time};
6use url::Url;
7
8#[derive(Debug, thiserror::Error)]
10#[non_exhaustive]
11pub enum Error {
12 #[error(transparent)]
13 Io(#[from] std::io::Error),
14
15 #[error("WebSocket support is disabled")]
16 Disabled,
17
18 #[error("missing hostname")]
19 MissingHostname,
20
21 #[error("unsupported URL scheme for WebSocket: {0}")]
22 UnsupportedScheme(String),
23
24 #[error("failed to connect WebSocket")]
25 Connect(#[source] qmux::Error),
26
27 #[error("failed to build WebSocket request")]
28 BuildRequest(#[source] tungstenite::Error),
29
30 #[error("failed to build WebSocket protocols header")]
31 ProtocolHeader(#[source] http::header::InvalidHeaderValue),
32
33 #[error("failed to connect WebSocket")]
34 WebSocketConnect(#[source] tungstenite::Error),
35
36 #[error(transparent)]
37 ConnectRejected(#[from] crate::ConnectError),
38
39 #[error("WebSocket accept failed")]
40 Accept(#[source] qmux::Error),
41}
42
43type Result<T> = std::result::Result<T, Error>;
44
45static WEBSOCKET_WON: LazyLock<Mutex<HashSet<(String, u16)>>> = LazyLock::new(|| Mutex::new(HashSet::new()));
47
48#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
50#[serde(default, deny_unknown_fields)]
51#[group(id = "websocket-client")]
52#[non_exhaustive]
53pub struct Client {
54 #[arg(
56 id = "websocket-enabled",
57 long = "websocket-enabled",
58 env = "MOQ_CLIENT_WEBSOCKET_ENABLED",
59 default_value = "true"
60 )]
61 pub enabled: bool,
62
63 #[arg(
66 id = "websocket-delay",
67 long = "websocket-delay",
68 env = "MOQ_CLIENT_WEBSOCKET_DELAY",
69 default_value = "200ms",
70 value_parser = humantime::parse_duration,
71 )]
72 #[serde(with = "humantime_serde")]
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pub delay: Option<time::Duration>,
75}
76
77impl Default for Client {
78 fn default() -> Self {
79 Self {
80 enabled: true,
81 delay: Some(time::Duration::from_millis(200)),
82 }
83 }
84}
85
86pub(crate) async fn race_handle(
87 config: &Client,
88 tls: &rustls::ClientConfig,
89 url: Url,
90 alpns: &[&str],
91) -> Option<Result<qmux::Session>> {
92 if !config.enabled {
93 return None;
94 }
95
96 match url.scheme() {
99 "http" | "https" | "ws" | "wss" => {}
100 _ => return None,
101 }
102
103 let res = connect(config, tls, url, alpns).await;
104 if let Err(err) = &res {
105 tracing::warn!(%err, "WebSocket connection failed");
106 }
107 Some(res)
108}
109
110pub(crate) async fn connect(
111 config: &Client,
112 tls: &rustls::ClientConfig,
113 mut url: Url,
114 alpns: &[&str],
115) -> Result<qmux::Session> {
116 if !config.enabled {
117 return Err(Error::Disabled);
118 }
119
120 let host = url.host_str().ok_or(Error::MissingHostname)?.to_string();
121 let port = url.port().unwrap_or_else(|| match url.scheme() {
122 "https" | "wss" | "moql" | "moqt" => 443,
123 "http" | "ws" => 80,
124 _ => 443,
125 });
126 let key = (host, port);
127
128 match config.delay {
132 Some(delay) if !WEBSOCKET_WON.lock().unwrap().contains(&key) => {
133 tokio::time::sleep(delay).await;
134 tracing::debug!(%url, delay_ms = %delay.as_millis(), "QUIC not yet connected, attempting WebSocket fallback");
135 }
136 _ => {}
137 }
138
139 let needs_tls = match url.scheme() {
142 "http" => {
143 url.set_scheme("ws").expect("failed to set scheme");
144 false
145 }
146 "https" => {
147 url.set_scheme("wss").expect("failed to set scheme");
148 true
149 }
150 "ws" => false,
151 "wss" => true,
152 _ => return Err(Error::UnsupportedScheme(url.scheme().to_string())),
153 };
154
155 tracing::debug!(%url, "connecting via WebSocket");
156
157 let connector = if needs_tls {
159 tokio_tungstenite::Connector::Rustls(Arc::new(tls.clone()))
160 } else {
161 tokio_tungstenite::Connector::Plain
162 };
163
164 let mut request = url.as_str().into_client_request().map_err(Error::BuildRequest)?;
165 let protocols = websocket_subprotocols(alpns).join(", ");
166 request.headers_mut().insert(
167 http::header::SEC_WEBSOCKET_PROTOCOL,
168 http::HeaderValue::from_str(&protocols).map_err(Error::ProtocolHeader)?,
169 );
170
171 let (socket, response) = if needs_tls {
172 tokio_tungstenite::connect_async_tls_with_config(request, None, false, Some(connector))
173 .await
174 .map_err(map_websocket_error)?
175 } else {
176 tokio_tungstenite::connect_async_with_config(request, None, false)
177 .await
178 .map_err(map_websocket_error)?
179 };
180
181 let alpn = response
182 .headers()
183 .get(http::header::SEC_WEBSOCKET_PROTOCOL)
184 .and_then(|header| header.to_str().ok())
185 .map(str::to_owned);
186 let bare = qmux::ws::Bare::new(socket).with_keep_alive(qmux::KeepAlive::default());
187 let bare = match alpn.as_deref() {
188 Some(alpn) => bare.with_alpn(alpn),
189 None => bare,
190 };
191 let session = bare.connect();
192
193 tracing::warn!(%url, "using WebSocket fallback");
194 WEBSOCKET_WON.lock().unwrap().insert(key);
195
196 Ok(session)
197}
198
199fn websocket_subprotocols(alpns: &[&str]) -> Vec<String> {
200 let mut protocols = Vec::with_capacity(qmux::ALPNS.len() + qmux::PREFIXES.len() * alpns.len());
201 for (&bare, &prefix) in qmux::ALPNS.iter().zip(qmux::PREFIXES) {
202 protocols.push(bare.to_string());
203 protocols.extend(alpns.iter().map(|alpn| format!("{prefix}{alpn}")));
204 }
205 protocols
206}
207
208impl Error {
209 pub(crate) fn connect_error(&self) -> Option<crate::ConnectError> {
210 match self {
211 Self::ConnectRejected(err) => Some(*err),
212 _ => None,
213 }
214 }
215}
216
217fn map_websocket_error(err: tungstenite::Error) -> Error {
218 if let tungstenite::Error::Http(response) = &err
219 && let Some(err) = crate::ConnectError::from_status_u16(response.status().as_u16())
220 {
221 return err.into();
222 }
223
224 Error::WebSocketConnect(err)
225}
226
227pub struct Listener {
232 listener: tokio::net::TcpListener,
233 server: qmux::Server,
234}
235
236impl Listener {
237 pub async fn bind(addr: net::SocketAddr) -> Result<Self> {
238 Self::bind_with_alpns(addr, moq_net::ALPNS).await
239 }
240
241 pub async fn bind_with_alpns(addr: net::SocketAddr, alpns: &[&str]) -> Result<Self> {
242 let listener = tokio::net::TcpListener::bind(addr).await?;
243 let server = qmux::Server::new().with_protocols(alpns);
244 Ok(Self { listener, server })
245 }
246
247 pub fn local_addr(&self) -> Result<net::SocketAddr> {
248 Ok(self.listener.local_addr()?)
249 }
250
251 pub async fn accept(&self) -> Option<Result<qmux::Session>> {
252 match self.listener.accept().await {
253 Ok((stream, addr)) => {
254 tracing::debug!(%addr, "accepted WebSocket TCP connection");
255 let server = self.server.clone();
256 Some(server.accept(stream).await.map_err(Error::Accept))
257 }
258 Err(e) => Some(Err(e.into())),
259 }
260 }
261}