Skip to main content

pipa/http/
request.rs

1use std::io::{ErrorKind, Read, Write};
2use std::os::unix::io::RawFd;
3use std::sync::mpsc;
4
5use crate::http::body_reader::{BodyMode, BodyReader};
6use crate::http::compression::Decompressor;
7use crate::http::conn::Connection;
8use crate::http::headers::Headers;
9use crate::http::method::HttpMethod;
10use crate::http::response::HttpResponse;
11use crate::http::status::HttpStatus;
12use crate::http::url::Url;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum RequestState {
16    Resolving,
17    Connecting,
18    WritingRequest,
19    ReadingStatus,
20    ReadingHeaders,
21    ReadingBody,
22    Done,
23}
24
25#[derive(Debug)]
26pub enum RequestEvent {
27    NeedRead,
28    NeedWrite,
29    Complete(HttpResponse),
30    Error(String),
31}
32
33pub struct HttpRequest {
34    pub url: Url,
35    pub state: RequestState,
36    pub method: HttpMethod,
37    pub req_headers: Headers,
38    pub body: Option<Vec<u8>>,
39    pub conn: Option<Connection>,
40    connect_rx: Option<mpsc::Receiver<Result<Connection, String>>>,
41    write_buf: Vec<u8>,
42    write_pos: usize,
43    read_buf: [u8; 8192],
44    read_data: Vec<u8>,
45    parse_buf: Vec<u8>,
46    resp_headers: Option<Headers>,
47    resp_status: Option<(u16, String)>,
48    body_reader: Option<BodyReader>,
49    decompressor: Option<Decompressor>,
50    pub redirect_count: u32,
51    max_redirects: u32,
52}
53
54impl HttpRequest {
55    pub fn dummy() -> Self {
56        use crate::http::method::HttpMethod;
57        use crate::http::url::Url;
58        HttpRequest {
59            url: Url {
60                scheme: String::new(),
61                host: String::new(),
62                port: 0,
63                path: String::new(),
64                query: String::new(),
65                full: String::new(),
66            },
67            state: RequestState::Done,
68            method: HttpMethod::GET,
69            req_headers: Headers::new(),
70            body: None,
71            conn: None,
72            connect_rx: None,
73            write_buf: Vec::new(),
74            write_pos: 0,
75            read_buf: [0u8; 8192],
76            read_data: Vec::new(),
77            parse_buf: Vec::new(),
78            resp_headers: None,
79            resp_status: None,
80            body_reader: None,
81            decompressor: None,
82            redirect_count: 0,
83            max_redirects: 20,
84        }
85    }
86
87    pub fn new(url: Url, method: HttpMethod, req_headers: Headers, body: Option<Vec<u8>>) -> Self {
88        HttpRequest {
89            url,
90            state: RequestState::Resolving,
91            method,
92            req_headers,
93            body,
94            conn: None,
95            connect_rx: None,
96            write_buf: Vec::new(),
97            write_pos: 0,
98            read_buf: [0u8; 8192],
99            read_data: Vec::new(),
100            parse_buf: Vec::new(),
101            resp_headers: None,
102            resp_status: None,
103            body_reader: None,
104            decompressor: None,
105            redirect_count: 0,
106            max_redirects: 20,
107        }
108    }
109
110    pub fn fd(&self) -> Option<RawFd> {
111        self.conn.as_ref().map(|c| c.raw_fd())
112    }
113
114    pub fn max_redirects(&self) -> u32 {
115        self.max_redirects
116    }
117
118    pub fn set_connect_rx(&mut self, rx: mpsc::Receiver<Result<Connection, String>>) {
119        self.connect_rx = Some(rx);
120        self.state = RequestState::Resolving;
121    }
122
123    pub fn try_advance(&mut self) -> Result<RequestEvent, String> {
124        loop {
125            match self.state {
126                RequestState::Resolving => match self.connect_rx.as_ref().unwrap().try_recv() {
127                    Ok(result) => {
128                        let conn = result?;
129                        conn.set_nonblocking(true)?;
130                        self.conn = Some(conn);
131                        self.build_request();
132                        self.state = RequestState::WritingRequest;
133                    }
134                    Err(mpsc::TryRecvError::Empty) => {
135                        return Ok(RequestEvent::NeedRead);
136                    }
137                    Err(mpsc::TryRecvError::Disconnected) => {
138                        return Err("connect thread disconnected".into());
139                    }
140                },
141
142                RequestState::WritingRequest => {
143                    let conn = self.conn.as_mut().unwrap();
144                    let remaining = &self.write_buf[self.write_pos..];
145                    if remaining.is_empty() {
146                        self.state = RequestState::ReadingStatus;
147                        self.read_data.clear();
148                        continue;
149                    }
150                    match conn.write(remaining) {
151                        Ok(n) => {
152                            self.write_pos += n;
153                            if self.write_pos >= self.write_buf.len() {
154                                self.state = RequestState::ReadingStatus;
155                                self.read_data.clear();
156                            }
157                        }
158                        Err(e) if e.kind() == ErrorKind::WouldBlock => {
159                            return Ok(RequestEvent::NeedWrite);
160                        }
161                        Err(e) => {
162                            return Err(format!("write error: {e}"));
163                        }
164                    }
165                }
166
167                RequestState::ReadingStatus => {
168                    let conn = self.conn.as_mut().unwrap();
169                    match conn.read(&mut self.read_buf) {
170                        Ok(0) => {
171                            return Err("connection closed while reading status".into());
172                        }
173                        Ok(n) => {
174                            self.read_data.extend_from_slice(&self.read_buf[..n]);
175                            if let Some(status_line_end) =
176                                self.read_data.windows(4).position(|w| w == b"\r\n\r\n")
177                            {
178                                let headers_start = status_line_end + 4;
179                                if let Some(status) = self.parse_status_line()? {
180                                    self.resp_status = Some(status);
181                                    let header_data = &self.read_data[headers_start..];
182                                    let (headers, _) =
183                                        Headers::from_bytes(&self.read_data[headers_start..])?;
184                                    let body_mode = self.determine_body_mode(&headers);
185                                    self.resp_headers = Some(headers.clone());
186                                    let enc =
187                                        headers.get("content-encoding").map(|s| s.to_string());
188                                    self.decompressor = Some(Decompressor::new(enc.as_deref()));
189                                    self.body_reader = Some(BodyReader::new(body_mode));
190                                    self.parse_buf = self.read_data
191                                        [headers_start + header_data.len()..]
192                                        .to_vec();
193                                    self.state = RequestState::ReadingBody;
194                                    continue;
195                                }
196                            } else if let Some(line_end) =
197                                self.read_data.windows(2).position(|w| w == b"\r\n")
198                            {
199                                if line_end == self.read_data.len() - 2
200                                    || self.read_data[line_end + 2..]
201                                        .windows(4)
202                                        .any(|w| w == b"\r\n\r\n")
203                                {
204                                    continue;
205                                }
206                            }
207                        }
208                        Err(e) if e.kind() == ErrorKind::WouldBlock => {
209                            return Ok(RequestEvent::NeedRead);
210                        }
211                        Err(e) => {
212                            return Err(format!("read error: {e}"));
213                        }
214                    }
215                }
216
217                RequestState::ReadingHeaders => {
218                    let conn = self.conn.as_mut().unwrap();
219                    match conn.read(&mut self.read_buf) {
220                        Ok(0) => {
221                            return Err("connection closed while reading headers".into());
222                        }
223                        Ok(n) => {
224                            self.read_data.extend_from_slice(&self.read_buf[..n]);
225                            if let Some(headers_end) =
226                                self.read_data.windows(4).position(|w| w == b"\r\n\r\n")
227                            {
228                                let header_data = &self.read_data[..headers_end + 2];
229                                let (headers, _) = Headers::from_bytes(header_data)?;
230                                let body_mode = self.determine_body_mode(&headers);
231                                self.resp_headers = Some(headers.clone());
232                                let enc = headers.get("content-encoding").map(|s| s.to_string());
233                                self.decompressor = Some(Decompressor::new(enc.as_deref()));
234                                self.body_reader = Some(BodyReader::new(body_mode));
235                                self.parse_buf = self.read_data[headers_end + 4..].to_vec();
236                                self.state = RequestState::ReadingBody;
237                                continue;
238                            }
239                        }
240                        Err(e) if e.kind() == ErrorKind::WouldBlock => {
241                            return Ok(RequestEvent::NeedRead);
242                        }
243                        Err(e) => {
244                            return Err(format!("read error: {e}"));
245                        }
246                    }
247                }
248
249                RequestState::ReadingBody => {
250                    let conn = self.conn.as_mut().unwrap();
251
252                    if !self.parse_buf.is_empty() {
253                        let data = std::mem::take(&mut self.parse_buf);
254                        let _ = self.body_reader.as_mut().unwrap().feed(&data)?;
255                    }
256
257                    if self.body_reader.as_ref().unwrap().is_done() {
258                        return self.finalize_response();
259                    }
260
261                    match conn.read(&mut self.read_buf) {
262                        Ok(0) => {
263                            self.body_reader.as_mut().unwrap().finish();
264                            return self.finalize_response();
265                        }
266                        Ok(n) => {
267                            let _ = self
268                                .body_reader
269                                .as_mut()
270                                .unwrap()
271                                .feed(&self.read_buf[..n])?;
272                            if self.body_reader.as_ref().unwrap().is_done() {
273                                return self.finalize_response();
274                            }
275                        }
276                        Err(e) if e.kind() == ErrorKind::WouldBlock => {
277                            return Ok(RequestEvent::NeedRead);
278                        }
279                        Err(e) => {
280                            return Err(format!("read error: {e}"));
281                        }
282                    }
283                }
284
285                RequestState::Done => {
286                    return Err("request already completed".into());
287                }
288
289                RequestState::Connecting => {
290                    unreachable!("Connecting should have resolved to another state");
291                }
292            }
293        }
294    }
295
296    fn parse_status_line(&self) -> Result<Option<(u16, String)>, String> {
297        let line_end = self.read_data.windows(2).position(|w| w == b"\r\n");
298        match line_end {
299            Some(end) => {
300                let line = &self.read_data[..end];
301                if line.len() < 12 || !line.starts_with(b"HTTP/") {
302                    return Err(format!(
303                        "malformed status line: {:?}",
304                        String::from_utf8_lossy(line)
305                    ));
306                }
307                let http_end = line
308                    .iter()
309                    .position(|&b| b == b' ')
310                    .ok_or("no space after HTTP/x.y")?;
311                let status_start = http_end + 1;
312                if status_start >= line.len() {
313                    return Err("missing status code".into());
314                }
315                let status_end = line[status_start..]
316                    .iter()
317                    .position(|&b| b == b' ')
318                    .map(|p| status_start + p)
319                    .unwrap_or(line.len());
320                let code_str = String::from_utf8_lossy(&line[status_start..status_end]);
321                let code: u16 = code_str
322                    .parse()
323                    .map_err(|e| format!("invalid status code: {e}"))?;
324                let reason = if status_end < line.len() {
325                    String::from_utf8_lossy(&line[status_end + 1..])
326                        .trim()
327                        .to_string()
328                } else {
329                    String::new()
330                };
331                Ok(Some((code, reason)))
332            }
333            None => Ok(None),
334        }
335    }
336
337    fn determine_body_mode(&self, headers: &Headers) -> BodyMode {
338        if let Some(len_str) = headers.get("content-length") {
339            if let Ok(len) = len_str.parse::<usize>() {
340                return BodyMode::ContentLength(len);
341            }
342        }
343        if let Some(te) = headers.get("transfer-encoding") {
344            if te.contains("chunked") {
345                return BodyMode::Chunked;
346            }
347        }
348        BodyMode::ConnectionClose
349    }
350
351    fn build_request(&mut self) {
352        let target = self.url.request_target();
353        let mut buf = Vec::new();
354        buf.extend_from_slice(self.method.as_str().as_bytes());
355        buf.extend_from_slice(b" ");
356        buf.extend_from_slice(target.as_bytes());
357        buf.extend_from_slice(b" HTTP/1.1\r\n");
358        buf.extend_from_slice(self.req_headers.to_request_bytes().as_ref());
359        if !self.req_headers.contains("host") {
360            buf.extend_from_slice(b"Host: ");
361            buf.extend_from_slice(self.url.host.as_bytes());
362            if (self.url.port != 80 && self.url.port != 443)
363                || (self.url.port == 80 && self.url.is_tls())
364                || (self.url.port == 443 && !self.url.is_tls())
365            {
366                buf.extend_from_slice(b":");
367                buf.extend_from_slice(self.url.port.to_string().as_bytes());
368            }
369            buf.extend_from_slice(b"\r\n");
370        }
371        if !self.req_headers.contains("user-agent") {
372            buf.extend_from_slice(b"User-Agent: pipa/0.1\r\n");
373        }
374        if !self.req_headers.contains("accept") {
375            buf.extend_from_slice(b"Accept: */*\r\n");
376        }
377        if let Some(ref body) = self.body {
378            if !self.req_headers.contains("content-length") {
379                buf.extend_from_slice(b"Content-Length: ");
380                buf.extend_from_slice(body.len().to_string().as_bytes());
381                buf.extend_from_slice(b"\r\n");
382            }
383            buf.extend_from_slice(b"\r\n");
384            buf.extend_from_slice(body);
385        } else {
386            buf.extend_from_slice(b"\r\n");
387        }
388        self.write_buf = buf;
389        self.write_pos = 0;
390    }
391
392    fn finalize_response(&mut self) -> Result<RequestEvent, String> {
393        self.state = RequestState::Done;
394        let (code, status_text) = self.resp_status.take().unwrap_or((0, String::new()));
395        let headers = self.resp_headers.take().unwrap_or_default();
396        let mut body_reader = self
397            .body_reader
398            .take()
399            .unwrap_or_else(|| BodyReader::new(BodyMode::None));
400        let body = body_reader.take_body();
401
402        let decompressed = if let Some(ref mut decomp) = self.decompressor {
403            decomp.decompress(&body)?
404        } else {
405            body
406        };
407
408        let mut final_reader = BodyReader::new(BodyMode::ContentLength(decompressed.len()));
409        let _ = final_reader.feed(&decompressed)?;
410
411        let resp = HttpResponse::new(
412            HttpStatus(code),
413            status_text,
414            headers,
415            final_reader,
416            self.url.full.clone(),
417        );
418        Ok(RequestEvent::Complete(resp))
419    }
420
421    pub fn wants_read(&self) -> bool {
422        matches!(
423            self.state,
424            RequestState::ReadingStatus
425                | RequestState::ReadingHeaders
426                | RequestState::ReadingBody
427                | RequestState::Resolving
428        )
429    }
430
431    pub fn wants_write(&self) -> bool {
432        matches!(self.state, RequestState::WritingRequest)
433    }
434}