lynx/
streams.rs

1#[cfg(test)]
2use std::io::Cursor;
3use std::io::{self, Read, Write};
4use std::net::TcpStream;
5
6#[cfg(feature = "charsets")]
7use encoding_rs::{self, CoderResult};
8#[cfg(feature = "tls")]
9use native_tls::{HandshakeError, TlsConnector, TlsStream};
10use url::Url;
11
12#[cfg(feature = "charsets")]
13use crate::charsets::Charset;
14use crate::{HttpError, HttpResult};
15
16pub enum BaseStream {
17    Plain(TcpStream),
18    #[cfg(feature = "tls")]
19    Tls(TlsStream<TcpStream>),
20    #[cfg(test)]
21    Mock(Cursor<Vec<u8>>),
22}
23
24impl BaseStream {
25    pub fn connect(url: &Url) -> HttpResult<BaseStream> {
26        let host = url.host_str().ok_or(HttpError::InvalidUrl("url has no host"))?;
27        let port = url
28            .port_or_known_default()
29            .ok_or(HttpError::InvalidUrl("url has no port"))?;
30
31        debug!("trying to connect to {}:{}", host, port);
32
33        Ok(match url.scheme() {
34            "http" => BaseStream::Plain(TcpStream::connect((host, port))?),
35            #[cfg(feature = "tls")]
36            "https" => BaseStream::connect_tls(host, port)?,
37            _ => return Err(HttpError::InvalidUrl("url contains unsupported scheme")),
38        })
39    }
40
41    #[cfg(feature = "tls")]
42    fn connect_tls(host: &str, port: u16) -> HttpResult<BaseStream> {
43        let connector = TlsConnector::new()?;
44        let stream = TcpStream::connect((host, port))?;
45        let tls_stream = match connector.connect(host, stream) {
46            Ok(stream) => stream,
47            Err(HandshakeError::Failure(err)) => return Err(err.into()),
48            Err(HandshakeError::WouldBlock(_)) => panic!("socket configured in non-blocking mode"),
49        };
50        Ok(BaseStream::Tls(tls_stream))
51    }
52
53    #[cfg(test)]
54    pub fn mock(bytes: Vec<u8>) -> BaseStream {
55        BaseStream::Mock(Cursor::new(bytes))
56    }
57}
58
59impl Read for BaseStream {
60    #[inline]
61    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
62        match self {
63            BaseStream::Plain(s) => s.read(buf),
64            #[cfg(feature = "tls")]
65            BaseStream::Tls(s) => s.read(buf),
66            #[cfg(test)]
67            BaseStream::Mock(s) => s.read(buf),
68        }
69    }
70}
71
72impl Write for BaseStream {
73    #[inline]
74    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
75        match self {
76            BaseStream::Plain(s) => s.write(buf),
77            #[cfg(feature = "tls")]
78            BaseStream::Tls(s) => s.write(buf),
79            #[cfg(test)]
80            _ => Ok(0),
81        }
82    }
83
84    #[inline]
85    fn flush(&mut self) -> io::Result<()> {
86        match self {
87            BaseStream::Plain(s) => s.flush(),
88            #[cfg(feature = "tls")]
89            BaseStream::Tls(s) => s.flush(),
90            #[cfg(test)]
91            _ => Ok(()),
92        }
93    }
94}
95
96#[cfg(feature = "charsets")]
97pub struct StreamDecoder {
98    output: String,
99    decoder: encoding_rs::Decoder,
100}
101
102#[cfg(feature = "charsets")]
103impl StreamDecoder {
104    pub fn new(charset: Charset) -> StreamDecoder {
105        StreamDecoder {
106            output: String::with_capacity(1024),
107            decoder: charset.new_decoder(),
108        }
109    }
110
111    pub fn take(mut self) -> String {
112        self.decoder.decode_to_string(&[], &mut self.output, true);
113        self.output
114    }
115}
116
117#[cfg(feature = "charsets")]
118impl Write for StreamDecoder {
119    fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
120        let len = buf.len();
121        while buf.len() > 0 {
122            match self.decoder.decode_to_string(&buf, &mut self.output, false) {
123                (CoderResult::InputEmpty, written, _) => {
124                    buf = &buf[written..];
125                }
126                (CoderResult::OutputFull, written, _) => {
127                    buf = &buf[written..];
128                    self.output.reserve(self.output.capacity());
129                }
130            }
131        }
132        Ok(len)
133    }
134
135    fn flush(&mut self) -> io::Result<()> {
136        Ok(())
137    }
138}
139
140#[cfg(test)]
141#[cfg(feature = "charsets")]
142mod tests {
143    use super::StreamDecoder;
144    use crate::charsets;
145    use std::io::Write;
146
147    #[test]
148    fn test_stream_decoder_utf8() {
149        let mut decoder = StreamDecoder::new(charsets::UTF_8);
150        decoder.write_all("québec".as_bytes()).unwrap();
151        assert_eq!(decoder.take(), "québec");
152    }
153
154    #[test]
155    fn test_stream_decoder_latin1() {
156        let mut decoder = StreamDecoder::new(charsets::WINDOWS_1252);
157        decoder.write_all(&[201]).unwrap();
158        assert_eq!(decoder.take(), "É");
159    }
160
161    #[test]
162    fn test_stream_decoder_large_buffer() {
163        let mut decoder = StreamDecoder::new(charsets::WINDOWS_1252);
164        let mut buf = vec![];
165        for _ in 0..10_000 {
166            buf.push(201);
167        }
168        decoder.write_all(&buf).unwrap();
169        for c in decoder.take().chars() {
170            assert_eq!(c, 'É');
171        }
172    }
173}