tinyquest/
lib.rs

1#![warn(missing_debug_implementations)]
2
3use chunked_transfer::Decoder;
4use core::convert::TryFrom;
5use http::{Request, Version};
6use native_tls::TlsConnector;
7use std::io::{self, prelude::*};
8use std::net::TcpStream;
9use std::sync::Arc;
10
11#[derive(Debug)]
12pub enum Error {
13    FailedToConnect,
14    FailedToHandshake,
15    FailedToWrite,
16    StreamBroken,
17    /// If the client had to make a new request due to the redirect policy.
18    WouldBlock,
19
20    IO(io::Error),
21    Request(RequestError),
22    Response(ResponseError),
23}
24
25impl From<io::Error> for Error {
26    fn from(error: io::Error) -> Self {
27        match error.kind() {
28            io::ErrorKind::WouldBlock => Error::WouldBlock,
29            _ => Error::IO(error),
30        }
31    }
32}
33
34pub type Result<T> = std::result::Result<T, Error>;
35#[derive(Debug)]
36pub enum RequestError {
37    FailedToGetIP,
38    NoHost,
39    FailedToConstructRequest(http::Error),
40    FailedToSetNoDelay,
41}
42#[derive(Debug)]
43pub enum ResponseError {
44    InvalidHeaderName,
45    InvalidHeaderValue,
46    InvalidHeaderUtf8,
47    InvalidStatusCode,
48    FailedToConstructResponse,
49    RedirectMissingLocation,
50    RedirectBrokenLocation,
51}
52
53#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
54pub enum Content {
55    Body,
56    Header,
57    Both,
58}
59#[derive(Debug, Ord, PartialOrd, Eq, PartialEq)]
60pub struct Config {
61    redirect_policy: RedirectPolicy,
62    header: Content,
63}
64impl Default for Config {
65    fn default() -> Self {
66        Config {
67            redirect_policy: RedirectPolicy::default(),
68            header: Content::Both,
69        }
70    }
71}
72impl Config {
73    pub fn no_header() -> Self {
74        Config {
75            redirect_policy: RedirectPolicy::default(),
76            header: Content::Body,
77        }
78    }
79}
80#[derive(Debug, Ord, PartialOrd, Eq, PartialEq)]
81pub enum RedirectPolicy {
82    Stay,
83    Max(u32),
84    Continue,
85}
86impl Default for RedirectPolicy {
87    fn default() -> Self {
88        RedirectPolicy::Max(10)
89    }
90}
91
92#[derive(Debug)]
93enum Connector {
94    Raw(TcpStream),
95    TLS(native_tls::TlsStream<TcpStream>),
96}
97impl Connector {
98    pub fn set_read_timeout(
99        &mut self,
100        dur: std::option::Option<std::time::Duration>,
101    ) -> io::Result<()> {
102        match self {
103            Self::Raw(stream) => stream.set_read_timeout(dur),
104            Self::TLS(tls_stream) => tls_stream.get_mut().set_read_timeout(dur),
105        }
106    }
107}
108impl Write for Connector {
109    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
110        match self {
111            Self::Raw(stream) => stream.write(buf),
112            Self::TLS(tls_stream) => tls_stream.write(buf),
113        }
114    }
115    fn flush(&mut self) -> io::Result<()> {
116        match self {
117            Self::Raw(stream) => stream.flush(),
118            Self::TLS(tls_stream) => tls_stream.flush(),
119        }
120    }
121}
122impl Read for Connector {
123    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
124        match self {
125            Self::Raw(stream) => stream.read(buf),
126            Self::TLS(tls_stream) => tls_stream.read(buf),
127        }
128    }
129}
130
131#[derive(Debug)]
132pub struct Client {
133    stream: Connector,
134    config: Arc<Config>,
135    request: Option<http::Request<Vec<u8>>>,
136    redirects: u32,
137}
138impl Client {
139    pub fn connect(config: Arc<Config>, host: &str, port: u16, use_https: bool) -> Result<Self> {
140        let tcp_stream = match TcpStream::connect(format!("{}:{}", host, port)) {
141            Ok(stream) => stream,
142            Err(err) => {
143                return Err(Error::IO(err));
144            }
145        };
146
147        let stream = if use_https {
148            let connector = match TlsConnector::new() {
149                Ok(conn) => conn,
150                Err(..) => {
151                    return Err(Error::FailedToConnect);
152                }
153            };
154            Connector::TLS(match connector.connect(host, tcp_stream) {
155                Ok(stream) => stream,
156                Err(..) => {
157                    return Err(Error::FailedToHandshake);
158                }
159            })
160        } else {
161            Connector::Raw(tcp_stream)
162        };
163        Ok(Self {
164            config,
165            stream,
166            request: None,
167            redirects: 0,
168        })
169    }
170    pub fn request(&mut self, request: Request<Vec<u8>>) -> Result<()> {
171        self.request = Some(request);
172        self._request()
173    }
174    /// # Panics
175    /// This function will panic if the internal `request` parameter is None.
176    fn _request(&mut self) -> Result<()> {
177        let request = self.request.as_ref().unwrap();
178        let uri = match request.uri().path() {
179            uri if !uri.is_empty() => uri,
180            _ => "/",
181        };
182        let domain = match request.uri().host() {
183            Some(dom) => dom,
184            None => {
185                return Err(Error::Request(RequestError::NoHost));
186            }
187        };
188        let method = request.method().as_str();
189        let version = match request.version() {
190            Version::HTTP_09 => "HTTP/0.9",
191            Version::HTTP_10 => "HTTP/1.0",
192            Version::HTTP_11 => "HTTP/1.1",
193            Version::HTTP_2 => "HTTP/2.0",
194            Version::HTTP_3 => "HTTP/3.0",
195            _ => "HTTP/1.1",
196        };
197        let mut http_req = Vec::new();
198        http_req.extend(
199            format!(
200                "{} {} {}\r\n\
201                Host: {}\r\n\
202                Connection: keep-alive\r\n\
203                Accept-Encoding: identity\r\n",
204                method, uri, version, domain,
205            )
206            .as_bytes(),
207        );
208        // Add headers
209        for (name, value) in request.headers().iter() {
210            http_req.extend(name.as_str().as_bytes());
211            http_req.extend(b": ");
212            http_req.extend(value.as_bytes());
213            http_req.extend(LINE_ENDING);
214        }
215        http_req.extend(LINE_ENDING);
216
217        match self
218            .stream
219            .write_all(&http_req[..])
220            .and(self.stream.flush())
221        {
222            Ok(()) => (),
223            Err(err) => {
224                return Err(Error::IO(err));
225            }
226        };
227        // Not optimal; make it smart and read chunked encoding later!
228        self.stream
229            .set_read_timeout(Some(std::time::Duration::from_millis(100)))?;
230        Ok(())
231    }
232    fn _handle(&mut self) -> Result<(Vec<u8>, usize, Vec<u8>, u16, http::HeaderMap)> {
233        let mut bytes = Self::read_to_vec(&mut self.stream)?;
234
235        let mut version = Vec::new();
236        let mut status_code = Vec::new();
237        let mut reason_phrase = Vec::new();
238        let mut headers = http::HeaderMap::with_capacity(32);
239        let mut key = Vec::new();
240        let mut value = Vec::new();
241
242        let mut segment = 0;
243        let mut newlines = 0;
244
245        let mut last_byte = 0;
246
247        // Parse header
248        for byte in bytes.iter() {
249            last_byte += 1;
250            if *byte == 32 {
251                // Space
252                if segment != -1 {
253                    segment += 1;
254                    continue;
255                }
256            }
257            if *byte == 10 {
258                // Line Feed
259                newlines += 1;
260                segment = -2;
261                if !key.is_empty() || !value.is_empty() {
262                    headers.insert(
263                        match http::header::HeaderName::from_bytes(&key) {
264                            Ok(name) => name,
265                            Err(..) => {
266                                return Err(Error::Response(ResponseError::InvalidHeaderName));
267                            }
268                        },
269                        match http::header::HeaderValue::from_bytes(&value) {
270                            Ok(value) => value,
271                            Err(..) => {
272                                return Err(Error::Response(ResponseError::InvalidHeaderValue));
273                            }
274                        },
275                    );
276                    key.clear();
277                    value.clear();
278                }
279                // If double newline, it's body-time!
280                if newlines == 2 {
281                    break;
282                }
283                continue;
284            } else if *byte != 13 {
285                newlines = 0;
286            }
287            // Filter out CR and colon
288            if *byte == 13 || (*byte == 58 && segment != -1) {
289                continue;
290            }
291
292            match segment {
293                0 => version.push(*byte),
294                1 => status_code.push(*byte),
295                2 => reason_phrase.push(*byte),
296                -2 => key.push(*byte),
297                -1 => value.push(*byte),
298                _ => {}
299            };
300        }
301
302        if headers
303            .get("transfer-encoding")
304            .and_then(|header| header.to_str().ok())
305            .map(|string| string.to_ascii_lowercase() == "chunked")
306            .unwrap_or(false)
307        {
308            let mut buffer = Vec::with_capacity(bytes.len());
309            buffer.extend(&bytes[..last_byte]);
310            let mut decoder = Decoder::new(&bytes[last_byte..]);
311            if let Ok(..) = decoder.read_to_end(&mut buffer) {
312                bytes = buffer;
313            }
314        }
315
316        let status = match String::from_utf8(status_code) {
317            Ok(s) => match s.parse::<u16>() {
318                Err(..) => {
319                    return Err(Error::Response(ResponseError::InvalidStatusCode));
320                }
321                Ok(parsed) => parsed,
322            },
323            Err(..) => {
324                return Err(Error::Response(ResponseError::InvalidHeaderUtf8));
325            }
326        };
327
328        if status >= 300
329            && status < 400
330            && status != 305
331            && headers.contains_key("location")
332            && (match self.config.redirect_policy {
333                RedirectPolicy::Stay => false,
334                RedirectPolicy::Max(redirects) => self.redirects < redirects,
335                RedirectPolicy::Continue => true,
336            })
337        {
338            self.redirects += 1;
339            let mutable_uri = match &mut self.request {
340                Some(request) => request.uri_mut(),
341                None => unreachable!(),
342            };
343            *mutable_uri = match headers.get("location") {
344                Some(location) => match http::Uri::try_from(match location.to_str() {
345                    Ok(location) => location,
346                    Err(..) => {
347                        return Err(Error::Response(ResponseError::RedirectBrokenLocation));
348                    }
349                }) {
350                    Ok(location) => location,
351                    Err(..) => {
352                        return Err(Error::Response(ResponseError::RedirectBrokenLocation));
353                    }
354                },
355                None => {
356                    return Err(Error::Response(ResponseError::RedirectMissingLocation));
357                }
358            };
359            self._request()?;
360            return Err(Error::WouldBlock);
361        }
362        Ok((bytes, last_byte, version, status, headers))
363    }
364
365    pub(crate) fn read_to_vec(reader: &mut dyn Read) -> Result<Vec<u8>> {
366        const BYTES_ADD: usize = 8 * 1024;
367
368        let mut bytes = Vec::with_capacity(BYTES_ADD);
369        unsafe { bytes.set_len(BYTES_ADD) };
370        let mut began_recieving = false;
371        let mut read = 0;
372        loop {
373            match reader.read(&mut bytes[read..]) {
374                Err(err) if err.kind() == io::ErrorKind::Interrupted => {
375                    std::thread::yield_now();
376                    continue;
377                }
378                Err(err)
379                    if err.kind() == io::ErrorKind::WouldBlock
380                        || err.kind() == io::ErrorKind::TimedOut =>
381                {
382                    if began_recieving {
383                        break;
384                    } else {
385                        std::thread::yield_now();
386                        continue;
387                    }
388                }
389
390                Err(err) => {
391                    return Err(Error::IO(err));
392                }
393                Ok(just_read) => {
394                    began_recieving = true;
395                    read += just_read;
396
397                    if read == bytes.len() {
398                        bytes.reserve(BYTES_ADD);
399                        unsafe { bytes.set_len(bytes.capacity()) };
400                    }
401                }
402            };
403        }
404        unsafe { bytes.set_len(read) };
405        Ok(bytes)
406    }
407
408    pub fn done(&mut self) -> bool {
409        let result = self.stream.read(&mut [0; 0]).is_ok();
410        result
411    }
412    pub fn wait(&mut self) -> Result<http::Response<Vec<u8>>> {
413        let mut response = http::Response::builder();
414        let (bytes, last_byte, version, status, headers) = self._handle()?;
415        response = response
416            .version(match &version[..] {
417                b"HTTP/0.9" => Version::HTTP_09,
418                b"HTTP/1.0" => Version::HTTP_10,
419                b"HTTP/1.1" => Version::HTTP_11,
420                b"HTTP/2.0" => Version::HTTP_2,
421                b"HTTP/3.0" => Version::HTTP_3,
422                _ => Version::HTTP_11,
423            })
424            .status(status);
425
426        for (name, value) in headers.iter() {
427            response = response.header(name, value);
428        }
429
430        let mut body: Vec<u8> = bytes.into_iter().skip(last_byte).collect();
431        body.truncate(body.len());
432
433        match response.body(body) {
434            Ok(res) => Ok(res),
435            Err(..) => Err(Error::Response(ResponseError::FailedToConstructResponse)),
436        }
437    }
438    pub fn follow_redirects(&mut self) -> Result<http::Response<Vec<u8>>> {
439        loop {
440            match self.wait() {
441                Err(Error::WouldBlock) => continue,
442                Err(err) => return Err(err),
443                Ok(result) => return Ok(result),
444            }
445        }
446    }
447    pub fn write(&mut self, writer: &mut dyn Write) -> Result<()> {
448        let (bytes, last_byte, _, _, _) = self._handle()?;
449
450        let start_at = match self.config.header {
451            Content::Body => last_byte,
452            _ => 0,
453        };
454        let end_at = match self.config.header {
455            Content::Header => last_byte,
456            _ => bytes.len(),
457        };
458
459        writer
460            .write_all(&bytes[start_at..end_at])
461            .map_err(|err| err.into())
462    }
463    pub fn follow_redirects_write(&mut self, writer: &mut dyn Write) -> Result<()> {
464        loop {
465            match self.write(writer) {
466                Err(Error::WouldBlock) => continue,
467                Err(err) => return Err(err),
468                Ok(result) => return Ok(result),
469            }
470        }
471    }
472}
473
474const LINE_ENDING: &[u8] = b"\r\n";
475
476pub fn get(url: &str, user_agent: &str) -> Result<Client> {
477    let req = match Request::get(url)
478        .header("User-Agent", user_agent)
479        .body(Vec::new())
480    {
481        Ok(req) => req,
482        Err(err) => return Err(Error::Request(RequestError::FailedToConstructRequest(err))),
483    };
484    let host = match req.uri().host() {
485        Some(host) => host,
486        None => return Err(Error::Request(RequestError::NoHost)),
487    };
488    let port = req.uri().port_u16().unwrap_or(443);
489    let mut result = Client::connect(
490        Arc::new(Config::default()),
491        host,
492        port,
493        if port == 443 { true } else { false },
494    )?;
495    result.request(req)?;
496    Ok(result)
497}
498pub fn request(request: http::Request<Vec<u8>>, config: Config) -> Result<Client> {
499    let host = match request.uri().host() {
500        Some(host) => host,
501        None => return Err(Error::Request(RequestError::NoHost)),
502    };
503    let port = request.uri().port_u16().unwrap_or(443);
504    let mut result = Client::connect(
505        Arc::new(config),
506        host,
507        port,
508        if port == 443 { true } else { false },
509    )?;
510    result.request(request)?;
511    Ok(result)
512}