Skip to main content

trojan_server/
ws.rs

1//! WebSocket transport support.
2//!
3//! This module provides WebSocket upgrade handling for the server.
4//! The `WsIo` adapter is provided by `trojan-core::transport`.
5
6use bytes::Bytes;
7use tokio::io::{AsyncRead, AsyncWrite};
8use tokio_tungstenite::{
9    WebSocketStream, accept_hdr_async_with_config,
10    tungstenite::{
11        handshake::server::{Request, Response},
12        protocol::WebSocketConfig,
13    },
14};
15use tracing::{debug, warn};
16use trojan_config::WebSocketConfig as WsCfg;
17
18use crate::error::ServerError;
19use crate::util::PrefixedStream;
20
21// Re-export WsIo from trojan-core for convenience
22pub use trojan_core::transport::WsIo;
23
24/// Initial buffer size for reading HTTP headers during WebSocket upgrade.
25pub const INITIAL_BUFFER_SIZE: usize = 2048;
26
27const HTTP_HEADER_END: &[u8] = b"\r\n\r\n";
28
29/// Result of inspecting buffered bytes for WebSocket upgrade.
30pub enum WsInspect {
31    /// Need more data to determine protocol.
32    NeedMore,
33    /// Not HTTP traffic, proceed as raw Trojan.
34    NotHttp,
35    /// HTTP but not WebSocket upgrade, fallback to HTTP backend.
36    HttpFallback,
37    /// Valid WebSocket upgrade request.
38    Upgrade,
39    /// Reject with reason (e.g., path/host mismatch).
40    Reject(&'static str),
41}
42
43/// Inspect buffered bytes for WebSocket upgrade in mixed mode.
44pub fn inspect_mixed(buf: &[u8], cfg: &WsCfg) -> WsInspect {
45    let header_end = find_header_end(buf);
46    if header_end.is_none() {
47        return WsInspect::NeedMore;
48    }
49    let header_end = header_end.unwrap();
50    let header_bytes = &buf[..header_end];
51    let header_str = match std::str::from_utf8(header_bytes) {
52        Ok(v) => v,
53        Err(_) => return WsInspect::NotHttp,
54    };
55    let mut lines = header_str.split("\r\n");
56    let request_line = match lines.next() {
57        Some(v) => v,
58        None => return WsInspect::NotHttp,
59    };
60    let mut parts = request_line.split_whitespace();
61    let method = parts.next().unwrap_or("");
62    let path = parts.next().unwrap_or("");
63    let version = parts.next().unwrap_or("");
64    if !version.starts_with("HTTP/") {
65        return WsInspect::NotHttp;
66    }
67    if method != "GET" {
68        return WsInspect::HttpFallback;
69    }
70
71    let mut upgrade = false;
72    let mut connection_upgrade = false;
73    let mut ws_key = false;
74    let mut host: Option<&str> = None;
75
76    for line in lines {
77        if let Some((name, value)) = line.split_once(':') {
78            let name = name.trim().to_ascii_lowercase();
79            let value_trim = value.trim();
80            let value_lower = value_trim.to_ascii_lowercase();
81            match name.as_str() {
82                "upgrade" => {
83                    if value_lower.contains("websocket") {
84                        upgrade = true;
85                    }
86                }
87                "connection" => {
88                    if value_lower.contains("upgrade") {
89                        connection_upgrade = true;
90                    }
91                }
92                "sec-websocket-key" => {
93                    if !value_trim.is_empty() {
94                        ws_key = true;
95                    }
96                }
97                "host" => {
98                    host = Some(value_trim);
99                }
100                _ => {}
101            }
102        }
103    }
104
105    if !upgrade || !connection_upgrade || !ws_key {
106        return WsInspect::HttpFallback;
107    }
108
109    if !path_matches(cfg, path) || !host_matches(cfg, host) {
110        return WsInspect::Reject("websocket path/host mismatch");
111    }
112
113    WsInspect::Upgrade
114}
115
116/// Accept a WebSocket upgrade on the given stream.
117pub async fn accept_ws<S>(
118    stream: S,
119    initial: Bytes,
120    cfg: &WsCfg,
121) -> Result<WebSocketStream<PrefixedStream<S>>, ServerError>
122where
123    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
124{
125    let max_frame = if cfg.max_frame_bytes == 0 {
126        None
127    } else {
128        Some(cfg.max_frame_bytes)
129    };
130    let ws_cfg = WebSocketConfig {
131        max_frame_size: max_frame,
132        max_message_size: max_frame,
133        ..WebSocketConfig::default()
134    };
135    let prefixed = PrefixedStream::new(initial, stream);
136    let ws = accept_hdr_async_with_config(
137        prefixed,
138        |req: &Request, resp: Response| {
139            debug!(path = %req.uri().path(), "websocket upgrade");
140            Ok(resp)
141        },
142        Some(ws_cfg),
143    )
144    .await
145    .map_err(|e| {
146        ServerError::Io(std::io::Error::new(
147            std::io::ErrorKind::InvalidData,
148            format!("websocket handshake failed: {e}"),
149        ))
150    })?;
151    Ok(ws)
152}
153
154/// Send an HTTP 400 Bad Request response to reject the connection.
155pub async fn send_reject<S>(mut stream: S, reason: &'static str) -> Result<(), ServerError>
156where
157    S: AsyncWrite + Unpin,
158{
159    warn!(reason, "websocket rejected");
160    let response = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n";
161    tokio::io::AsyncWriteExt::write_all(&mut stream, response).await?;
162    Ok(())
163}
164
165fn find_header_end(buf: &[u8]) -> Option<usize> {
166    buf.windows(HTTP_HEADER_END.len())
167        .position(|w| w == HTTP_HEADER_END)
168        .map(|idx| idx + HTTP_HEADER_END.len())
169}
170
171fn path_matches(cfg: &WsCfg, path: &str) -> bool {
172    let path_only = path.split('?').next().unwrap_or("");
173    path_only == cfg.path
174}
175
176fn host_matches(cfg: &WsCfg, host: Option<&str>) -> bool {
177    let expected = match cfg.host.as_deref() {
178        Some(v) => v,
179        None => return true,
180    };
181    let host = match host {
182        Some(v) => v,
183        None => return false,
184    };
185    let host_only = host.split(':').next().unwrap_or("");
186    host_only.eq_ignore_ascii_case(expected)
187}