Skip to main content

pipa/http/ws/
conn.rs

1use std::collections::VecDeque;
2use std::io::{ErrorKind, Read, Write};
3use std::os::unix::io::RawFd;
4use std::sync::mpsc;
5
6use crate::http::conn::Connection;
7use crate::http::headers::Headers;
8use crate::http::status::HttpStatus;
9use crate::http::url::Url;
10use crate::http::ws::frame::{OpCode, WsFrame};
11use crate::http::ws::handshake::WsHandshake;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum WsState {
15    Connecting,
16    Handshake,
17    Open,
18    Closing,
19    Closed,
20}
21
22#[derive(Debug, Clone)]
23pub enum WsEvent {
24    Open,
25    Message(Vec<u8>, bool),
26    Close(u16, String),
27    Error(String),
28}
29
30pub struct WsConnection {
31    pub url: Url,
32    pub state: WsState,
33    conn: Option<Connection>,
34    connect_rx: Option<mpsc::Receiver<Result<Connection, String>>>,
35    write_buf: Vec<u8>,
36    write_pos: usize,
37    read_buf: [u8; 8192],
38    read_data: Vec<u8>,
39    key: String,
40    pending_frames: VecDeque<WsFrame>,
41    close_code: u16,
42    close_reason: String,
43    pub ready_state: u8,
44}
45
46impl WsConnection {
47    pub fn new(url: Url) -> Self {
48        WsConnection {
49            url,
50            state: WsState::Connecting,
51            conn: None,
52            connect_rx: None,
53            write_buf: Vec::new(),
54            write_pos: 0,
55            read_buf: [0u8; 8192],
56            read_data: Vec::new(),
57            key: String::new(),
58            pending_frames: VecDeque::new(),
59            close_code: 0,
60            close_reason: String::new(),
61            ready_state: 0,
62        }
63    }
64
65    pub fn fd(&self) -> Option<RawFd> {
66        self.conn.as_ref().map(|c| c.raw_fd())
67    }
68
69    pub fn set_connect_rx(&mut self, rx: mpsc::Receiver<Result<Connection, String>>) {
70        self.connect_rx = Some(rx);
71        self.state = WsState::Connecting;
72    }
73
74    pub fn try_advance(&mut self) -> Result<Option<WsEvent>, String> {
75        loop {
76            match self.state {
77                WsState::Connecting => match self.connect_rx.as_ref().unwrap().try_recv() {
78                    Ok(result) => {
79                        let conn = result?;
80                        conn.set_nonblocking(true)?;
81                        self.conn = Some(conn);
82                        self.key = WsHandshake::generate_key();
83                        self.build_handshake_request();
84                        self.state = WsState::Handshake;
85                    }
86                    Err(mpsc::TryRecvError::Empty) => {
87                        return Ok(None);
88                    }
89                    Err(mpsc::TryRecvError::Disconnected) => {
90                        return Err("connect thread disconnected".into());
91                    }
92                },
93
94                WsState::Handshake => {
95                    if !self.write_buf.is_empty() {
96                        let conn = self.conn.as_mut().unwrap();
97                        let remaining = &self.write_buf[self.write_pos..];
98                        if !remaining.is_empty() {
99                            match conn.write(remaining) {
100                                Ok(n) => {
101                                    self.write_pos += n;
102                                    if self.write_pos >= self.write_buf.len() {
103                                        self.write_buf.clear();
104                                        self.write_pos = 0;
105                                    } else {
106                                        return Ok(None);
107                                    }
108                                }
109                                Err(e) if e.kind() == ErrorKind::WouldBlock => {
110                                    return Ok(None);
111                                }
112                                Err(e) => return Err(format!("ws handshake write: {e}")),
113                            }
114                        }
115                    }
116
117                    let conn = self.conn.as_mut().unwrap();
118                    match conn.read(&mut self.read_buf) {
119                        Ok(0) => return Err("connection closed during handshake".into()),
120                        Ok(n) => {
121                            self.read_data.extend_from_slice(&self.read_buf[..n]);
122                            if let Some(pos) =
123                                self.read_data.windows(4).position(|w| w == b"\r\n\r\n")
124                            {
125                                let header_data = &self.read_data[..pos + 4];
126                                let (headers, _) = Headers::from_bytes(header_data)?;
127                                let status_line_end = self
128                                    .read_data
129                                    .windows(2)
130                                    .position(|w| w == b"\r\n")
131                                    .unwrap_or(0);
132                                let status_line = &self.read_data[..status_line_end];
133                                let code = if status_line.len() >= 12 {
134                                    let s = &status_line[9..12];
135                                    String::from_utf8_lossy(s).parse::<u16>().unwrap_or(0)
136                                } else {
137                                    0
138                                };
139                                let status = HttpStatus(code);
140                                let accept = WsHandshake::validate_response(status, &headers)?;
141                                if !WsHandshake::verify_accept(&self.key, &accept) {
142                                    return Err("WebSocket accept mismatch".into());
143                                }
144                                self.state = WsState::Open;
145                                self.ready_state = 1;
146                                self.read_data.clear();
147                                return Ok(Some(WsEvent::Open));
148                            }
149                        }
150                        Err(e) if e.kind() == ErrorKind::WouldBlock => {
151                            return Ok(None);
152                        }
153                        Err(e) => return Err(format!("ws handshake read: {e}")),
154                    }
155                }
156
157                WsState::Open => {
158                    if let Some(frame) = self.pending_frames.pop_front() {
159                        self.write_buf = frame.encode();
160                        self.write_pos = 0;
161                    }
162
163                    if !self.write_buf.is_empty() {
164                        let conn = self.conn.as_mut().unwrap();
165                        let remaining = &self.write_buf[self.write_pos..];
166                        if !remaining.is_empty() {
167                            match conn.write(remaining) {
168                                Ok(n) => {
169                                    self.write_pos += n;
170                                }
171                                Err(e) if e.kind() == ErrorKind::WouldBlock => {
172                                    return Ok(None);
173                                }
174                                Err(e) => return Err(format!("ws write error: {e}")),
175                            }
176                        }
177                        if self.write_pos >= self.write_buf.len() {
178                            self.write_buf.clear();
179                            self.write_pos = 0;
180                        }
181                        if !self.pending_frames.is_empty() {
182                            return Ok(None);
183                        }
184                    }
185
186                    let conn = self.conn.as_mut().unwrap();
187                    match conn.read(&mut self.read_buf) {
188                        Ok(0) => {
189                            self.state = WsState::Closed;
190                            self.ready_state = 3;
191                            return Ok(Some(WsEvent::Close(1006, "connection closed".into())));
192                        }
193                        Ok(n) => {
194                            self.read_data.extend_from_slice(&self.read_buf[..n]);
195                            let frames = WsFrame::parse_all(&self.read_data)?;
196                            if !frames.is_empty() {
197                                let consumed = self.calculate_consumed(&frames);
198                                self.read_data.drain(..consumed);
199                                for frame in frames {
200                                    match frame.opcode {
201                                        OpCode::Text | OpCode::Binary => {
202                                            let is_text = frame.opcode == OpCode::Text;
203                                            return Ok(Some(WsEvent::Message(
204                                                frame.payload,
205                                                is_text,
206                                            )));
207                                        }
208                                        OpCode::Ping => {
209                                            let pong = WsFrame::new_pong(frame.payload);
210                                            self.pending_frames.push_back(pong);
211                                            return Ok(None);
212                                        }
213                                        OpCode::Close => {
214                                            let (code, reason) =
215                                                Self::parse_close_payload(&frame.payload);
216                                            self.close_code = code;
217                                            self.close_reason = reason.clone();
218                                            let close_frame = WsFrame::new_close(code, &reason);
219                                            self.pending_frames.push_back(close_frame);
220                                            self.state = WsState::Closing;
221                                            self.ready_state = 2;
222                                            return Ok(None);
223                                        }
224                                        OpCode::Pong => {}
225                                        OpCode::Continuation => {}
226                                    }
227                                }
228                            }
229                        }
230                        Err(e) if e.kind() == ErrorKind::WouldBlock => {
231                            return Ok(None);
232                        }
233                        Err(e) => return Err(format!("ws read error: {e}")),
234                    }
235                }
236
237                WsState::Closing => {
238                    if let Some(frame) = self.pending_frames.pop_front() {
239                        self.write_buf = frame.encode();
240                        self.write_pos = 0;
241                    }
242                    if !self.write_buf.is_empty() {
243                        let conn = self.conn.as_mut().unwrap();
244                        let remaining = &self.write_buf[self.write_pos..];
245                        if !remaining.is_empty() {
246                            let _ = conn.write(remaining);
247                        }
248                        self.write_buf.clear();
249                        self.write_pos = 0;
250                    }
251                    let conn = self.conn.as_mut().unwrap();
252                    let _ = conn.read(&mut self.read_buf);
253                    self.state = WsState::Closed;
254                    self.ready_state = 3;
255                    return Ok(Some(WsEvent::Close(
256                        self.close_code,
257                        self.close_reason.clone(),
258                    )));
259                }
260
261                WsState::Closed => {
262                    return Ok(None);
263                }
264            }
265        }
266    }
267
268    pub fn send_text(&mut self, data: &str) {
269        let frame = WsFrame::new_text(data.as_bytes().to_vec());
270        self.pending_frames.push_back(frame);
271    }
272
273    pub fn send_binary(&mut self, data: &[u8]) {
274        let frame = WsFrame::new_binary(data.to_vec());
275        self.pending_frames.push_back(frame);
276    }
277
278    pub fn close(&mut self, code: u16, reason: &str) {
279        if self.state == WsState::Open {
280            let frame = WsFrame::new_close(code, reason);
281            self.pending_frames.push_back(frame);
282            self.state = WsState::Closing;
283            self.ready_state = 2;
284            self.close_code = code;
285            self.close_reason = reason.to_string();
286        }
287    }
288
289    pub fn wants_read(&self) -> bool {
290        matches!(
291            self.state,
292            WsState::Connecting | WsState::Handshake | WsState::Open
293        )
294    }
295
296    pub fn wants_write(&self) -> bool {
297        matches!(
298            self.state,
299            WsState::Handshake | WsState::Open | WsState::Closing
300        ) && (!self.write_buf.is_empty() || !self.pending_frames.is_empty())
301    }
302
303    fn build_handshake_request(&mut self) {
304        let host = format!(
305            "{}:{}",
306            self.url.host,
307            if self.url.port != 80 && self.url.port != 443 {
308                self.url.port
309            } else {
310                0
311            }
312        );
313        let host = if host.ends_with(":0") {
314            self.url.host.clone()
315        } else {
316            host
317        };
318        let path = self.url.request_target();
319        let mut headers = WsHandshake::build_request(&host, &path, &self.key);
320        if !headers.contains("user-agent") {
321            headers.set("User-Agent", "pipa/0.1");
322        }
323        let mut buf = Vec::new();
324        buf.extend_from_slice(b"GET ");
325        buf.extend_from_slice(path.as_bytes());
326        buf.extend_from_slice(b" HTTP/1.1\r\n");
327        buf.extend_from_slice(headers.to_request_bytes().as_ref());
328        buf.extend_from_slice(b"\r\n");
329        self.write_buf = buf;
330        self.write_pos = 0;
331    }
332
333    fn parse_close_payload(payload: &[u8]) -> (u16, String) {
334        if payload.len() >= 2 {
335            let code = u16::from_be_bytes([payload[0], payload[1]]);
336            let reason = if payload.len() > 2 {
337                String::from_utf8_lossy(&payload[2..]).to_string()
338            } else {
339                String::new()
340            };
341            (code, reason)
342        } else {
343            (1005, String::new())
344        }
345    }
346
347    fn calculate_consumed(&self, frames: &[WsFrame]) -> usize {
348        let mut total = 0usize;
349        for frame in frames {
350            let mut frame_size = 2;
351            let payload_len = frame.payload.len();
352            if payload_len >= 126 && payload_len <= 0xFFFF {
353                frame_size += 2;
354            } else if payload_len > 0xFFFF {
355                frame_size += 8;
356            }
357            if frame.mask.is_some() {
358                frame_size += 4;
359            }
360            frame_size += payload_len;
361            total += frame_size;
362        }
363        total
364    }
365}