#[cfg(feature = "tls")]
use std::io::ErrorKind::{ConnectionAborted, WouldBlock};
use std::io::{Read, Result as IoResult, Write};
use std::net::TcpStream;
#[cfg(feature = "tls")]
use std::sync::Arc;
#[cfg(any(feature = "native-tls", feature = "tls"))]
use http::uri::Scheme;
#[cfg(feature = "native-tls")]
use native_tls::{HandshakeError, TlsConnector, TlsStream};
#[cfg(feature = "tls")]
use rustls::{ClientConfig, ClientSession, Session, StreamOwned};
#[cfg(feature = "tls")]
use webpki::DNSNameRef;
#[cfg(feature = "tls")]
use webpki_roots::TLS_SERVER_ROOTS;
use super::{happy_eyeballs::connect, timeout::Timeout, Error, Options};
pub enum Stream {
Tcp(TcpStream),
TcpWithTimeout(TcpStream, Timeout),
#[cfg(feature = "native-tls")]
NativeTls(TlsStream<TcpStream>),
#[cfg(feature = "native-tls")]
NativeTlsWithTimeout(TlsStream<TcpStream>, Timeout),
#[cfg(feature = "tls")]
Rustls(Box<StreamOwned<ClientSession, TcpStream>>),
#[cfg(feature = "tls")]
RustlsWithTimeout(Box<StreamOwned<ClientSession, TcpStream>>, Timeout),
}
impl Stream {
pub fn new(
#[cfg(any(feature = "native-tls", feature = "tls"))] scheme: &Scheme,
host: &str,
port: u16,
opts: &Options,
) -> Result<Self, Error> {
let stream = connect(host, port, &opts)?;
match opts.timeout {
#[cfg(feature = "native-tls")]
None if scheme == &Scheme::HTTPS => {
let stream = perform_native_tls_handshake(stream, host, opts.tls_connector)?;
Ok(Self::NativeTls(stream))
}
#[cfg(feature = "tls")]
None if scheme == &Scheme::HTTPS => {
let stream = perform_rustls_handshake(stream, host, opts.client_config)?;
Ok(Self::Rustls(Box::new(stream)))
}
None => Ok(Self::Tcp(stream)),
#[cfg(feature = "native-tls")]
Some(timeout) if scheme == &Scheme::HTTPS => {
let timeout = Timeout::start(&stream, timeout)?;
let stream = perform_native_tls_handshake(stream, host, opts.tls_connector)?;
Ok(Self::NativeTlsWithTimeout(stream, timeout))
}
#[cfg(feature = "tls")]
Some(timeout) if scheme == &Scheme::HTTPS => {
let timeout = Timeout::start(&stream, timeout)?;
let stream = perform_rustls_handshake(stream, host, opts.client_config)?;
Ok(Self::RustlsWithTimeout(Box::new(stream), timeout))
}
Some(timeout) => {
let timeout = Timeout::start(&stream, timeout)?;
Ok(Self::TcpWithTimeout(stream, timeout))
}
}
}
}
#[cfg(feature = "native-tls")]
fn perform_native_tls_handshake(
stream: TcpStream,
host: &str,
tls_connector: Option<&TlsConnector>,
) -> Result<TlsStream<TcpStream>, Error> {
let handshake = match tls_connector {
Some(tls_connector) => tls_connector.connect(host, stream),
None => TlsConnector::new()?.connect(host, stream),
};
match handshake {
Ok(stream) => Ok(stream),
Err(HandshakeError::Failure(err)) => Err(err.into()),
Err(HandshakeError::WouldBlock(mut stream)) => loop {
match stream.handshake() {
Ok(stream) => return Ok(stream),
Err(HandshakeError::Failure(err)) => return Err(err.into()),
Err(HandshakeError::WouldBlock(stream1)) => stream = stream1,
}
},
}
}
#[cfg(feature = "tls")]
fn perform_rustls_handshake(
mut stream: TcpStream,
host: &str,
client_config: Option<&Arc<ClientConfig>>,
) -> Result<StreamOwned<ClientSession, TcpStream>, Error> {
let name = DNSNameRef::try_from_ascii_str(host)?;
let mut session = match client_config {
Some(client_config) => ClientSession::new(client_config, name),
None => {
let mut client_config = ClientConfig::new();
client_config
.root_store
.add_server_trust_anchors(&TLS_SERVER_ROOTS);
ClientSession::new(&Arc::new(client_config), name)
}
};
while let Err(err) = session.complete_io(&mut stream) {
if err.kind() != WouldBlock || !session.is_handshaking() {
return Err(err.into());
}
}
Ok(StreamOwned::new(session, stream))
}
impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
match self {
Self::Tcp(stream) => stream.read(buf),
Self::TcpWithTimeout(stream, timeout) => timeout.read(stream, buf),
#[cfg(feature = "native-tls")]
Self::NativeTls(stream) => stream.read(buf),
#[cfg(feature = "native-tls")]
Self::NativeTlsWithTimeout(stream, timeout) => timeout.read(stream, buf),
#[cfg(feature = "tls")]
Self::Rustls(stream) => {
let res = stream.read(buf);
handle_close_notify(res, stream)
}
#[cfg(feature = "tls")]
Self::RustlsWithTimeout(stream, timeout) => {
let res = timeout.read(stream, buf);
handle_close_notify(res, stream)
}
}
}
}
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
match self {
Self::Tcp(stream) | Self::TcpWithTimeout(stream, _) => stream.write(buf),
#[cfg(feature = "native-tls")]
Self::NativeTls(stream) | Self::NativeTlsWithTimeout(stream, _) => stream.write(buf),
#[cfg(feature = "tls")]
Self::Rustls(stream) | Self::RustlsWithTimeout(stream, _) => stream.write(buf),
}
}
fn flush(&mut self) -> IoResult<()> {
match self {
Self::Tcp(stream) | Self::TcpWithTimeout(stream, _) => stream.flush(),
#[cfg(feature = "native-tls")]
Self::NativeTls(stream) | Self::NativeTlsWithTimeout(stream, _) => stream.flush(),
#[cfg(feature = "tls")]
Self::Rustls(stream) | Self::RustlsWithTimeout(stream, _) => stream.flush(),
}
}
}
#[cfg(feature = "tls")]
fn handle_close_notify(
res: IoResult<usize>,
stream: &mut StreamOwned<ClientSession, TcpStream>,
) -> IoResult<usize> {
match res {
Err(err) if err.kind() == ConnectionAborted => {
stream.sess.send_close_notify();
stream.sess.complete_io(&mut stream.sock)?;
Ok(0)
}
res => res,
}
}