1use std::collections::VecDeque;
2use std::io::{ErrorKind, Read, Write};
3use std::os::unix::io::RawFd;
4use std::sync::mpsc;
5
6use crate::http::conn::Connection;
7use crate::http::headers::Headers;
8use crate::http::status::HttpStatus;
9use crate::http::url::Url;
10use crate::http::ws::frame::{OpCode, WsFrame};
11use crate::http::ws::handshake::WsHandshake;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum WsState {
15 Connecting,
16 Handshake,
17 Open,
18 Closing,
19 Closed,
20}
21
22#[derive(Debug, Clone)]
23pub enum WsEvent {
24 Open,
25 Message(Vec<u8>, bool),
26 Close(u16, String),
27 Error(String),
28}
29
30pub struct WsConnection {
31 pub url: Url,
32 pub state: WsState,
33 conn: Option<Connection>,
34 connect_rx: Option<mpsc::Receiver<Result<Connection, String>>>,
35 write_buf: Vec<u8>,
36 write_pos: usize,
37 read_buf: [u8; 8192],
38 read_data: Vec<u8>,
39 key: String,
40 pending_frames: VecDeque<WsFrame>,
41 close_code: u16,
42 close_reason: String,
43 pub ready_state: u8,
44}
45
46impl WsConnection {
47 pub fn new(url: Url) -> Self {
48 WsConnection {
49 url,
50 state: WsState::Connecting,
51 conn: None,
52 connect_rx: None,
53 write_buf: Vec::new(),
54 write_pos: 0,
55 read_buf: [0u8; 8192],
56 read_data: Vec::new(),
57 key: String::new(),
58 pending_frames: VecDeque::new(),
59 close_code: 0,
60 close_reason: String::new(),
61 ready_state: 0,
62 }
63 }
64
65 pub fn fd(&self) -> Option<RawFd> {
66 self.conn.as_ref().map(|c| c.raw_fd())
67 }
68
69 pub fn set_connect_rx(&mut self, rx: mpsc::Receiver<Result<Connection, String>>) {
70 self.connect_rx = Some(rx);
71 self.state = WsState::Connecting;
72 }
73
74 pub fn try_advance(&mut self) -> Result<Option<WsEvent>, String> {
75 loop {
76 match self.state {
77 WsState::Connecting => match self.connect_rx.as_ref().unwrap().try_recv() {
78 Ok(result) => {
79 let conn = result?;
80 conn.set_nonblocking(true)?;
81 self.conn = Some(conn);
82 self.key = WsHandshake::generate_key();
83 self.build_handshake_request();
84 self.state = WsState::Handshake;
85 }
86 Err(mpsc::TryRecvError::Empty) => {
87 return Ok(None);
88 }
89 Err(mpsc::TryRecvError::Disconnected) => {
90 return Err("connect thread disconnected".into());
91 }
92 },
93
94 WsState::Handshake => {
95 if !self.write_buf.is_empty() {
96 let conn = self.conn.as_mut().unwrap();
97 let remaining = &self.write_buf[self.write_pos..];
98 if !remaining.is_empty() {
99 match conn.write(remaining) {
100 Ok(n) => {
101 self.write_pos += n;
102 if self.write_pos >= self.write_buf.len() {
103 self.write_buf.clear();
104 self.write_pos = 0;
105 } else {
106 return Ok(None);
107 }
108 }
109 Err(e) if e.kind() == ErrorKind::WouldBlock => {
110 return Ok(None);
111 }
112 Err(e) => return Err(format!("ws handshake write: {e}")),
113 }
114 }
115 }
116
117 let conn = self.conn.as_mut().unwrap();
118 match conn.read(&mut self.read_buf) {
119 Ok(0) => return Err("connection closed during handshake".into()),
120 Ok(n) => {
121 self.read_data.extend_from_slice(&self.read_buf[..n]);
122 if let Some(pos) =
123 self.read_data.windows(4).position(|w| w == b"\r\n\r\n")
124 {
125 let header_data = &self.read_data[..pos + 4];
126 let (headers, _) = Headers::from_bytes(header_data)?;
127 let status_line_end = self
128 .read_data
129 .windows(2)
130 .position(|w| w == b"\r\n")
131 .unwrap_or(0);
132 let status_line = &self.read_data[..status_line_end];
133 let code = if status_line.len() >= 12 {
134 let s = &status_line[9..12];
135 String::from_utf8_lossy(s).parse::<u16>().unwrap_or(0)
136 } else {
137 0
138 };
139 let status = HttpStatus(code);
140 let accept = WsHandshake::validate_response(status, &headers)?;
141 if !WsHandshake::verify_accept(&self.key, &accept) {
142 return Err("WebSocket accept mismatch".into());
143 }
144 self.state = WsState::Open;
145 self.ready_state = 1;
146 self.read_data.clear();
147 return Ok(Some(WsEvent::Open));
148 }
149 }
150 Err(e) if e.kind() == ErrorKind::WouldBlock => {
151 return Ok(None);
152 }
153 Err(e) => return Err(format!("ws handshake read: {e}")),
154 }
155 }
156
157 WsState::Open => {
158 if let Some(frame) = self.pending_frames.pop_front() {
159 self.write_buf = frame.encode();
160 self.write_pos = 0;
161 }
162
163 if !self.write_buf.is_empty() {
164 let conn = self.conn.as_mut().unwrap();
165 let remaining = &self.write_buf[self.write_pos..];
166 if !remaining.is_empty() {
167 match conn.write(remaining) {
168 Ok(n) => {
169 self.write_pos += n;
170 }
171 Err(e) if e.kind() == ErrorKind::WouldBlock => {
172 return Ok(None);
173 }
174 Err(e) => return Err(format!("ws write error: {e}")),
175 }
176 }
177 if self.write_pos >= self.write_buf.len() {
178 self.write_buf.clear();
179 self.write_pos = 0;
180 }
181 if !self.pending_frames.is_empty() {
182 return Ok(None);
183 }
184 }
185
186 let conn = self.conn.as_mut().unwrap();
187 match conn.read(&mut self.read_buf) {
188 Ok(0) => {
189 self.state = WsState::Closed;
190 self.ready_state = 3;
191 return Ok(Some(WsEvent::Close(1006, "connection closed".into())));
192 }
193 Ok(n) => {
194 self.read_data.extend_from_slice(&self.read_buf[..n]);
195 let frames = WsFrame::parse_all(&self.read_data)?;
196 if !frames.is_empty() {
197 let consumed = self.calculate_consumed(&frames);
198 self.read_data.drain(..consumed);
199 for frame in frames {
200 match frame.opcode {
201 OpCode::Text | OpCode::Binary => {
202 let is_text = frame.opcode == OpCode::Text;
203 return Ok(Some(WsEvent::Message(
204 frame.payload,
205 is_text,
206 )));
207 }
208 OpCode::Ping => {
209 let pong = WsFrame::new_pong(frame.payload);
210 self.pending_frames.push_back(pong);
211 return Ok(None);
212 }
213 OpCode::Close => {
214 let (code, reason) =
215 Self::parse_close_payload(&frame.payload);
216 self.close_code = code;
217 self.close_reason = reason.clone();
218 let close_frame = WsFrame::new_close(code, &reason);
219 self.pending_frames.push_back(close_frame);
220 self.state = WsState::Closing;
221 self.ready_state = 2;
222 return Ok(None);
223 }
224 OpCode::Pong => {}
225 OpCode::Continuation => {}
226 }
227 }
228 }
229 }
230 Err(e) if e.kind() == ErrorKind::WouldBlock => {
231 return Ok(None);
232 }
233 Err(e) => return Err(format!("ws read error: {e}")),
234 }
235 }
236
237 WsState::Closing => {
238 if let Some(frame) = self.pending_frames.pop_front() {
239 self.write_buf = frame.encode();
240 self.write_pos = 0;
241 }
242 if !self.write_buf.is_empty() {
243 let conn = self.conn.as_mut().unwrap();
244 let remaining = &self.write_buf[self.write_pos..];
245 if !remaining.is_empty() {
246 let _ = conn.write(remaining);
247 }
248 self.write_buf.clear();
249 self.write_pos = 0;
250 }
251 let conn = self.conn.as_mut().unwrap();
252 let _ = conn.read(&mut self.read_buf);
253 self.state = WsState::Closed;
254 self.ready_state = 3;
255 return Ok(Some(WsEvent::Close(
256 self.close_code,
257 self.close_reason.clone(),
258 )));
259 }
260
261 WsState::Closed => {
262 return Ok(None);
263 }
264 }
265 }
266 }
267
268 pub fn send_text(&mut self, data: &str) {
269 let frame = WsFrame::new_text(data.as_bytes().to_vec());
270 self.pending_frames.push_back(frame);
271 }
272
273 pub fn send_binary(&mut self, data: &[u8]) {
274 let frame = WsFrame::new_binary(data.to_vec());
275 self.pending_frames.push_back(frame);
276 }
277
278 pub fn close(&mut self, code: u16, reason: &str) {
279 if self.state == WsState::Open {
280 let frame = WsFrame::new_close(code, reason);
281 self.pending_frames.push_back(frame);
282 self.state = WsState::Closing;
283 self.ready_state = 2;
284 self.close_code = code;
285 self.close_reason = reason.to_string();
286 }
287 }
288
289 pub fn wants_read(&self) -> bool {
290 matches!(
291 self.state,
292 WsState::Connecting | WsState::Handshake | WsState::Open
293 )
294 }
295
296 pub fn wants_write(&self) -> bool {
297 matches!(
298 self.state,
299 WsState::Handshake | WsState::Open | WsState::Closing
300 ) && (!self.write_buf.is_empty() || !self.pending_frames.is_empty())
301 }
302
303 fn build_handshake_request(&mut self) {
304 let host = format!(
305 "{}:{}",
306 self.url.host,
307 if self.url.port != 80 && self.url.port != 443 {
308 self.url.port
309 } else {
310 0
311 }
312 );
313 let host = if host.ends_with(":0") {
314 self.url.host.clone()
315 } else {
316 host
317 };
318 let path = self.url.request_target();
319 let mut headers = WsHandshake::build_request(&host, &path, &self.key);
320 if !headers.contains("user-agent") {
321 headers.set("User-Agent", "pipa/0.1");
322 }
323 let mut buf = Vec::new();
324 buf.extend_from_slice(b"GET ");
325 buf.extend_from_slice(path.as_bytes());
326 buf.extend_from_slice(b" HTTP/1.1\r\n");
327 buf.extend_from_slice(headers.to_request_bytes().as_ref());
328 buf.extend_from_slice(b"\r\n");
329 self.write_buf = buf;
330 self.write_pos = 0;
331 }
332
333 fn parse_close_payload(payload: &[u8]) -> (u16, String) {
334 if payload.len() >= 2 {
335 let code = u16::from_be_bytes([payload[0], payload[1]]);
336 let reason = if payload.len() > 2 {
337 String::from_utf8_lossy(&payload[2..]).to_string()
338 } else {
339 String::new()
340 };
341 (code, reason)
342 } else {
343 (1005, String::new())
344 }
345 }
346
347 fn calculate_consumed(&self, frames: &[WsFrame]) -> usize {
348 let mut total = 0usize;
349 for frame in frames {
350 let mut frame_size = 2;
351 let payload_len = frame.payload.len();
352 if payload_len >= 126 && payload_len <= 0xFFFF {
353 frame_size += 2;
354 } else if payload_len > 0xFFFF {
355 frame_size += 8;
356 }
357 if frame.mask.is_some() {
358 frame_size += 4;
359 }
360 frame_size += payload_len;
361 total += frame_size;
362 }
363 total
364 }
365}