1use std::io::{ErrorKind, Read, Write};
2use std::os::unix::io::RawFd;
3
4use crate::http::body_reader::{BodyMode, BodyReader};
5use crate::http::compression::Decompressor;
6use crate::http::conn::Connection;
7use crate::http::headers::Headers;
8use crate::http::method::HttpMethod;
9use crate::http::response::HttpResponse;
10use crate::http::status::HttpStatus;
11use crate::http::url::Url;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum RequestState {
15 Connecting,
16 WritingRequest,
17 ReadingStatus,
18 ReadingHeaders,
19 ReadingBody,
20 Done,
21}
22
23#[derive(Debug)]
24pub enum RequestEvent {
25 NeedRead,
26 NeedWrite,
27 Complete(HttpResponse),
28 Error(String),
29}
30
31pub struct HttpRequest {
32 pub url: Url,
33 pub state: RequestState,
34 pub method: HttpMethod,
35 pub req_headers: Headers,
36 pub body: Option<Vec<u8>>,
37 pub conn: Option<Connection>,
38 write_buf: Vec<u8>,
39 write_pos: usize,
40 read_buf: [u8; 8192],
41 read_data: Vec<u8>,
42 parse_buf: Vec<u8>,
43 resp_headers: Option<Headers>,
44 resp_status: Option<(u16, String)>,
45 body_reader: Option<BodyReader>,
46 decompressor: Option<Decompressor>,
47 pub redirect_count: u32,
48 max_redirects: u32,
49}
50
51impl HttpRequest {
52 pub fn dummy() -> Self {
53 HttpRequest {
54 url: Url {
55 scheme: String::new(),
56 host: String::new(),
57 port: 0,
58 path: String::new(),
59 query: String::new(),
60 full: String::new(),
61 },
62 state: RequestState::Done,
63 method: HttpMethod::GET,
64 req_headers: Headers::new(),
65 body: None,
66 conn: None,
67 write_buf: Vec::new(),
68 write_pos: 0,
69 read_buf: [0u8; 8192],
70 read_data: Vec::new(),
71 parse_buf: Vec::new(),
72 resp_headers: None,
73 resp_status: None,
74 body_reader: None,
75 decompressor: None,
76 redirect_count: 0,
77 max_redirects: 20,
78 }
79 }
80
81 pub fn new(url: Url, method: HttpMethod, req_headers: Headers, body: Option<Vec<u8>>) -> Self {
82 HttpRequest {
83 url,
84 state: RequestState::Connecting,
85 method,
86 req_headers,
87 body,
88 conn: None,
89 write_buf: Vec::new(),
90 write_pos: 0,
91 read_buf: [0u8; 8192],
92 read_data: Vec::new(),
93 parse_buf: Vec::new(),
94 resp_headers: None,
95 resp_status: None,
96 body_reader: None,
97 decompressor: None,
98 redirect_count: 0,
99 max_redirects: 20,
100 }
101 }
102
103 pub fn fd(&self) -> Option<RawFd> {
104 self.conn.as_ref().map(|c| c.raw_fd())
105 }
106
107 pub fn max_redirects(&self) -> u32 {
108 self.max_redirects
109 }
110
111 pub fn try_advance(&mut self) -> Result<RequestEvent, String> {
112 loop {
113 match self.state {
114 RequestState::WritingRequest => {
115 let conn = self.conn.as_mut().unwrap();
116 let remaining = &self.write_buf[self.write_pos..];
117 if remaining.is_empty() {
118 self.state = RequestState::ReadingStatus;
119 self.read_data.clear();
120 continue;
121 }
122 match conn.write(remaining) {
123 Ok(n) => {
124 self.write_pos += n;
125 if self.write_pos >= self.write_buf.len() {
126 self.state = RequestState::ReadingStatus;
127 self.read_data.clear();
128 }
129 }
130 Err(e) if e.kind() == ErrorKind::WouldBlock => {
131 return Ok(RequestEvent::NeedWrite);
132 }
133 Err(e) => {
134 return Err(format!("write error: {e}"));
135 }
136 }
137 }
138
139 RequestState::ReadingStatus => {
140 let conn = self.conn.as_mut().unwrap();
141 match conn.read(&mut self.read_buf) {
142 Ok(0) => {
143 return Err("connection closed while reading status".into());
144 }
145 Ok(n) => {
146 self.read_data.extend_from_slice(&self.read_buf[..n]);
147 if let Some(headers_end) =
148 self.read_data.windows(4).position(|w| w == b"\r\n\r\n")
149 {
150 let line_end = self
151 .read_data
152 .windows(2)
153 .position(|w| w == b"\r\n")
154 .unwrap_or(0);
155 let status =
156 self.parse_status_from_bytes(&self.read_data[..line_end])?;
157 self.resp_status = Some(status);
158 let header_bytes = &self.read_data[line_end + 2..headers_end + 4];
159 let (headers, _) = Headers::from_bytes(header_bytes)?;
160 let body_mode = self.determine_body_mode(&headers);
161 self.resp_headers = Some(headers.clone());
162 let enc = headers.get("content-encoding").map(|s| s.to_string());
163 self.decompressor = Some(Decompressor::new(enc.as_deref()));
164 self.body_reader = Some(BodyReader::new(body_mode));
165 self.parse_buf = self.read_data[headers_end + 4..].to_vec();
166 self.state = RequestState::ReadingBody;
167 continue;
168 }
169 }
170 Err(e) if e.kind() == ErrorKind::WouldBlock => {
171 return Ok(RequestEvent::NeedRead);
172 }
173 Err(e) => {
174 return Err(format!("read error: {e}"));
175 }
176 }
177 }
178
179 RequestState::ReadingHeaders => {
180 let conn = self.conn.as_mut().unwrap();
181 match conn.read(&mut self.read_buf) {
182 Ok(0) => {
183 return Err("connection closed while reading headers".into());
184 }
185 Ok(n) => {
186 self.read_data.extend_from_slice(&self.read_buf[..n]);
187 if let Some(headers_end) =
188 self.read_data.windows(4).position(|w| w == b"\r\n\r\n")
189 {
190 let header_data = &self.read_data[..headers_end + 2];
191 let (headers, _) = Headers::from_bytes(header_data)?;
192 let body_mode = self.determine_body_mode(&headers);
193 self.resp_headers = Some(headers.clone());
194 let enc = headers.get("content-encoding").map(|s| s.to_string());
195 self.decompressor = Some(Decompressor::new(enc.as_deref()));
196 self.body_reader = Some(BodyReader::new(body_mode));
197 self.parse_buf = self.read_data[headers_end + 4..].to_vec();
198 self.state = RequestState::ReadingBody;
199 continue;
200 }
201 }
202 Err(e) if e.kind() == ErrorKind::WouldBlock => {
203 return Ok(RequestEvent::NeedRead);
204 }
205 Err(e) => {
206 return Err(format!("read error: {e}"));
207 }
208 }
209 }
210
211 RequestState::ReadingBody => {
212 let conn = self.conn.as_mut().unwrap();
213
214 if !self.parse_buf.is_empty() {
215 let data = std::mem::take(&mut self.parse_buf);
216 let _ = self.body_reader.as_mut().unwrap().feed(&data)?;
217 }
218
219 if self.body_reader.as_ref().unwrap().is_done() {
220 return self.finalize_response();
221 }
222
223 match conn.read(&mut self.read_buf) {
224 Ok(0) => {
225 self.body_reader.as_mut().unwrap().finish();
226 return self.finalize_response();
227 }
228 Ok(n) => {
229 let _ = self
230 .body_reader
231 .as_mut()
232 .unwrap()
233 .feed(&self.read_buf[..n])?;
234 if self.body_reader.as_ref().unwrap().is_done() {
235 return self.finalize_response();
236 }
237 }
238 Err(e) if e.kind() == ErrorKind::WouldBlock => {
239 return Ok(RequestEvent::NeedRead);
240 }
241 Err(e) => {
242 return Err(format!("read error: {e}"));
243 }
244 }
245 }
246
247 RequestState::Done => {
248 return Err("request already completed".into());
249 }
250
251 RequestState::Connecting => {
252 unreachable!("Connecting should have resolved to another state");
253 }
254 }
255 }
256 }
257
258 fn parse_status_from_bytes(&self, line: &[u8]) -> Result<(u16, String), String> {
259 if line.len() < 12 || !line.starts_with(b"HTTP/") {
260 return Err(format!(
261 "malformed status line: {:?}",
262 String::from_utf8_lossy(line)
263 ));
264 }
265 let http_end = line
266 .iter()
267 .position(|&b| b == b' ')
268 .ok_or("no space after HTTP/x.y")?;
269 let status_start = http_end + 1;
270 if status_start >= line.len() {
271 return Err("missing status code".into());
272 }
273 let status_end = line[status_start..]
274 .iter()
275 .position(|&b| b == b' ')
276 .map(|p| status_start + p)
277 .unwrap_or(line.len());
278 let code_str = String::from_utf8_lossy(&line[status_start..status_end]);
279 let code: u16 = code_str
280 .parse()
281 .map_err(|e| format!("invalid status code: {e}"))?;
282 let reason = if status_end < line.len() {
283 String::from_utf8_lossy(&line[status_end + 1..])
284 .trim()
285 .to_string()
286 } else {
287 String::new()
288 };
289 Ok((code, reason))
290 }
291
292 fn determine_body_mode(&self, headers: &Headers) -> BodyMode {
293 if let Some(len_str) = headers.get("content-length") {
294 if let Ok(len) = len_str.parse::<usize>() {
295 return BodyMode::ContentLength(len);
296 }
297 }
298 if let Some(te) = headers.get("transfer-encoding") {
299 if te.contains("chunked") {
300 return BodyMode::Chunked;
301 }
302 }
303 BodyMode::ConnectionClose
304 }
305
306 pub fn build_request(&mut self) {
307 let target = self.url.request_target();
308 let mut buf = Vec::new();
309 buf.extend_from_slice(self.method.as_str().as_bytes());
310 buf.extend_from_slice(b" ");
311 buf.extend_from_slice(target.as_bytes());
312 buf.extend_from_slice(b" HTTP/1.1\r\n");
313 buf.extend_from_slice(self.req_headers.to_request_bytes().as_ref());
314 if !self.req_headers.contains("host") {
315 buf.extend_from_slice(b"Host: ");
316 buf.extend_from_slice(self.url.host.as_bytes());
317 if (self.url.port != 80 && self.url.port != 443)
318 || (self.url.port == 80 && self.url.is_tls())
319 || (self.url.port == 443 && !self.url.is_tls())
320 {
321 buf.extend_from_slice(b":");
322 buf.extend_from_slice(self.url.port.to_string().as_bytes());
323 }
324 buf.extend_from_slice(b"\r\n");
325 }
326 if !self.req_headers.contains("user-agent") {
327 buf.extend_from_slice(b"User-Agent: pipa/0.1\r\n");
328 }
329 if !self.req_headers.contains("accept") {
330 buf.extend_from_slice(b"Accept: */*\r\n");
331 }
332 if let Some(ref body) = self.body {
333 if !self.req_headers.contains("content-length") {
334 buf.extend_from_slice(b"Content-Length: ");
335 buf.extend_from_slice(body.len().to_string().as_bytes());
336 buf.extend_from_slice(b"\r\n");
337 }
338 buf.extend_from_slice(b"\r\n");
339 buf.extend_from_slice(body);
340 } else {
341 buf.extend_from_slice(b"\r\n");
342 }
343 self.write_buf = buf;
344 self.write_pos = 0;
345 }
346
347 fn finalize_response(&mut self) -> Result<RequestEvent, String> {
348 self.state = RequestState::Done;
349 let (code, status_text) = self.resp_status.take().unwrap_or((0, String::new()));
350 let headers = self.resp_headers.take().unwrap_or_default();
351 let mut body_reader = self
352 .body_reader
353 .take()
354 .unwrap_or_else(|| BodyReader::new(BodyMode::None));
355 let body = body_reader.take_body();
356
357 let decompressed = if let Some(ref mut decomp) = self.decompressor {
358 decomp.decompress(&body)?
359 } else {
360 body
361 };
362
363 let mut final_reader = BodyReader::new(BodyMode::ContentLength(decompressed.len()));
364 let _ = final_reader.feed(&decompressed)?;
365
366 let resp = HttpResponse::new(
367 HttpStatus(code),
368 status_text,
369 headers,
370 final_reader,
371 self.url.full.clone(),
372 );
373 Ok(RequestEvent::Complete(resp))
374 }
375
376 pub fn wants_read(&self) -> bool {
377 matches!(
378 self.state,
379 RequestState::ReadingStatus
380 | RequestState::ReadingHeaders
381 | RequestState::ReadingBody
382 | RequestState::Connecting
383 )
384 }
385
386 pub fn wants_write(&self) -> bool {
387 matches!(self.state, RequestState::WritingRequest)
388 }
389}