1use bytes::Bytes;
7use tokio::io::{AsyncRead, AsyncWrite};
8use tokio_tungstenite::{
9 WebSocketStream, accept_hdr_async_with_config,
10 tungstenite::{
11 handshake::server::{Request, Response},
12 protocol::WebSocketConfig,
13 },
14};
15use tracing::{debug, warn};
16use trojan_config::WebSocketConfig as WsCfg;
17
18use crate::error::ServerError;
19use crate::util::PrefixedStream;
20
21pub use trojan_core::transport::WsIo;
23
24pub const INITIAL_BUFFER_SIZE: usize = 2048;
26
27const HTTP_HEADER_END: &[u8] = b"\r\n\r\n";
28
29pub enum WsInspect {
31 NeedMore,
33 NotHttp,
35 HttpFallback,
37 Upgrade,
39 Reject(&'static str),
41}
42
43pub fn inspect_mixed(buf: &[u8], cfg: &WsCfg) -> WsInspect {
45 let header_end = find_header_end(buf);
46 if header_end.is_none() {
47 return WsInspect::NeedMore;
48 }
49 let header_end = header_end.unwrap();
50 let header_bytes = &buf[..header_end];
51 let header_str = match std::str::from_utf8(header_bytes) {
52 Ok(v) => v,
53 Err(_) => return WsInspect::NotHttp,
54 };
55 let mut lines = header_str.split("\r\n");
56 let request_line = match lines.next() {
57 Some(v) => v,
58 None => return WsInspect::NotHttp,
59 };
60 let mut parts = request_line.split_whitespace();
61 let method = parts.next().unwrap_or("");
62 let path = parts.next().unwrap_or("");
63 let version = parts.next().unwrap_or("");
64 if !version.starts_with("HTTP/") {
65 return WsInspect::NotHttp;
66 }
67 if method != "GET" {
68 return WsInspect::HttpFallback;
69 }
70
71 let mut upgrade = false;
72 let mut connection_upgrade = false;
73 let mut ws_key = false;
74 let mut host: Option<&str> = None;
75
76 for line in lines {
77 if let Some((name, value)) = line.split_once(':') {
78 let name = name.trim().to_ascii_lowercase();
79 let value_trim = value.trim();
80 let value_lower = value_trim.to_ascii_lowercase();
81 match name.as_str() {
82 "upgrade" => {
83 if value_lower.contains("websocket") {
84 upgrade = true;
85 }
86 }
87 "connection" => {
88 if value_lower.contains("upgrade") {
89 connection_upgrade = true;
90 }
91 }
92 "sec-websocket-key" => {
93 if !value_trim.is_empty() {
94 ws_key = true;
95 }
96 }
97 "host" => {
98 host = Some(value_trim);
99 }
100 _ => {}
101 }
102 }
103 }
104
105 if !upgrade || !connection_upgrade || !ws_key {
106 return WsInspect::HttpFallback;
107 }
108
109 if !path_matches(cfg, path) || !host_matches(cfg, host) {
110 return WsInspect::Reject("websocket path/host mismatch");
111 }
112
113 WsInspect::Upgrade
114}
115
116pub async fn accept_ws<S>(
118 stream: S,
119 initial: Bytes,
120 cfg: &WsCfg,
121) -> Result<WebSocketStream<PrefixedStream<S>>, ServerError>
122where
123 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
124{
125 let max_frame = if cfg.max_frame_bytes == 0 {
126 None
127 } else {
128 Some(cfg.max_frame_bytes)
129 };
130 let ws_cfg = WebSocketConfig {
131 max_frame_size: max_frame,
132 max_message_size: max_frame,
133 ..WebSocketConfig::default()
134 };
135 let prefixed = PrefixedStream::new(initial, stream);
136 let ws = accept_hdr_async_with_config(
137 prefixed,
138 |req: &Request, resp: Response| {
139 debug!(path = %req.uri().path(), "websocket upgrade");
140 Ok(resp)
141 },
142 Some(ws_cfg),
143 )
144 .await
145 .map_err(|e| {
146 ServerError::Io(std::io::Error::new(
147 std::io::ErrorKind::InvalidData,
148 format!("websocket handshake failed: {e}"),
149 ))
150 })?;
151 Ok(ws)
152}
153
154pub async fn send_reject<S>(mut stream: S, reason: &'static str) -> Result<(), ServerError>
156where
157 S: AsyncWrite + Unpin,
158{
159 warn!(reason, "websocket rejected");
160 let response = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n";
161 tokio::io::AsyncWriteExt::write_all(&mut stream, response).await?;
162 Ok(())
163}
164
165fn find_header_end(buf: &[u8]) -> Option<usize> {
166 buf.windows(HTTP_HEADER_END.len())
167 .position(|w| w == HTTP_HEADER_END)
168 .map(|idx| idx + HTTP_HEADER_END.len())
169}
170
171fn path_matches(cfg: &WsCfg, path: &str) -> bool {
172 let path_only = path.split('?').next().unwrap_or("");
173 path_only == cfg.path
174}
175
176fn host_matches(cfg: &WsCfg, host: Option<&str>) -> bool {
177 let expected = match cfg.host.as_deref() {
178 Some(v) => v,
179 None => return true,
180 };
181 let host = match host {
182 Some(v) => v,
183 None => return false,
184 };
185 let host_only = host.split(':').next().unwrap_or("");
186 host_only.eq_ignore_ascii_case(expected)
187}