1#![warn(missing_debug_implementations)]
2
3use chunked_transfer::Decoder;
4use core::convert::TryFrom;
5use http::{Request, Version};
6use native_tls::TlsConnector;
7use std::io::{self, prelude::*};
8use std::net::TcpStream;
9use std::sync::Arc;
10
11#[derive(Debug)]
12pub enum Error {
13 FailedToConnect,
14 FailedToHandshake,
15 FailedToWrite,
16 StreamBroken,
17 WouldBlock,
19
20 IO(io::Error),
21 Request(RequestError),
22 Response(ResponseError),
23}
24
25impl From<io::Error> for Error {
26 fn from(error: io::Error) -> Self {
27 match error.kind() {
28 io::ErrorKind::WouldBlock => Error::WouldBlock,
29 _ => Error::IO(error),
30 }
31 }
32}
33
34pub type Result<T> = std::result::Result<T, Error>;
35#[derive(Debug)]
36pub enum RequestError {
37 FailedToGetIP,
38 NoHost,
39 FailedToConstructRequest(http::Error),
40 FailedToSetNoDelay,
41}
42#[derive(Debug)]
43pub enum ResponseError {
44 InvalidHeaderName,
45 InvalidHeaderValue,
46 InvalidHeaderUtf8,
47 InvalidStatusCode,
48 FailedToConstructResponse,
49 RedirectMissingLocation,
50 RedirectBrokenLocation,
51}
52
53#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
54pub enum Content {
55 Body,
56 Header,
57 Both,
58}
59#[derive(Debug, Ord, PartialOrd, Eq, PartialEq)]
60pub struct Config {
61 redirect_policy: RedirectPolicy,
62 header: Content,
63}
64impl Default for Config {
65 fn default() -> Self {
66 Config {
67 redirect_policy: RedirectPolicy::default(),
68 header: Content::Both,
69 }
70 }
71}
72impl Config {
73 pub fn no_header() -> Self {
74 Config {
75 redirect_policy: RedirectPolicy::default(),
76 header: Content::Body,
77 }
78 }
79}
80#[derive(Debug, Ord, PartialOrd, Eq, PartialEq)]
81pub enum RedirectPolicy {
82 Stay,
83 Max(u32),
84 Continue,
85}
86impl Default for RedirectPolicy {
87 fn default() -> Self {
88 RedirectPolicy::Max(10)
89 }
90}
91
92#[derive(Debug)]
93enum Connector {
94 Raw(TcpStream),
95 TLS(native_tls::TlsStream<TcpStream>),
96}
97impl Connector {
98 pub fn set_read_timeout(
99 &mut self,
100 dur: std::option::Option<std::time::Duration>,
101 ) -> io::Result<()> {
102 match self {
103 Self::Raw(stream) => stream.set_read_timeout(dur),
104 Self::TLS(tls_stream) => tls_stream.get_mut().set_read_timeout(dur),
105 }
106 }
107}
108impl Write for Connector {
109 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
110 match self {
111 Self::Raw(stream) => stream.write(buf),
112 Self::TLS(tls_stream) => tls_stream.write(buf),
113 }
114 }
115 fn flush(&mut self) -> io::Result<()> {
116 match self {
117 Self::Raw(stream) => stream.flush(),
118 Self::TLS(tls_stream) => tls_stream.flush(),
119 }
120 }
121}
122impl Read for Connector {
123 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
124 match self {
125 Self::Raw(stream) => stream.read(buf),
126 Self::TLS(tls_stream) => tls_stream.read(buf),
127 }
128 }
129}
130
131#[derive(Debug)]
132pub struct Client {
133 stream: Connector,
134 config: Arc<Config>,
135 request: Option<http::Request<Vec<u8>>>,
136 redirects: u32,
137}
138impl Client {
139 pub fn connect(config: Arc<Config>, host: &str, port: u16, use_https: bool) -> Result<Self> {
140 let tcp_stream = match TcpStream::connect(format!("{}:{}", host, port)) {
141 Ok(stream) => stream,
142 Err(err) => {
143 return Err(Error::IO(err));
144 }
145 };
146
147 let stream = if use_https {
148 let connector = match TlsConnector::new() {
149 Ok(conn) => conn,
150 Err(..) => {
151 return Err(Error::FailedToConnect);
152 }
153 };
154 Connector::TLS(match connector.connect(host, tcp_stream) {
155 Ok(stream) => stream,
156 Err(..) => {
157 return Err(Error::FailedToHandshake);
158 }
159 })
160 } else {
161 Connector::Raw(tcp_stream)
162 };
163 Ok(Self {
164 config,
165 stream,
166 request: None,
167 redirects: 0,
168 })
169 }
170 pub fn request(&mut self, request: Request<Vec<u8>>) -> Result<()> {
171 self.request = Some(request);
172 self._request()
173 }
174 fn _request(&mut self) -> Result<()> {
177 let request = self.request.as_ref().unwrap();
178 let uri = match request.uri().path() {
179 uri if !uri.is_empty() => uri,
180 _ => "/",
181 };
182 let domain = match request.uri().host() {
183 Some(dom) => dom,
184 None => {
185 return Err(Error::Request(RequestError::NoHost));
186 }
187 };
188 let method = request.method().as_str();
189 let version = match request.version() {
190 Version::HTTP_09 => "HTTP/0.9",
191 Version::HTTP_10 => "HTTP/1.0",
192 Version::HTTP_11 => "HTTP/1.1",
193 Version::HTTP_2 => "HTTP/2.0",
194 Version::HTTP_3 => "HTTP/3.0",
195 _ => "HTTP/1.1",
196 };
197 let mut http_req = Vec::new();
198 http_req.extend(
199 format!(
200 "{} {} {}\r\n\
201 Host: {}\r\n\
202 Connection: keep-alive\r\n\
203 Accept-Encoding: identity\r\n",
204 method, uri, version, domain,
205 )
206 .as_bytes(),
207 );
208 for (name, value) in request.headers().iter() {
210 http_req.extend(name.as_str().as_bytes());
211 http_req.extend(b": ");
212 http_req.extend(value.as_bytes());
213 http_req.extend(LINE_ENDING);
214 }
215 http_req.extend(LINE_ENDING);
216
217 match self
218 .stream
219 .write_all(&http_req[..])
220 .and(self.stream.flush())
221 {
222 Ok(()) => (),
223 Err(err) => {
224 return Err(Error::IO(err));
225 }
226 };
227 self.stream
229 .set_read_timeout(Some(std::time::Duration::from_millis(100)))?;
230 Ok(())
231 }
232 fn _handle(&mut self) -> Result<(Vec<u8>, usize, Vec<u8>, u16, http::HeaderMap)> {
233 let mut bytes = Self::read_to_vec(&mut self.stream)?;
234
235 let mut version = Vec::new();
236 let mut status_code = Vec::new();
237 let mut reason_phrase = Vec::new();
238 let mut headers = http::HeaderMap::with_capacity(32);
239 let mut key = Vec::new();
240 let mut value = Vec::new();
241
242 let mut segment = 0;
243 let mut newlines = 0;
244
245 let mut last_byte = 0;
246
247 for byte in bytes.iter() {
249 last_byte += 1;
250 if *byte == 32 {
251 if segment != -1 {
253 segment += 1;
254 continue;
255 }
256 }
257 if *byte == 10 {
258 newlines += 1;
260 segment = -2;
261 if !key.is_empty() || !value.is_empty() {
262 headers.insert(
263 match http::header::HeaderName::from_bytes(&key) {
264 Ok(name) => name,
265 Err(..) => {
266 return Err(Error::Response(ResponseError::InvalidHeaderName));
267 }
268 },
269 match http::header::HeaderValue::from_bytes(&value) {
270 Ok(value) => value,
271 Err(..) => {
272 return Err(Error::Response(ResponseError::InvalidHeaderValue));
273 }
274 },
275 );
276 key.clear();
277 value.clear();
278 }
279 if newlines == 2 {
281 break;
282 }
283 continue;
284 } else if *byte != 13 {
285 newlines = 0;
286 }
287 if *byte == 13 || (*byte == 58 && segment != -1) {
289 continue;
290 }
291
292 match segment {
293 0 => version.push(*byte),
294 1 => status_code.push(*byte),
295 2 => reason_phrase.push(*byte),
296 -2 => key.push(*byte),
297 -1 => value.push(*byte),
298 _ => {}
299 };
300 }
301
302 if headers
303 .get("transfer-encoding")
304 .and_then(|header| header.to_str().ok())
305 .map(|string| string.to_ascii_lowercase() == "chunked")
306 .unwrap_or(false)
307 {
308 let mut buffer = Vec::with_capacity(bytes.len());
309 buffer.extend(&bytes[..last_byte]);
310 let mut decoder = Decoder::new(&bytes[last_byte..]);
311 if let Ok(..) = decoder.read_to_end(&mut buffer) {
312 bytes = buffer;
313 }
314 }
315
316 let status = match String::from_utf8(status_code) {
317 Ok(s) => match s.parse::<u16>() {
318 Err(..) => {
319 return Err(Error::Response(ResponseError::InvalidStatusCode));
320 }
321 Ok(parsed) => parsed,
322 },
323 Err(..) => {
324 return Err(Error::Response(ResponseError::InvalidHeaderUtf8));
325 }
326 };
327
328 if status >= 300
329 && status < 400
330 && status != 305
331 && headers.contains_key("location")
332 && (match self.config.redirect_policy {
333 RedirectPolicy::Stay => false,
334 RedirectPolicy::Max(redirects) => self.redirects < redirects,
335 RedirectPolicy::Continue => true,
336 })
337 {
338 self.redirects += 1;
339 let mutable_uri = match &mut self.request {
340 Some(request) => request.uri_mut(),
341 None => unreachable!(),
342 };
343 *mutable_uri = match headers.get("location") {
344 Some(location) => match http::Uri::try_from(match location.to_str() {
345 Ok(location) => location,
346 Err(..) => {
347 return Err(Error::Response(ResponseError::RedirectBrokenLocation));
348 }
349 }) {
350 Ok(location) => location,
351 Err(..) => {
352 return Err(Error::Response(ResponseError::RedirectBrokenLocation));
353 }
354 },
355 None => {
356 return Err(Error::Response(ResponseError::RedirectMissingLocation));
357 }
358 };
359 self._request()?;
360 return Err(Error::WouldBlock);
361 }
362 Ok((bytes, last_byte, version, status, headers))
363 }
364
365 pub(crate) fn read_to_vec(reader: &mut dyn Read) -> Result<Vec<u8>> {
366 const BYTES_ADD: usize = 8 * 1024;
367
368 let mut bytes = Vec::with_capacity(BYTES_ADD);
369 unsafe { bytes.set_len(BYTES_ADD) };
370 let mut began_recieving = false;
371 let mut read = 0;
372 loop {
373 match reader.read(&mut bytes[read..]) {
374 Err(err) if err.kind() == io::ErrorKind::Interrupted => {
375 std::thread::yield_now();
376 continue;
377 }
378 Err(err)
379 if err.kind() == io::ErrorKind::WouldBlock
380 || err.kind() == io::ErrorKind::TimedOut =>
381 {
382 if began_recieving {
383 break;
384 } else {
385 std::thread::yield_now();
386 continue;
387 }
388 }
389
390 Err(err) => {
391 return Err(Error::IO(err));
392 }
393 Ok(just_read) => {
394 began_recieving = true;
395 read += just_read;
396
397 if read == bytes.len() {
398 bytes.reserve(BYTES_ADD);
399 unsafe { bytes.set_len(bytes.capacity()) };
400 }
401 }
402 };
403 }
404 unsafe { bytes.set_len(read) };
405 Ok(bytes)
406 }
407
408 pub fn done(&mut self) -> bool {
409 let result = self.stream.read(&mut [0; 0]).is_ok();
410 result
411 }
412 pub fn wait(&mut self) -> Result<http::Response<Vec<u8>>> {
413 let mut response = http::Response::builder();
414 let (bytes, last_byte, version, status, headers) = self._handle()?;
415 response = response
416 .version(match &version[..] {
417 b"HTTP/0.9" => Version::HTTP_09,
418 b"HTTP/1.0" => Version::HTTP_10,
419 b"HTTP/1.1" => Version::HTTP_11,
420 b"HTTP/2.0" => Version::HTTP_2,
421 b"HTTP/3.0" => Version::HTTP_3,
422 _ => Version::HTTP_11,
423 })
424 .status(status);
425
426 for (name, value) in headers.iter() {
427 response = response.header(name, value);
428 }
429
430 let mut body: Vec<u8> = bytes.into_iter().skip(last_byte).collect();
431 body.truncate(body.len());
432
433 match response.body(body) {
434 Ok(res) => Ok(res),
435 Err(..) => Err(Error::Response(ResponseError::FailedToConstructResponse)),
436 }
437 }
438 pub fn follow_redirects(&mut self) -> Result<http::Response<Vec<u8>>> {
439 loop {
440 match self.wait() {
441 Err(Error::WouldBlock) => continue,
442 Err(err) => return Err(err),
443 Ok(result) => return Ok(result),
444 }
445 }
446 }
447 pub fn write(&mut self, writer: &mut dyn Write) -> Result<()> {
448 let (bytes, last_byte, _, _, _) = self._handle()?;
449
450 let start_at = match self.config.header {
451 Content::Body => last_byte,
452 _ => 0,
453 };
454 let end_at = match self.config.header {
455 Content::Header => last_byte,
456 _ => bytes.len(),
457 };
458
459 writer
460 .write_all(&bytes[start_at..end_at])
461 .map_err(|err| err.into())
462 }
463 pub fn follow_redirects_write(&mut self, writer: &mut dyn Write) -> Result<()> {
464 loop {
465 match self.write(writer) {
466 Err(Error::WouldBlock) => continue,
467 Err(err) => return Err(err),
468 Ok(result) => return Ok(result),
469 }
470 }
471 }
472}
473
474const LINE_ENDING: &[u8] = b"\r\n";
475
476pub fn get(url: &str, user_agent: &str) -> Result<Client> {
477 let req = match Request::get(url)
478 .header("User-Agent", user_agent)
479 .body(Vec::new())
480 {
481 Ok(req) => req,
482 Err(err) => return Err(Error::Request(RequestError::FailedToConstructRequest(err))),
483 };
484 let host = match req.uri().host() {
485 Some(host) => host,
486 None => return Err(Error::Request(RequestError::NoHost)),
487 };
488 let port = req.uri().port_u16().unwrap_or(443);
489 let mut result = Client::connect(
490 Arc::new(Config::default()),
491 host,
492 port,
493 if port == 443 { true } else { false },
494 )?;
495 result.request(req)?;
496 Ok(result)
497}
498pub fn request(request: http::Request<Vec<u8>>, config: Config) -> Result<Client> {
499 let host = match request.uri().host() {
500 Some(host) => host,
501 None => return Err(Error::Request(RequestError::NoHost)),
502 };
503 let port = request.uri().port_u16().unwrap_or(443);
504 let mut result = Client::connect(
505 Arc::new(config),
506 host,
507 port,
508 if port == 443 { true } else { false },
509 )?;
510 result.request(request)?;
511 Ok(result)
512}