Skip to main content

pipa/http/
request.rs

1use std::io::{ErrorKind, Read, Write};
2use std::os::unix::io::RawFd;
3
4use crate::http::body_reader::{BodyMode, BodyReader};
5use crate::http::compression::Decompressor;
6use crate::http::conn::Connection;
7use crate::http::headers::Headers;
8use crate::http::method::HttpMethod;
9use crate::http::response::HttpResponse;
10use crate::http::status::HttpStatus;
11use crate::http::url::Url;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum RequestState {
15    Connecting,
16    WritingRequest,
17    ReadingStatus,
18    ReadingHeaders,
19    ReadingBody,
20    Done,
21}
22
23#[derive(Debug)]
24pub enum RequestEvent {
25    NeedRead,
26    NeedWrite,
27    Complete(HttpResponse),
28    Error(String),
29}
30
31pub struct HttpRequest {
32    pub url: Url,
33    pub state: RequestState,
34    pub method: HttpMethod,
35    pub req_headers: Headers,
36    pub body: Option<Vec<u8>>,
37    pub conn: Option<Connection>,
38    write_buf: Vec<u8>,
39    write_pos: usize,
40    read_buf: [u8; 8192],
41    read_data: Vec<u8>,
42    parse_buf: Vec<u8>,
43    resp_headers: Option<Headers>,
44    resp_status: Option<(u16, String)>,
45    body_reader: Option<BodyReader>,
46    decompressor: Option<Decompressor>,
47    pub redirect_count: u32,
48    max_redirects: u32,
49}
50
51impl HttpRequest {
52    pub fn dummy() -> Self {
53        HttpRequest {
54            url: Url {
55                scheme: String::new(),
56                host: String::new(),
57                port: 0,
58                path: String::new(),
59                query: String::new(),
60                full: String::new(),
61            },
62            state: RequestState::Done,
63            method: HttpMethod::GET,
64            req_headers: Headers::new(),
65            body: None,
66            conn: None,
67            write_buf: Vec::new(),
68            write_pos: 0,
69            read_buf: [0u8; 8192],
70            read_data: Vec::new(),
71            parse_buf: Vec::new(),
72            resp_headers: None,
73            resp_status: None,
74            body_reader: None,
75            decompressor: None,
76            redirect_count: 0,
77            max_redirects: 20,
78        }
79    }
80
81    pub fn new(url: Url, method: HttpMethod, req_headers: Headers, body: Option<Vec<u8>>) -> Self {
82        HttpRequest {
83            url,
84            state: RequestState::Connecting,
85            method,
86            req_headers,
87            body,
88            conn: None,
89            write_buf: Vec::new(),
90            write_pos: 0,
91            read_buf: [0u8; 8192],
92            read_data: Vec::new(),
93            parse_buf: Vec::new(),
94            resp_headers: None,
95            resp_status: None,
96            body_reader: None,
97            decompressor: None,
98            redirect_count: 0,
99            max_redirects: 20,
100        }
101    }
102
103    pub fn fd(&self) -> Option<RawFd> {
104        self.conn.as_ref().map(|c| c.raw_fd())
105    }
106
107    pub fn max_redirects(&self) -> u32 {
108        self.max_redirects
109    }
110
111    pub fn try_advance(&mut self) -> Result<RequestEvent, String> {
112        loop {
113            match self.state {
114                RequestState::WritingRequest => {
115                    let conn = self.conn.as_mut().unwrap();
116                    let remaining = &self.write_buf[self.write_pos..];
117                    if remaining.is_empty() {
118                        self.state = RequestState::ReadingStatus;
119                        self.read_data.clear();
120                        continue;
121                    }
122                    match conn.write(remaining) {
123                        Ok(n) => {
124                            self.write_pos += n;
125                            if self.write_pos >= self.write_buf.len() {
126                                self.state = RequestState::ReadingStatus;
127                                self.read_data.clear();
128                            }
129                        }
130                        Err(e) if e.kind() == ErrorKind::WouldBlock => {
131                            return Ok(RequestEvent::NeedWrite);
132                        }
133                        Err(e) => {
134                            return Err(format!("write error: {e}"));
135                        }
136                    }
137                }
138
139                RequestState::ReadingStatus => {
140                    let conn = self.conn.as_mut().unwrap();
141                    match conn.read(&mut self.read_buf) {
142                        Ok(0) => {
143                            return Err("connection closed while reading status".into());
144                        }
145                        Ok(n) => {
146                            self.read_data.extend_from_slice(&self.read_buf[..n]);
147                            if let Some(headers_end) =
148                                self.read_data.windows(4).position(|w| w == b"\r\n\r\n")
149                            {
150                                let line_end = self
151                                    .read_data
152                                    .windows(2)
153                                    .position(|w| w == b"\r\n")
154                                    .unwrap_or(0);
155                                let status =
156                                    self.parse_status_from_bytes(&self.read_data[..line_end])?;
157                                self.resp_status = Some(status);
158                                let header_bytes = &self.read_data[line_end + 2..headers_end + 4];
159                                let (headers, _) = Headers::from_bytes(header_bytes)?;
160                                let body_mode = self.determine_body_mode(&headers);
161                                self.resp_headers = Some(headers.clone());
162                                let enc = headers.get("content-encoding").map(|s| s.to_string());
163                                self.decompressor = Some(Decompressor::new(enc.as_deref()));
164                                self.body_reader = Some(BodyReader::new(body_mode));
165                                self.parse_buf = self.read_data[headers_end + 4..].to_vec();
166                                self.state = RequestState::ReadingBody;
167                                continue;
168                            }
169                        }
170                        Err(e) if e.kind() == ErrorKind::WouldBlock => {
171                            return Ok(RequestEvent::NeedRead);
172                        }
173                        Err(e) => {
174                            return Err(format!("read error: {e}"));
175                        }
176                    }
177                }
178
179                RequestState::ReadingHeaders => {
180                    let conn = self.conn.as_mut().unwrap();
181                    match conn.read(&mut self.read_buf) {
182                        Ok(0) => {
183                            return Err("connection closed while reading headers".into());
184                        }
185                        Ok(n) => {
186                            self.read_data.extend_from_slice(&self.read_buf[..n]);
187                            if let Some(headers_end) =
188                                self.read_data.windows(4).position(|w| w == b"\r\n\r\n")
189                            {
190                                let header_data = &self.read_data[..headers_end + 2];
191                                let (headers, _) = Headers::from_bytes(header_data)?;
192                                let body_mode = self.determine_body_mode(&headers);
193                                self.resp_headers = Some(headers.clone());
194                                let enc = headers.get("content-encoding").map(|s| s.to_string());
195                                self.decompressor = Some(Decompressor::new(enc.as_deref()));
196                                self.body_reader = Some(BodyReader::new(body_mode));
197                                self.parse_buf = self.read_data[headers_end + 4..].to_vec();
198                                self.state = RequestState::ReadingBody;
199                                continue;
200                            }
201                        }
202                        Err(e) if e.kind() == ErrorKind::WouldBlock => {
203                            return Ok(RequestEvent::NeedRead);
204                        }
205                        Err(e) => {
206                            return Err(format!("read error: {e}"));
207                        }
208                    }
209                }
210
211                RequestState::ReadingBody => {
212                    let conn = self.conn.as_mut().unwrap();
213
214                    if !self.parse_buf.is_empty() {
215                        let data = std::mem::take(&mut self.parse_buf);
216                        let _ = self.body_reader.as_mut().unwrap().feed(&data)?;
217                    }
218
219                    if self.body_reader.as_ref().unwrap().is_done() {
220                        return self.finalize_response();
221                    }
222
223                    match conn.read(&mut self.read_buf) {
224                        Ok(0) => {
225                            self.body_reader.as_mut().unwrap().finish();
226                            return self.finalize_response();
227                        }
228                        Ok(n) => {
229                            let _ = self
230                                .body_reader
231                                .as_mut()
232                                .unwrap()
233                                .feed(&self.read_buf[..n])?;
234                            if self.body_reader.as_ref().unwrap().is_done() {
235                                return self.finalize_response();
236                            }
237                        }
238                        Err(e) if e.kind() == ErrorKind::WouldBlock => {
239                            return Ok(RequestEvent::NeedRead);
240                        }
241                        Err(e) => {
242                            return Err(format!("read error: {e}"));
243                        }
244                    }
245                }
246
247                RequestState::Done => {
248                    return Err("request already completed".into());
249                }
250
251                RequestState::Connecting => {
252                    unreachable!("Connecting should have resolved to another state");
253                }
254            }
255        }
256    }
257
258    fn parse_status_from_bytes(&self, line: &[u8]) -> Result<(u16, String), String> {
259        if line.len() < 12 || !line.starts_with(b"HTTP/") {
260            return Err(format!(
261                "malformed status line: {:?}",
262                String::from_utf8_lossy(line)
263            ));
264        }
265        let http_end = line
266            .iter()
267            .position(|&b| b == b' ')
268            .ok_or("no space after HTTP/x.y")?;
269        let status_start = http_end + 1;
270        if status_start >= line.len() {
271            return Err("missing status code".into());
272        }
273        let status_end = line[status_start..]
274            .iter()
275            .position(|&b| b == b' ')
276            .map(|p| status_start + p)
277            .unwrap_or(line.len());
278        let code_str = String::from_utf8_lossy(&line[status_start..status_end]);
279        let code: u16 = code_str
280            .parse()
281            .map_err(|e| format!("invalid status code: {e}"))?;
282        let reason = if status_end < line.len() {
283            String::from_utf8_lossy(&line[status_end + 1..])
284                .trim()
285                .to_string()
286        } else {
287            String::new()
288        };
289        Ok((code, reason))
290    }
291
292    fn determine_body_mode(&self, headers: &Headers) -> BodyMode {
293        if let Some(len_str) = headers.get("content-length") {
294            if let Ok(len) = len_str.parse::<usize>() {
295                return BodyMode::ContentLength(len);
296            }
297        }
298        if let Some(te) = headers.get("transfer-encoding") {
299            if te.contains("chunked") {
300                return BodyMode::Chunked;
301            }
302        }
303        BodyMode::ConnectionClose
304    }
305
306    pub fn build_request(&mut self) {
307        let target = self.url.request_target();
308        let mut buf = Vec::new();
309        buf.extend_from_slice(self.method.as_str().as_bytes());
310        buf.extend_from_slice(b" ");
311        buf.extend_from_slice(target.as_bytes());
312        buf.extend_from_slice(b" HTTP/1.1\r\n");
313        buf.extend_from_slice(self.req_headers.to_request_bytes().as_ref());
314        if !self.req_headers.contains("host") {
315            buf.extend_from_slice(b"Host: ");
316            buf.extend_from_slice(self.url.host.as_bytes());
317            if (self.url.port != 80 && self.url.port != 443)
318                || (self.url.port == 80 && self.url.is_tls())
319                || (self.url.port == 443 && !self.url.is_tls())
320            {
321                buf.extend_from_slice(b":");
322                buf.extend_from_slice(self.url.port.to_string().as_bytes());
323            }
324            buf.extend_from_slice(b"\r\n");
325        }
326        if !self.req_headers.contains("user-agent") {
327            buf.extend_from_slice(b"User-Agent: pipa/0.1\r\n");
328        }
329        if !self.req_headers.contains("accept") {
330            buf.extend_from_slice(b"Accept: */*\r\n");
331        }
332        if let Some(ref body) = self.body {
333            if !self.req_headers.contains("content-length") {
334                buf.extend_from_slice(b"Content-Length: ");
335                buf.extend_from_slice(body.len().to_string().as_bytes());
336                buf.extend_from_slice(b"\r\n");
337            }
338            buf.extend_from_slice(b"\r\n");
339            buf.extend_from_slice(body);
340        } else {
341            buf.extend_from_slice(b"\r\n");
342        }
343        self.write_buf = buf;
344        self.write_pos = 0;
345    }
346
347    fn finalize_response(&mut self) -> Result<RequestEvent, String> {
348        self.state = RequestState::Done;
349        let (code, status_text) = self.resp_status.take().unwrap_or((0, String::new()));
350        let headers = self.resp_headers.take().unwrap_or_default();
351        let mut body_reader = self
352            .body_reader
353            .take()
354            .unwrap_or_else(|| BodyReader::new(BodyMode::None));
355        let body = body_reader.take_body();
356
357        let decompressed = if let Some(ref mut decomp) = self.decompressor {
358            decomp.decompress(&body)?
359        } else {
360            body
361        };
362
363        let mut final_reader = BodyReader::new(BodyMode::ContentLength(decompressed.len()));
364        let _ = final_reader.feed(&decompressed)?;
365
366        let resp = HttpResponse::new(
367            HttpStatus(code),
368            status_text,
369            headers,
370            final_reader,
371            self.url.full.clone(),
372        );
373        Ok(RequestEvent::Complete(resp))
374    }
375
376    pub fn wants_read(&self) -> bool {
377        matches!(
378            self.state,
379            RequestState::ReadingStatus
380                | RequestState::ReadingHeaders
381                | RequestState::ReadingBody
382                | RequestState::Connecting
383        )
384    }
385
386    pub fn wants_write(&self) -> bool {
387        matches!(self.state, RequestState::WritingRequest)
388    }
389}