1use std::io::{ErrorKind, Read, Write};
2use std::os::unix::io::RawFd;
3use std::sync::mpsc;
4
5use crate::http::body_reader::{BodyMode, BodyReader};
6use crate::http::compression::Decompressor;
7use crate::http::conn::Connection;
8use crate::http::headers::Headers;
9use crate::http::method::HttpMethod;
10use crate::http::response::HttpResponse;
11use crate::http::status::HttpStatus;
12use crate::http::url::Url;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum RequestState {
16 Resolving,
17 Connecting,
18 WritingRequest,
19 ReadingStatus,
20 ReadingHeaders,
21 ReadingBody,
22 Done,
23}
24
25#[derive(Debug)]
26pub enum RequestEvent {
27 NeedRead,
28 NeedWrite,
29 Complete(HttpResponse),
30 Error(String),
31}
32
33pub struct HttpRequest {
34 pub url: Url,
35 pub state: RequestState,
36 pub method: HttpMethod,
37 pub req_headers: Headers,
38 pub body: Option<Vec<u8>>,
39 pub conn: Option<Connection>,
40 connect_rx: Option<mpsc::Receiver<Result<Connection, String>>>,
41 write_buf: Vec<u8>,
42 write_pos: usize,
43 read_buf: [u8; 8192],
44 read_data: Vec<u8>,
45 parse_buf: Vec<u8>,
46 resp_headers: Option<Headers>,
47 resp_status: Option<(u16, String)>,
48 body_reader: Option<BodyReader>,
49 decompressor: Option<Decompressor>,
50 pub redirect_count: u32,
51 max_redirects: u32,
52}
53
54impl HttpRequest {
55 pub fn dummy() -> Self {
56 use crate::http::method::HttpMethod;
57 use crate::http::url::Url;
58 HttpRequest {
59 url: Url {
60 scheme: String::new(),
61 host: String::new(),
62 port: 0,
63 path: String::new(),
64 query: String::new(),
65 full: String::new(),
66 },
67 state: RequestState::Done,
68 method: HttpMethod::GET,
69 req_headers: Headers::new(),
70 body: None,
71 conn: None,
72 connect_rx: None,
73 write_buf: Vec::new(),
74 write_pos: 0,
75 read_buf: [0u8; 8192],
76 read_data: Vec::new(),
77 parse_buf: Vec::new(),
78 resp_headers: None,
79 resp_status: None,
80 body_reader: None,
81 decompressor: None,
82 redirect_count: 0,
83 max_redirects: 20,
84 }
85 }
86
87 pub fn new(url: Url, method: HttpMethod, req_headers: Headers, body: Option<Vec<u8>>) -> Self {
88 HttpRequest {
89 url,
90 state: RequestState::Resolving,
91 method,
92 req_headers,
93 body,
94 conn: None,
95 connect_rx: None,
96 write_buf: Vec::new(),
97 write_pos: 0,
98 read_buf: [0u8; 8192],
99 read_data: Vec::new(),
100 parse_buf: Vec::new(),
101 resp_headers: None,
102 resp_status: None,
103 body_reader: None,
104 decompressor: None,
105 redirect_count: 0,
106 max_redirects: 20,
107 }
108 }
109
110 pub fn fd(&self) -> Option<RawFd> {
111 self.conn.as_ref().map(|c| c.raw_fd())
112 }
113
114 pub fn max_redirects(&self) -> u32 {
115 self.max_redirects
116 }
117
118 pub fn set_connect_rx(&mut self, rx: mpsc::Receiver<Result<Connection, String>>) {
119 self.connect_rx = Some(rx);
120 self.state = RequestState::Resolving;
121 }
122
123 pub fn try_advance(&mut self) -> Result<RequestEvent, String> {
124 loop {
125 match self.state {
126 RequestState::Resolving => match self.connect_rx.as_ref().unwrap().try_recv() {
127 Ok(result) => {
128 let conn = result?;
129 conn.set_nonblocking(true)?;
130 self.conn = Some(conn);
131 self.build_request();
132 self.state = RequestState::WritingRequest;
133 }
134 Err(mpsc::TryRecvError::Empty) => {
135 return Ok(RequestEvent::NeedRead);
136 }
137 Err(mpsc::TryRecvError::Disconnected) => {
138 return Err("connect thread disconnected".into());
139 }
140 },
141
142 RequestState::WritingRequest => {
143 let conn = self.conn.as_mut().unwrap();
144 let remaining = &self.write_buf[self.write_pos..];
145 if remaining.is_empty() {
146 self.state = RequestState::ReadingStatus;
147 self.read_data.clear();
148 continue;
149 }
150 match conn.write(remaining) {
151 Ok(n) => {
152 self.write_pos += n;
153 if self.write_pos >= self.write_buf.len() {
154 self.state = RequestState::ReadingStatus;
155 self.read_data.clear();
156 }
157 }
158 Err(e) if e.kind() == ErrorKind::WouldBlock => {
159 return Ok(RequestEvent::NeedWrite);
160 }
161 Err(e) => {
162 return Err(format!("write error: {e}"));
163 }
164 }
165 }
166
167 RequestState::ReadingStatus => {
168 let conn = self.conn.as_mut().unwrap();
169 match conn.read(&mut self.read_buf) {
170 Ok(0) => {
171 return Err("connection closed while reading status".into());
172 }
173 Ok(n) => {
174 self.read_data.extend_from_slice(&self.read_buf[..n]);
175 if let Some(status_line_end) =
176 self.read_data.windows(4).position(|w| w == b"\r\n\r\n")
177 {
178 let headers_start = status_line_end + 4;
179 if let Some(status) = self.parse_status_line()? {
180 self.resp_status = Some(status);
181 let header_data = &self.read_data[headers_start..];
182 let (headers, _) =
183 Headers::from_bytes(&self.read_data[headers_start..])?;
184 let body_mode = self.determine_body_mode(&headers);
185 self.resp_headers = Some(headers.clone());
186 let enc =
187 headers.get("content-encoding").map(|s| s.to_string());
188 self.decompressor = Some(Decompressor::new(enc.as_deref()));
189 self.body_reader = Some(BodyReader::new(body_mode));
190 self.parse_buf = self.read_data
191 [headers_start + header_data.len()..]
192 .to_vec();
193 self.state = RequestState::ReadingBody;
194 continue;
195 }
196 } else if let Some(line_end) =
197 self.read_data.windows(2).position(|w| w == b"\r\n")
198 {
199 if line_end == self.read_data.len() - 2
200 || self.read_data[line_end + 2..]
201 .windows(4)
202 .any(|w| w == b"\r\n\r\n")
203 {
204 continue;
205 }
206 }
207 }
208 Err(e) if e.kind() == ErrorKind::WouldBlock => {
209 return Ok(RequestEvent::NeedRead);
210 }
211 Err(e) => {
212 return Err(format!("read error: {e}"));
213 }
214 }
215 }
216
217 RequestState::ReadingHeaders => {
218 let conn = self.conn.as_mut().unwrap();
219 match conn.read(&mut self.read_buf) {
220 Ok(0) => {
221 return Err("connection closed while reading headers".into());
222 }
223 Ok(n) => {
224 self.read_data.extend_from_slice(&self.read_buf[..n]);
225 if let Some(headers_end) =
226 self.read_data.windows(4).position(|w| w == b"\r\n\r\n")
227 {
228 let header_data = &self.read_data[..headers_end + 2];
229 let (headers, _) = Headers::from_bytes(header_data)?;
230 let body_mode = self.determine_body_mode(&headers);
231 self.resp_headers = Some(headers.clone());
232 let enc = headers.get("content-encoding").map(|s| s.to_string());
233 self.decompressor = Some(Decompressor::new(enc.as_deref()));
234 self.body_reader = Some(BodyReader::new(body_mode));
235 self.parse_buf = self.read_data[headers_end + 4..].to_vec();
236 self.state = RequestState::ReadingBody;
237 continue;
238 }
239 }
240 Err(e) if e.kind() == ErrorKind::WouldBlock => {
241 return Ok(RequestEvent::NeedRead);
242 }
243 Err(e) => {
244 return Err(format!("read error: {e}"));
245 }
246 }
247 }
248
249 RequestState::ReadingBody => {
250 let conn = self.conn.as_mut().unwrap();
251
252 if !self.parse_buf.is_empty() {
253 let data = std::mem::take(&mut self.parse_buf);
254 let _ = self.body_reader.as_mut().unwrap().feed(&data)?;
255 }
256
257 if self.body_reader.as_ref().unwrap().is_done() {
258 return self.finalize_response();
259 }
260
261 match conn.read(&mut self.read_buf) {
262 Ok(0) => {
263 self.body_reader.as_mut().unwrap().finish();
264 return self.finalize_response();
265 }
266 Ok(n) => {
267 let _ = self
268 .body_reader
269 .as_mut()
270 .unwrap()
271 .feed(&self.read_buf[..n])?;
272 if self.body_reader.as_ref().unwrap().is_done() {
273 return self.finalize_response();
274 }
275 }
276 Err(e) if e.kind() == ErrorKind::WouldBlock => {
277 return Ok(RequestEvent::NeedRead);
278 }
279 Err(e) => {
280 return Err(format!("read error: {e}"));
281 }
282 }
283 }
284
285 RequestState::Done => {
286 return Err("request already completed".into());
287 }
288
289 RequestState::Connecting => {
290 unreachable!("Connecting should have resolved to another state");
291 }
292 }
293 }
294 }
295
296 fn parse_status_line(&self) -> Result<Option<(u16, String)>, String> {
297 let line_end = self.read_data.windows(2).position(|w| w == b"\r\n");
298 match line_end {
299 Some(end) => {
300 let line = &self.read_data[..end];
301 if line.len() < 12 || !line.starts_with(b"HTTP/") {
302 return Err(format!(
303 "malformed status line: {:?}",
304 String::from_utf8_lossy(line)
305 ));
306 }
307 let http_end = line
308 .iter()
309 .position(|&b| b == b' ')
310 .ok_or("no space after HTTP/x.y")?;
311 let status_start = http_end + 1;
312 if status_start >= line.len() {
313 return Err("missing status code".into());
314 }
315 let status_end = line[status_start..]
316 .iter()
317 .position(|&b| b == b' ')
318 .map(|p| status_start + p)
319 .unwrap_or(line.len());
320 let code_str = String::from_utf8_lossy(&line[status_start..status_end]);
321 let code: u16 = code_str
322 .parse()
323 .map_err(|e| format!("invalid status code: {e}"))?;
324 let reason = if status_end < line.len() {
325 String::from_utf8_lossy(&line[status_end + 1..])
326 .trim()
327 .to_string()
328 } else {
329 String::new()
330 };
331 Ok(Some((code, reason)))
332 }
333 None => Ok(None),
334 }
335 }
336
337 fn determine_body_mode(&self, headers: &Headers) -> BodyMode {
338 if let Some(len_str) = headers.get("content-length") {
339 if let Ok(len) = len_str.parse::<usize>() {
340 return BodyMode::ContentLength(len);
341 }
342 }
343 if let Some(te) = headers.get("transfer-encoding") {
344 if te.contains("chunked") {
345 return BodyMode::Chunked;
346 }
347 }
348 BodyMode::ConnectionClose
349 }
350
351 fn build_request(&mut self) {
352 let target = self.url.request_target();
353 let mut buf = Vec::new();
354 buf.extend_from_slice(self.method.as_str().as_bytes());
355 buf.extend_from_slice(b" ");
356 buf.extend_from_slice(target.as_bytes());
357 buf.extend_from_slice(b" HTTP/1.1\r\n");
358 buf.extend_from_slice(self.req_headers.to_request_bytes().as_ref());
359 if !self.req_headers.contains("host") {
360 buf.extend_from_slice(b"Host: ");
361 buf.extend_from_slice(self.url.host.as_bytes());
362 if (self.url.port != 80 && self.url.port != 443)
363 || (self.url.port == 80 && self.url.is_tls())
364 || (self.url.port == 443 && !self.url.is_tls())
365 {
366 buf.extend_from_slice(b":");
367 buf.extend_from_slice(self.url.port.to_string().as_bytes());
368 }
369 buf.extend_from_slice(b"\r\n");
370 }
371 if !self.req_headers.contains("user-agent") {
372 buf.extend_from_slice(b"User-Agent: pipa/0.1\r\n");
373 }
374 if !self.req_headers.contains("accept") {
375 buf.extend_from_slice(b"Accept: */*\r\n");
376 }
377 if let Some(ref body) = self.body {
378 if !self.req_headers.contains("content-length") {
379 buf.extend_from_slice(b"Content-Length: ");
380 buf.extend_from_slice(body.len().to_string().as_bytes());
381 buf.extend_from_slice(b"\r\n");
382 }
383 buf.extend_from_slice(b"\r\n");
384 buf.extend_from_slice(body);
385 } else {
386 buf.extend_from_slice(b"\r\n");
387 }
388 self.write_buf = buf;
389 self.write_pos = 0;
390 }
391
392 fn finalize_response(&mut self) -> Result<RequestEvent, String> {
393 self.state = RequestState::Done;
394 let (code, status_text) = self.resp_status.take().unwrap_or((0, String::new()));
395 let headers = self.resp_headers.take().unwrap_or_default();
396 let mut body_reader = self
397 .body_reader
398 .take()
399 .unwrap_or_else(|| BodyReader::new(BodyMode::None));
400 let body = body_reader.take_body();
401
402 let decompressed = if let Some(ref mut decomp) = self.decompressor {
403 decomp.decompress(&body)?
404 } else {
405 body
406 };
407
408 let mut final_reader = BodyReader::new(BodyMode::ContentLength(decompressed.len()));
409 let _ = final_reader.feed(&decompressed)?;
410
411 let resp = HttpResponse::new(
412 HttpStatus(code),
413 status_text,
414 headers,
415 final_reader,
416 self.url.full.clone(),
417 );
418 Ok(RequestEvent::Complete(resp))
419 }
420
421 pub fn wants_read(&self) -> bool {
422 matches!(
423 self.state,
424 RequestState::ReadingStatus
425 | RequestState::ReadingHeaders
426 | RequestState::ReadingBody
427 | RequestState::Resolving
428 )
429 }
430
431 pub fn wants_write(&self) -> bool {
432 matches!(self.state, RequestState::WritingRequest)
433 }
434}