1use anyhow::Context;
2use std::collections::HashSet;
3use std::sync::{Arc, LazyLock, Mutex};
4use std::time;
5use url::Url;
6use web_transport_ws::{tokio_tungstenite, tungstenite};
7
8static WEBSOCKET_WON: LazyLock<Mutex<HashSet<(String, u16)>>> = LazyLock::new(|| Mutex::new(HashSet::new()));
10
11#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
13#[serde(default, deny_unknown_fields)]
14#[non_exhaustive]
15pub struct ClientWebSocket {
16 #[arg(
18 id = "websocket-enabled",
19 long = "websocket-enabled",
20 env = "MOQ_CLIENT_WEBSOCKET_ENABLED",
21 default_value = "true"
22 )]
23 pub enabled: bool,
24
25 #[arg(
28 id = "websocket-delay",
29 long = "websocket-delay",
30 env = "MOQ_CLIENT_WEBSOCKET_DELAY",
31 default_value = "200ms",
32 value_parser = humantime::parse_duration,
33 )]
34 #[serde(with = "humantime_serde")]
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub delay: Option<time::Duration>,
37}
38
39impl Default for ClientWebSocket {
40 fn default() -> Self {
41 Self {
42 enabled: true,
43 delay: Some(time::Duration::from_millis(200)),
44 }
45 }
46}
47
48pub(crate) async fn race_handle(
49 config: &ClientWebSocket,
50 tls: &rustls::ClientConfig,
51 url: Url,
52) -> Option<anyhow::Result<web_transport_ws::Session>> {
53 if !config.enabled {
54 return None;
55 }
56 let res = connect(config, tls, url).await;
57 if let Err(err) = &res {
58 tracing::warn!(%err, "WebSocket connection failed");
59 }
60 Some(res)
61}
62
63pub(crate) async fn connect(
64 config: &ClientWebSocket,
65 tls: &rustls::ClientConfig,
66 mut url: Url,
67) -> anyhow::Result<web_transport_ws::Session> {
68 anyhow::ensure!(config.enabled, "WebSocket support is disabled");
69
70 let host = url.host_str().context("missing hostname")?.to_string();
71 let port = url.port().unwrap_or_else(|| match url.scheme() {
72 "https" | "wss" | "moql" | "moqt" => 443,
73 "http" | "ws" => 80,
74 _ => 443,
75 });
76 let key = (host, port);
77
78 match config.delay {
82 Some(delay) if !WEBSOCKET_WON.lock().unwrap().contains(&key) => {
83 tokio::time::sleep(delay).await;
84 tracing::debug!(%url, delay_ms = %delay.as_millis(), "QUIC not yet connected, attempting WebSocket fallback");
85 }
86 _ => {}
87 }
88
89 let needs_tls = match url.scheme() {
91 "http" => {
92 url.set_scheme("ws").expect("failed to set scheme");
93 false
94 }
95 "https" | "moql" | "moqt" => {
96 url.set_scheme("wss").expect("failed to set scheme");
97 true
98 }
99 "ws" => false,
100 "wss" => true,
101 _ => anyhow::bail!("unsupported URL scheme for WebSocket: {}", url.scheme()),
102 };
103
104 tracing::debug!(%url, "connecting via WebSocket");
105
106 let connector = if needs_tls {
108 Some(tokio_tungstenite::Connector::Rustls(Arc::new(tls.clone())))
109 } else {
110 None
111 };
112
113 let (ws_stream, _response) = tokio_tungstenite::connect_async_tls_with_config(
115 url.as_str(),
116 Some(tungstenite::protocol::WebSocketConfig {
117 max_message_size: Some(64 << 20), max_frame_size: Some(16 << 20), accept_unmasked_frames: false,
120 ..Default::default()
121 }),
122 false, connector,
124 )
125 .await
126 .context("failed to connect WebSocket")?;
127
128 let session = web_transport_ws::Session::new(ws_stream, false);
131
132 tracing::warn!(%url, "using WebSocket fallback");
133 WEBSOCKET_WON.lock().unwrap().insert(key);
134
135 Ok(session)
136}