1use qmux::tokio_tungstenite;
2use std::collections::HashSet;
3use std::sync::{Arc, LazyLock, Mutex};
4use std::{net, time};
5use url::Url;
6
7#[derive(Debug, thiserror::Error)]
9#[non_exhaustive]
10pub enum Error {
11 #[error(transparent)]
12 Io(#[from] std::io::Error),
13
14 #[error("WebSocket support is disabled")]
15 Disabled,
16
17 #[error("missing hostname")]
18 MissingHostname,
19
20 #[error("unsupported URL scheme for WebSocket: {0}")]
21 UnsupportedScheme(String),
22
23 #[error("failed to connect WebSocket")]
24 Connect(#[source] qmux::Error),
25
26 #[error("WebSocket accept failed")]
27 Accept(#[source] qmux::Error),
28}
29
30type Result<T> = std::result::Result<T, Error>;
31
32static WEBSOCKET_WON: LazyLock<Mutex<HashSet<(String, u16)>>> = LazyLock::new(|| Mutex::new(HashSet::new()));
34
35#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
37#[serde(default, deny_unknown_fields)]
38#[group(id = "websocket-client")]
39#[non_exhaustive]
40pub struct Client {
41 #[arg(
43 id = "websocket-enabled",
44 long = "websocket-enabled",
45 env = "MOQ_CLIENT_WEBSOCKET_ENABLED",
46 default_value = "true"
47 )]
48 pub enabled: bool,
49
50 #[arg(
53 id = "websocket-delay",
54 long = "websocket-delay",
55 env = "MOQ_CLIENT_WEBSOCKET_DELAY",
56 default_value = "200ms",
57 value_parser = humantime::parse_duration,
58 )]
59 #[serde(with = "humantime_serde")]
60 #[serde(skip_serializing_if = "Option::is_none")]
61 pub delay: Option<time::Duration>,
62}
63
64impl Default for Client {
65 fn default() -> Self {
66 Self {
67 enabled: true,
68 delay: Some(time::Duration::from_millis(200)),
69 }
70 }
71}
72
73pub(crate) async fn race_handle(
74 config: &Client,
75 tls: &rustls::ClientConfig,
76 url: Url,
77 alpns: &[&str],
78) -> Option<Result<qmux::Session>> {
79 if !config.enabled {
80 return None;
81 }
82
83 match url.scheme() {
86 "http" | "https" | "ws" | "wss" => {}
87 _ => return None,
88 }
89
90 let res = connect(config, tls, url, alpns).await;
91 if let Err(err) = &res {
92 tracing::warn!(%err, "WebSocket connection failed");
93 }
94 Some(res)
95}
96
97pub(crate) async fn connect(
98 config: &Client,
99 tls: &rustls::ClientConfig,
100 mut url: Url,
101 alpns: &[&str],
102) -> Result<qmux::Session> {
103 if !config.enabled {
104 return Err(Error::Disabled);
105 }
106
107 let host = url.host_str().ok_or(Error::MissingHostname)?.to_string();
108 let port = url.port().unwrap_or_else(|| match url.scheme() {
109 "https" | "wss" | "moql" | "moqt" => 443,
110 "http" | "ws" => 80,
111 _ => 443,
112 });
113 let key = (host, port);
114
115 match config.delay {
119 Some(delay) if !WEBSOCKET_WON.lock().unwrap().contains(&key) => {
120 tokio::time::sleep(delay).await;
121 tracing::debug!(%url, delay_ms = %delay.as_millis(), "QUIC not yet connected, attempting WebSocket fallback");
122 }
123 _ => {}
124 }
125
126 let needs_tls = match url.scheme() {
129 "http" => {
130 url.set_scheme("ws").expect("failed to set scheme");
131 false
132 }
133 "https" => {
134 url.set_scheme("wss").expect("failed to set scheme");
135 true
136 }
137 "ws" => false,
138 "wss" => true,
139 _ => return Err(Error::UnsupportedScheme(url.scheme().to_string())),
140 };
141
142 tracing::debug!(%url, "connecting via WebSocket");
143
144 let connector = if needs_tls {
146 tokio_tungstenite::Connector::Rustls(Arc::new(tls.clone()))
147 } else {
148 tokio_tungstenite::Connector::Plain
149 };
150
151 let session = qmux::Client::new()
152 .with_protocols(alpns)
153 .with_connector(connector)
154 .with_keep_alive(qmux::KeepAlive::default()) .connect(url.as_str())
156 .await
157 .map_err(Error::Connect)?;
158
159 tracing::warn!(%url, "using WebSocket fallback");
160 WEBSOCKET_WON.lock().unwrap().insert(key);
161
162 Ok(session)
163}
164
165pub struct Listener {
170 listener: tokio::net::TcpListener,
171 server: qmux::Server,
172}
173
174impl Listener {
175 pub async fn bind(addr: net::SocketAddr) -> Result<Self> {
176 Self::bind_with_alpns(addr, moq_net::ALPNS).await
177 }
178
179 pub async fn bind_with_alpns(addr: net::SocketAddr, alpns: &[&str]) -> Result<Self> {
180 let listener = tokio::net::TcpListener::bind(addr).await?;
181 let server = qmux::Server::new().with_protocols(alpns);
182 Ok(Self { listener, server })
183 }
184
185 pub fn local_addr(&self) -> Result<net::SocketAddr> {
186 Ok(self.listener.local_addr()?)
187 }
188
189 pub async fn accept(&self) -> Option<Result<qmux::Session>> {
190 match self.listener.accept().await {
191 Ok((stream, addr)) => {
192 tracing::debug!(%addr, "accepted WebSocket TCP connection");
193 let server = self.server.clone();
194 Some(server.accept(stream).await.map_err(Error::Accept))
195 }
196 Err(e) => Some(Err(e.into())),
197 }
198 }
199}