#[cfg(feature = "ssl")]
use openssl::ssl::{SslConnector, SslStream};
use std::io;
use std::io::{Read, Write};
use std::net;
use std::net::TcpStream;
use std::time::Duration;
use std::sync::Arc;
pub trait CDRSTransport: Sized + Read + Write + Send + Sync {
fn try_clone(&self) -> io::Result<Self>;
fn close(&mut self, close: net::Shutdown) -> io::Result<()>;
fn set_timeout(&mut self, dur: Option<Duration>) -> io::Result<()>;
fn is_alive(&self) -> bool;
}
pub struct TransportTcp {
tcp: TcpStream,
addr: String,
}
impl TransportTcp {
pub fn new(addr: &str) -> io::Result<TransportTcp> {
TcpStream::connect(addr).map(|socket| TransportTcp {
tcp: socket,
addr: addr.to_string(),
})
}
}
impl Read for TransportTcp {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.tcp.read(buf)
}
}
impl Write for TransportTcp {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.tcp.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.tcp.flush()
}
}
impl CDRSTransport for TransportTcp {
fn try_clone(&self) -> io::Result<TransportTcp> {
TcpStream::connect(self.addr.as_str()).map(|socket| TransportTcp {
tcp: socket,
addr: self.addr.clone(),
})
}
fn close(&mut self, close: net::Shutdown) -> io::Result<()> {
self.tcp.shutdown(close)
}
fn set_timeout(&mut self, dur: Option<Duration>) -> io::Result<()> {
self.tcp
.set_read_timeout(dur)
.and_then(|_| self.tcp.set_write_timeout(dur))
}
fn is_alive(&self) -> bool {
self.tcp.peer_addr().is_ok()
}
}
#[cfg(feature = "rust-tls")]
pub struct TransportRustls {
inner: rustls::StreamOwned<rustls::ClientSession, net::TcpStream>,
config: Arc<rustls::ClientConfig>,
addr: net::SocketAddr,
dns_name: webpki::DNSName,
}
#[cfg(feature = "rust-tls")]
impl TransportRustls {
pub fn new(addr: net::SocketAddr, dns_name: webpki::DNSName, config: Arc<rustls::ClientConfig>) -> io::Result<Self> {
let socket = std::net::TcpStream::connect(addr)?;
let session = rustls::ClientSession::new(&config, dns_name.as_ref());
Ok(Self {
inner: rustls::StreamOwned::new(session, socket),
config,
addr,
dns_name,
})
}
}
#[cfg(feature = "rust-tls")]
impl Read for TransportRustls {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
}
#[cfg(feature = "rust-tls")]
impl Write for TransportRustls {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf)
}
#[inline]
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
#[cfg(feature = "rust-tls")]
impl CDRSTransport for TransportRustls {
#[inline]
fn try_clone(&self) -> io::Result<Self> {
Self::new(self.addr, self.dns_name.clone(), self.config.clone())
}
fn close(&mut self, close: net::Shutdown) -> io::Result<()> {
self.inner.get_mut().shutdown(close)
}
fn set_timeout(&mut self, dur: Option<Duration>) -> io::Result<()> {
self.inner.get_mut().set_read_timeout(dur)?;
self.inner.get_mut().set_write_timeout(dur)
}
fn is_alive(&self) -> bool {
self.inner.get_ref().peer_addr().is_ok()
}
}
#[cfg(feature = "ssl")]
pub struct TransportTls {
ssl: SslStream<TcpStream>,
connector: SslConnector,
addr: String,
}
#[cfg(feature = "ssl")]
impl TransportTls {
pub fn new(addr: &str, connector: &SslConnector) -> io::Result<TransportTls> {
let a: Vec<&str> = addr.split(':').collect();
let res = net::TcpStream::connect(addr).map(|socket| {
connector
.connect(a[0], socket)
.map(|sslsocket| TransportTls {
ssl: sslsocket,
connector: connector.clone(),
addr: addr.to_string(),
})
});
res.and_then(|res| {
res.map(|n: TransportTls| n)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
})
}
}
#[cfg(feature = "ssl")]
impl Read for TransportTls {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.ssl.read(buf)
}
}
#[cfg(feature = "ssl")]
impl Write for TransportTls {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.ssl.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.ssl.flush()
}
}
#[cfg(feature = "ssl")]
impl CDRSTransport for TransportTls {
fn try_clone(&self) -> io::Result<TransportTls> {
let ip = match self.addr.split(":").nth(0) {
Some(_ip) => _ip,
None => {
return Err(io::Error::new(
io::ErrorKind::Other,
"Wrong addess string - IP is missed",
));
}
};
let res = net::TcpStream::connect(self.addr.as_str()).map(|socket| {
self.connector
.connect(ip, socket)
.map(|sslsocket| TransportTls {
ssl: sslsocket,
connector: self.connector.clone(),
addr: self.addr.clone(),
})
});
res.and_then(|res| {
res.map(|n: TransportTls| n)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
})
}
fn close(&mut self, _close: net::Shutdown) -> io::Result<()> {
self.ssl
.shutdown()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
.and_then(|_| Ok(()))
}
fn set_timeout(&mut self, dur: Option<Duration>) -> io::Result<()> {
let stream = self.ssl.get_mut();
stream
.set_read_timeout(dur)
.and_then(|_| stream.set_write_timeout(dur))
}
fn is_alive(&self) -> bool {
self.ssl.get_ref().peer_addr().is_ok()
}
}