Skip to main content

trojan_server/
ws.rs

1//! WebSocket transport support.
2//!
3//! This module provides WebSocket upgrade handling for the server.
4//! The `WsIo` adapter is provided by `trojan-core::transport`.
5
6use bytes::Bytes;
7use tokio::io::{AsyncRead, AsyncWrite};
8use tokio_tungstenite::{
9    WebSocketStream, accept_hdr_async_with_config,
10    tungstenite::{
11        handshake::server::{Request, Response},
12        protocol::WebSocketConfig,
13    },
14};
15use tracing::{debug, warn};
16use trojan_config::WebSocketConfig as WsCfg;
17
18use crate::error::ServerError;
19use crate::util::PrefixedStream;
20
21// Re-export WsIo from trojan-core for convenience
22pub use trojan_core::transport::WsIo;
23
24/// Initial buffer size for reading HTTP headers during WebSocket upgrade.
25pub const INITIAL_BUFFER_SIZE: usize = 2048;
26
27const HTTP_HEADER_END: &[u8] = b"\r\n\r\n";
28
29/// Result of inspecting buffered bytes for WebSocket upgrade.
30#[derive(Debug)]
31pub enum WsInspect {
32    /// Need more data to determine protocol.
33    NeedMore,
34    /// Not HTTP traffic, proceed as raw Trojan.
35    NotHttp,
36    /// HTTP but not WebSocket upgrade, fallback to HTTP backend.
37    HttpFallback,
38    /// Valid WebSocket upgrade request.
39    Upgrade,
40    /// Reject with reason (e.g., path/host mismatch).
41    Reject(&'static str),
42}
43
44/// Inspect buffered bytes for WebSocket upgrade in mixed mode.
45pub fn inspect_mixed(buf: &[u8], cfg: &WsCfg) -> WsInspect {
46    // Quick check: if the buffer doesn't start with a plausible HTTP method,
47    // it's definitely not HTTP. Trojan headers start with a hex hash which
48    // will never match these prefixes (except edge cases caught below).
49    if buf.len() >= 3 && !could_be_http_method(buf) {
50        return WsInspect::NotHttp;
51    }
52
53    let header_end = find_header_end(buf);
54    if header_end.is_none() {
55        // If we've read a significant amount of data without finding \r\n\r\n,
56        // and it doesn't look like HTTP is still being received, treat as not HTTP.
57        // HTTP request lines are typically under 8KB. Trojan headers are ~70 bytes.
58        if buf.len() >= 256 {
59            return WsInspect::NotHttp;
60        }
61        return WsInspect::NeedMore;
62    }
63    let header_end = header_end.unwrap();
64    let header_bytes = &buf[..header_end];
65    let header_str = match std::str::from_utf8(header_bytes) {
66        Ok(v) => v,
67        Err(_) => return WsInspect::NotHttp,
68    };
69    let mut lines = header_str.split("\r\n");
70    let request_line = match lines.next() {
71        Some(v) => v,
72        None => return WsInspect::NotHttp,
73    };
74    let mut parts = request_line.split_whitespace();
75    let method = parts.next().unwrap_or("");
76    let path = parts.next().unwrap_or("");
77    let version = parts.next().unwrap_or("");
78    if !version.starts_with("HTTP/") {
79        return WsInspect::NotHttp;
80    }
81    if method != "GET" {
82        return WsInspect::HttpFallback;
83    }
84
85    let mut upgrade = false;
86    let mut connection_upgrade = false;
87    let mut ws_key = false;
88    let mut host: Option<&str> = None;
89
90    for line in lines {
91        if let Some((name, value)) = line.split_once(':') {
92            let name = name.trim().to_ascii_lowercase();
93            let value_trim = value.trim();
94            let value_lower = value_trim.to_ascii_lowercase();
95            match name.as_str() {
96                "upgrade" => {
97                    if value_lower.contains("websocket") {
98                        upgrade = true;
99                    }
100                }
101                "connection" => {
102                    if value_lower.contains("upgrade") {
103                        connection_upgrade = true;
104                    }
105                }
106                "sec-websocket-key" => {
107                    if !value_trim.is_empty() {
108                        ws_key = true;
109                    }
110                }
111                "host" => {
112                    host = Some(value_trim);
113                }
114                _ => {}
115            }
116        }
117    }
118
119    if !upgrade || !connection_upgrade || !ws_key {
120        return WsInspect::HttpFallback;
121    }
122
123    if !path_matches(cfg, path) || !host_matches(cfg, host) {
124        return WsInspect::Reject("websocket path/host mismatch");
125    }
126
127    WsInspect::Upgrade
128}
129
130/// Accept a WebSocket upgrade on the given stream.
131pub async fn accept_ws<S>(
132    stream: S,
133    initial: Bytes,
134    cfg: &WsCfg,
135) -> Result<WebSocketStream<PrefixedStream<S>>, ServerError>
136where
137    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
138{
139    let max_frame = if cfg.max_frame_bytes == 0 {
140        None
141    } else {
142        Some(cfg.max_frame_bytes)
143    };
144    let mut ws_cfg = WebSocketConfig::default();
145    ws_cfg.max_frame_size = max_frame;
146    ws_cfg.max_message_size = max_frame;
147    let prefixed = PrefixedStream::new(initial, stream);
148    let ws = accept_hdr_async_with_config(
149        prefixed,
150        |req: &Request, resp: Response| {
151            debug!(path = %req.uri().path(), "websocket upgrade");
152            Ok(resp)
153        },
154        Some(ws_cfg),
155    )
156    .await
157    .map_err(|e| {
158        ServerError::Io(std::io::Error::new(
159            std::io::ErrorKind::InvalidData,
160            format!("websocket handshake failed: {e}"),
161        ))
162    })?;
163    Ok(ws)
164}
165
166/// Send an HTTP 400 Bad Request response to reject the connection.
167pub async fn send_reject<S>(mut stream: S, reason: &'static str) -> Result<(), ServerError>
168where
169    S: AsyncWrite + Unpin,
170{
171    warn!(reason, "websocket rejected");
172    let response = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n";
173    tokio::io::AsyncWriteExt::write_all(&mut stream, response).await?;
174    Ok(())
175}
176
177/// Check if the buffer could plausibly start with an HTTP method.
178/// HTTP methods: GET, POST, PUT, DELETE, HEAD, OPTIONS, PATCH, CONNECT, TRACE.
179fn could_be_http_method(buf: &[u8]) -> bool {
180    buf.starts_with(b"GET")
181        || buf.starts_with(b"POS")
182        || buf.starts_with(b"PUT")
183        || buf.starts_with(b"DEL")
184        || buf.starts_with(b"HEA")
185        || buf.starts_with(b"OPT")
186        || buf.starts_with(b"PAT")
187        || buf.starts_with(b"CON")
188        || buf.starts_with(b"TRA")
189}
190
191fn find_header_end(buf: &[u8]) -> Option<usize> {
192    buf.windows(HTTP_HEADER_END.len())
193        .position(|w| w == HTTP_HEADER_END)
194        .map(|idx| idx + HTTP_HEADER_END.len())
195}
196
197fn path_matches(cfg: &WsCfg, path: &str) -> bool {
198    let path_only = path.split('?').next().unwrap_or("");
199    path_only == cfg.path
200}
201
202fn host_matches(cfg: &WsCfg, host: Option<&str>) -> bool {
203    let expected = match cfg.host.as_deref() {
204        Some(v) => v,
205        None => return true,
206    };
207    let host = match host {
208        Some(v) => v,
209        None => return false,
210    };
211    let host_only = host.split(':').next().unwrap_or("");
212    host_only.eq_ignore_ascii_case(expected)
213}