use std::io::{self, Read, Write};
use std::net::TcpStream;
use std::time::Duration;
#[cfg(unix)]
use std::os::unix::net::UnixStream;
use crate::DriverError;
pub(crate) enum Stream {
Tcp(TcpStream),
#[cfg(unix)]
Unix(UnixStream),
#[cfg(feature = "tls")]
Tls(Box<rustls::StreamOwned<rustls::ClientConnection, TcpStream>>),
}
impl Read for Stream {
#[inline(always)]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Stream::Tcp(s) => s.read(buf),
#[cfg(unix)]
Stream::Unix(s) => s.read(buf),
#[cfg(feature = "tls")]
Stream::Tls(s) => s.read(buf),
}
}
}
impl Write for Stream {
#[inline(always)]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Stream::Tcp(s) => s.write(buf),
#[cfg(unix)]
Stream::Unix(s) => s.write(buf),
#[cfg(feature = "tls")]
Stream::Tls(s) => s.write(buf),
}
}
#[inline(always)]
fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
match self {
Stream::Tcp(s) => s.write_all(buf),
#[cfg(unix)]
Stream::Unix(s) => s.write_all(buf),
#[cfg(feature = "tls")]
Stream::Tls(s) => s.write_all(buf),
}
}
#[inline(always)]
fn flush(&mut self) -> io::Result<()> {
match self {
Stream::Tcp(s) => s.flush(),
#[cfg(unix)]
Stream::Unix(s) => s.flush(),
#[cfg(feature = "tls")]
Stream::Tls(s) => s.flush(),
}
}
}
impl Stream {
pub(crate) fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
match self {
Stream::Tcp(s) => s.set_read_timeout(dur),
#[cfg(unix)]
Stream::Unix(s) => s.set_read_timeout(dur),
#[cfg(feature = "tls")]
Stream::Tls(s) => s.sock.set_read_timeout(dur),
}
}
#[allow(dead_code)] pub(crate) fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
match self {
Stream::Tcp(s) => s.set_write_timeout(dur),
#[cfg(unix)]
Stream::Unix(s) => s.set_write_timeout(dur),
#[cfg(feature = "tls")]
Stream::Tls(s) => s.sock.set_write_timeout(dur),
}
}
#[allow(dead_code)] pub(crate) fn set_nodelay(&self) -> Result<(), DriverError> {
match self {
Stream::Tcp(s) => s.set_nodelay(true).map_err(DriverError::Io),
#[cfg(unix)]
Stream::Unix(_) => Ok(()),
#[cfg(feature = "tls")]
Stream::Tls(s) => s.sock.set_nodelay(true).map_err(DriverError::Io),
}
}
pub(crate) fn set_keepalive(&self) -> Result<(), DriverError> {
match self {
Stream::Tcp(s) => set_tcp_keepalive(s),
#[cfg(unix)]
Stream::Unix(_) => Ok(()),
#[cfg(feature = "tls")]
Stream::Tls(s) => set_tcp_keepalive(&s.sock),
}
}
}
fn set_tcp_keepalive(tcp: &TcpStream) -> Result<(), DriverError> {
let sock = socket2::SockRef::from(tcp);
let ka = socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(60))
.with_interval(Duration::from_secs(15));
sock.set_tcp_keepalive(&ka).map_err(DriverError::Io)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stream_tcp_read_write_traits() {
fn assert_read_write<T: Read + Write>() {}
assert_read_write::<Stream>();
}
}