Skip to main content

http/
stream.rs

1use std::{
2    io::{self, Read, Write},
3    net::TcpStream,
4    time::Duration,
5};
6
7#[cfg(feature = "tls")]
8use rustls::{ConnectionCommon, SideData};
9
10pub trait HttpStream: Read + Write {
11    fn set_blocking(&mut self) -> io::Result<()> {
12        Ok(())
13    }
14
15    fn set_non_blocking(&mut self, timeout: Duration) -> io::Result<()> {
16        let _ = timeout;
17        Ok(())
18    }
19}
20
21pub trait IntoHttpStream {
22    type Stream: HttpStream + 'static;
23    fn into_http_stream(self) -> Self::Stream;
24}
25
26#[derive(Debug)]
27pub struct StringStream {
28    input: Vec<u8>,
29    offset: usize,
30    output: Vec<u8>,
31}
32
33impl Read for StringStream {
34    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
35        let n = self.peek(buf)?;
36        self.offset += n;
37        Ok(n)
38    }
39
40    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
41        if self.offset >= self.input.len() {
42            return Ok(0);
43        }
44        let remaining = &self.input[self.offset..];
45        buf.reserve(remaining.len());
46        buf.extend_from_slice(remaining);
47        self.offset += remaining.len();
48        Ok(remaining.len())
49    }
50}
51
52impl Write for StringStream {
53    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
54        self.output.extend_from_slice(buf);
55        Ok(buf.len())
56    }
57
58    fn flush(&mut self) -> io::Result<()> {
59        Ok(())
60    }
61}
62
63impl StringStream {
64    pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
65        let min = usize::min(buf.len(), self.input.len() - self.offset);
66        buf[..min].copy_from_slice(&self.input[self.offset..self.offset + min]);
67        Ok(min)
68    }
69}
70
71impl HttpStream for StringStream {}
72
73impl IntoHttpStream for String {
74    type Stream = StringStream;
75
76    fn into_http_stream(self) -> Self::Stream {
77        let src_vec = self.into_bytes();
78        StringStream {
79            input: src_vec,
80            offset: 0,
81            output: Vec::new(),
82        }
83    }
84}
85
86impl IntoHttpStream for &str {
87    type Stream = StringStream;
88
89    fn into_http_stream(self) -> Self::Stream {
90        self.to_string().into_http_stream()
91    }
92}
93
94impl<S: HttpStream + 'static> IntoHttpStream for S {
95    type Stream = S;
96
97    fn into_http_stream(self) -> Self::Stream {
98        self
99    }
100}
101
102impl HttpStream for TcpStream {
103    fn set_non_blocking(&mut self, timeout: Duration) -> io::Result<()> {
104        self.set_read_timeout(Some(timeout))
105    }
106
107    fn set_blocking(&mut self) -> io::Result<()> {
108        self.set_read_timeout(None)
109    }
110}
111
112pub struct DummyStream;
113
114impl Read for DummyStream {
115    fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
116        Ok(0)
117    }
118}
119
120impl Write for DummyStream {
121    fn write(&mut self, _buf: &[u8]) -> io::Result<usize> {
122        Ok(0)
123    }
124
125    fn flush(&mut self) -> io::Result<()> {
126        Ok(())
127    }
128}
129
130impl HttpStream for DummyStream {}
131
132pub fn dummy() -> Box<dyn HttpStream> {
133    Box::new(DummyStream)
134}
135
136#[cfg(feature = "tls")]
137impl<C, SD, S: HttpStream> HttpStream for rustls::StreamOwned<C, S>
138where
139    SD: SideData,
140    C: core::ops::DerefMut<Target = ConnectionCommon<SD>>,
141{
142    fn set_blocking(&mut self) -> io::Result<()> {
143        self.sock.set_blocking()
144    }
145
146    fn set_non_blocking(&mut self, timeout: Duration) -> io::Result<()> {
147        self.sock.set_non_blocking(timeout)
148    }
149}
150
151#[cfg(test)]
152mod test {
153    use std::io::Read;
154
155    use crate::stream::IntoHttpStream;
156
157    #[test]
158    fn string_stream() {
159        let mut hi = String::from("Hello world!").into_http_stream();
160        let mut vec = Vec::new();
161        hi.read_to_end(&mut vec).unwrap();
162
163        assert_eq!(str::from_utf8(vec.as_slice()).unwrap(), "Hello world!");
164
165        assert_eq!(hi.read(&mut [0; 10]).unwrap(), 0)
166    }
167}