1#[cfg(test)]
2use std::io::Cursor;
3use std::io::{self, Read, Write};
4use std::net::TcpStream;
5
6#[cfg(feature = "charsets")]
7use encoding_rs::{self, CoderResult};
8#[cfg(feature = "tls")]
9use native_tls::{HandshakeError, TlsConnector, TlsStream};
10use url::Url;
11
12#[cfg(feature = "charsets")]
13use crate::charsets::Charset;
14use crate::{HttpError, HttpResult};
15
16pub enum BaseStream {
17 Plain(TcpStream),
18 #[cfg(feature = "tls")]
19 Tls(TlsStream<TcpStream>),
20 #[cfg(test)]
21 Mock(Cursor<Vec<u8>>),
22}
23
24impl BaseStream {
25 pub fn connect(url: &Url) -> HttpResult<BaseStream> {
26 let host = url.host_str().ok_or(HttpError::InvalidUrl("url has no host"))?;
27 let port = url
28 .port_or_known_default()
29 .ok_or(HttpError::InvalidUrl("url has no port"))?;
30
31 debug!("trying to connect to {}:{}", host, port);
32
33 Ok(match url.scheme() {
34 "http" => BaseStream::Plain(TcpStream::connect((host, port))?),
35 #[cfg(feature = "tls")]
36 "https" => BaseStream::connect_tls(host, port)?,
37 _ => return Err(HttpError::InvalidUrl("url contains unsupported scheme")),
38 })
39 }
40
41 #[cfg(feature = "tls")]
42 fn connect_tls(host: &str, port: u16) -> HttpResult<BaseStream> {
43 let connector = TlsConnector::new()?;
44 let stream = TcpStream::connect((host, port))?;
45 let tls_stream = match connector.connect(host, stream) {
46 Ok(stream) => stream,
47 Err(HandshakeError::Failure(err)) => return Err(err.into()),
48 Err(HandshakeError::WouldBlock(_)) => panic!("socket configured in non-blocking mode"),
49 };
50 Ok(BaseStream::Tls(tls_stream))
51 }
52
53 #[cfg(test)]
54 pub fn mock(bytes: Vec<u8>) -> BaseStream {
55 BaseStream::Mock(Cursor::new(bytes))
56 }
57}
58
59impl Read for BaseStream {
60 #[inline]
61 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
62 match self {
63 BaseStream::Plain(s) => s.read(buf),
64 #[cfg(feature = "tls")]
65 BaseStream::Tls(s) => s.read(buf),
66 #[cfg(test)]
67 BaseStream::Mock(s) => s.read(buf),
68 }
69 }
70}
71
72impl Write for BaseStream {
73 #[inline]
74 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
75 match self {
76 BaseStream::Plain(s) => s.write(buf),
77 #[cfg(feature = "tls")]
78 BaseStream::Tls(s) => s.write(buf),
79 #[cfg(test)]
80 _ => Ok(0),
81 }
82 }
83
84 #[inline]
85 fn flush(&mut self) -> io::Result<()> {
86 match self {
87 BaseStream::Plain(s) => s.flush(),
88 #[cfg(feature = "tls")]
89 BaseStream::Tls(s) => s.flush(),
90 #[cfg(test)]
91 _ => Ok(()),
92 }
93 }
94}
95
96#[cfg(feature = "charsets")]
97pub struct StreamDecoder {
98 output: String,
99 decoder: encoding_rs::Decoder,
100}
101
102#[cfg(feature = "charsets")]
103impl StreamDecoder {
104 pub fn new(charset: Charset) -> StreamDecoder {
105 StreamDecoder {
106 output: String::with_capacity(1024),
107 decoder: charset.new_decoder(),
108 }
109 }
110
111 pub fn take(mut self) -> String {
112 self.decoder.decode_to_string(&[], &mut self.output, true);
113 self.output
114 }
115}
116
117#[cfg(feature = "charsets")]
118impl Write for StreamDecoder {
119 fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
120 let len = buf.len();
121 while buf.len() > 0 {
122 match self.decoder.decode_to_string(&buf, &mut self.output, false) {
123 (CoderResult::InputEmpty, written, _) => {
124 buf = &buf[written..];
125 }
126 (CoderResult::OutputFull, written, _) => {
127 buf = &buf[written..];
128 self.output.reserve(self.output.capacity());
129 }
130 }
131 }
132 Ok(len)
133 }
134
135 fn flush(&mut self) -> io::Result<()> {
136 Ok(())
137 }
138}
139
140#[cfg(test)]
141#[cfg(feature = "charsets")]
142mod tests {
143 use super::StreamDecoder;
144 use crate::charsets;
145 use std::io::Write;
146
147 #[test]
148 fn test_stream_decoder_utf8() {
149 let mut decoder = StreamDecoder::new(charsets::UTF_8);
150 decoder.write_all("québec".as_bytes()).unwrap();
151 assert_eq!(decoder.take(), "québec");
152 }
153
154 #[test]
155 fn test_stream_decoder_latin1() {
156 let mut decoder = StreamDecoder::new(charsets::WINDOWS_1252);
157 decoder.write_all(&[201]).unwrap();
158 assert_eq!(decoder.take(), "É");
159 }
160
161 #[test]
162 fn test_stream_decoder_large_buffer() {
163 let mut decoder = StreamDecoder::new(charsets::WINDOWS_1252);
164 let mut buf = vec![];
165 for _ in 0..10_000 {
166 buf.push(201);
167 }
168 decoder.write_all(&buf).unwrap();
169 for c in decoder.take().chars() {
170 assert_eq!(c, 'É');
171 }
172 }
173}