#[cfg(feature = "ssl")]
use std::error::Error as StdError;
use std::io::{ErrorKind, Error};
#[cfg(feature = "ssl")]
use std::sync::{Arc, Mutex};
#[cfg(feature = "ssl")]
use std::result::Result as StdResult;
use std::io::{Write, Read, Result, BufReader, BufWriter};
use std::net::{SocketAddr, ToSocketAddrs, TcpStream};
#[cfg(test)]
use std::net::Shutdown;
use std::os::unix::prelude::AsRawFd;
#[cfg(feature = "ssl")]
use openssl::ssl::{SslConnectorBuilder, SslMethod, SslStream, SSL_VERIFY_PEER, SSL_VERIFY_NONE};
#[cfg(feature = "ssl")]
use openssl::error::ErrorStack;
#[cfg(feature = "ssl")]
use openssl::x509;
use std::str::FromStr;
use net::config;
use uuid::Uuid;
use std::time::Duration;
pub struct Connection {
id: String,
pub reader: BufReader<NetStream>,
pub writer: BufWriter<NetStream>,
config: config::Config,
peer_address: String,
local_address: String,
}
impl Connection {
fn new(
reader: BufReader<NetStream>,
writer: BufWriter<NetStream>,
config: &config::Config,
peer_address: String,
local_address: String,
) -> Connection {
Connection {
id: Uuid::new_v4().to_urn_string(),
reader: reader,
writer: writer,
config: config.clone(),
peer_address: peer_address,
local_address: local_address,
}
}
pub fn get_peer_address(&self) -> &String {
&self.peer_address
}
pub fn get_local_address(&self) -> &String {
&self.local_address
}
pub fn connect(config: &config::Config) -> Result<Connection> {
if config.use_ssl.unwrap_or(false) {
Connection::connect_ssl_internal(config)
} else {
Connection::connect_internal(config)
}
}
pub fn reconnect(&mut self) -> Result<Connection> {
if self.config.use_ssl.unwrap_or(false) {
Connection::connect_ssl_internal(&self.config)
} else {
Connection::connect_internal(&self.config)
}
}
pub fn id(&self) -> &String {
&self.id
}
pub fn is_valid(&self) -> bool {
match self.reader.get_ref() {
&NetStream::UnsecuredTcpStream(ref tcp) => {
debug!("TCP FD:{}", tcp.as_raw_fd());
if tcp.as_raw_fd() < 0 { false } else { true }
}
#[cfg(feature = "ssl")]
&NetStream::SslTcpStream(ref ssl) => {
let fd = ssl.lock().unwrap().get_ref().as_raw_fd();
debug!("SSL FD:{}", fd);
if fd < 0 {
return false;
} else {
return true;
}
}
}
}
fn host_to_sock_address(host: &str, port: u16) -> Result<SocketAddr> {
let server = match (host, port).to_socket_addrs() {
Ok(mut host_iter) => {
match host_iter.next() {
Some(mut host_addr) => return Ok(host_addr),
None => {
let err_str = format!("Failed to parse {}:{}. ", host, port);
error!("{}", err_str);
return Err(Error::new(ErrorKind::Other, err_str));
}
}
}
Err(e) => {
let err_str = format!("Failed to parse {}:{}. Error:{}", host, port, e);
error!("{}", err_str);
return Err(Error::new(ErrorKind::Other, err_str));
}
};
let err_str = format!("Failed to parse {}:{}. ", host, port);
error!("{}", err_str);
return Err(Error::new(ErrorKind::Other, err_str));
}
fn connect_internal(config: &config::Config) -> Result<Connection> {
let host: &str = &config.server.clone();
let port = config.port;
error!("Connecting to server {}:{}", host, port);
let mut stream_socket;
let server = try!(Connection::host_to_sock_address(host, port));
if config.connect_timeout.is_some() {
stream_socket = try!(TcpStream::connect_timeout(
&server,
Duration::from_millis(config.connect_timeout.unwrap()),
));
} else {
stream_socket = try!(TcpStream::connect(server));
}
stream_socket.set_nodelay(true);
if config.read_timeout.is_some() {
stream_socket.set_read_timeout(Some(Duration::from_millis(
config.read_timeout.unwrap(),
)));
}
if config.write_timeout.is_some() {
stream_socket.set_write_timeout(Some(Duration::from_millis(
config.write_timeout.unwrap(),
)));
}
let writer_socket = try!(stream_socket.try_clone());
let peer_address = match stream_socket.peer_addr() {
Ok(sock_addr) => sock_addr.to_string(),
Err(_) => String::from(""),
};
let local_address = match stream_socket.local_addr() {
Ok(sock_addr) => sock_addr.to_string(),
Err(_) => String::from(""),
};
Ok(Connection::new(
BufReader::new(NetStream::UnsecuredTcpStream(stream_socket)),
BufWriter::new(NetStream::UnsecuredTcpStream(writer_socket)),
config,
peer_address,
local_address,
))
}
#[cfg(not(feature = "ssl"))]
fn connect_ssl_internal(config: &config::Config) -> Result<Connection> {
panic!(
"Cannot connect to {}:{} over SSL without compiling with SSL support.",
config.server.clone(),
config.port
)
}
#[cfg(feature = "ssl")]
fn connect_ssl_internal(config: &config::Config) -> Result<Connection> {
let host: &str = &config.server.clone();
let port = config.port;
info!("Connecting to server {}:{}", host, port);
let mut socket;
let server = try!(Connection::host_to_sock_address(host, port));
if config.connect_timeout.is_some() {
socket = try!(TcpStream::connect_timeout(
&server,
Duration::from_millis(config.connect_timeout.unwrap()),
));
} else {
socket = try!(TcpStream::connect(server));
}
socket.set_nodelay(true);
let peer_address = match socket.peer_addr() {
Ok(sock_addr) => sock_addr.to_string(),
Err(_) => String::from(""),
};
let local_address = match socket.local_addr() {
Ok(sock_addr) => sock_addr.to_string(),
Err(_) => String::from(""),
};
if config.read_timeout.is_some() {
socket.set_read_timeout(Some(Duration::from_millis(config.read_timeout.unwrap())));
}
if config.write_timeout.is_some() {
socket.set_write_timeout(Some(Duration::from_millis(config.write_timeout.unwrap())));
}
let mut ssl_connector_builder = SslConnectorBuilder::new(SslMethod::tls()).unwrap();
{
let ctx = ssl_connector_builder.builder_mut();
ctx.set_default_verify_paths().unwrap();
if config.verify.unwrap_or(false) {
ctx.set_verify(SSL_VERIFY_PEER);
} else {
ctx.set_verify(SSL_VERIFY_NONE);
}
if config.verify_depth.unwrap_or(0) > 0 {
ctx.set_verify_depth(config.verify_depth.unwrap());
}
if config.certificate_file.is_some() {
try!(ssl_to_io(ctx.set_certificate_file(
config.certificate_file.as_ref().unwrap(),
x509::X509_FILETYPE_PEM,
)));
}
if config.private_key_file.is_some() {
try!(ssl_to_io(ctx.set_private_key_file(
config.private_key_file.as_ref().unwrap(),
x509::X509_FILETYPE_PEM,
)));
}
if config.ca_file.is_some() {
try!(ssl_to_io(ctx.set_ca_file(config.ca_file.as_ref().unwrap())));
}
}
let ssl_connector = ssl_connector_builder.build();
let stream_socket_result =
match ssl_connector.connect(&*format!("{}:{}", host, port), socket) {
Ok(s) => s,
Err(e) => {
return Err(Error::new(
ErrorKind::Other,
&format!(
"An SSL error occurred. ({}:{})",
e.description(),
e.cause().unwrap()
)
[..],
));
}
};
let stream_socket = Arc::new(Mutex::new(stream_socket_result));
let writer_stream = Arc::clone(&stream_socket);
Ok(Connection::new(
BufReader::new(NetStream::SslTcpStream(stream_socket)),
BufWriter::new(NetStream::SslTcpStream(writer_stream)),
config,
peer_address,
local_address,
))
}
}
#[cfg(feature = "ssl")]
fn ssl_to_io<T>(res: StdResult<T, ErrorStack>) -> Result<T> {
match res {
Ok(x) => Ok(x),
Err(e) => {
Err(Error::new(
ErrorKind::Other,
&format!("An SSL error occurred. ({})", e.description())[..],
))
}
}
}
pub enum NetStream {
UnsecuredTcpStream(TcpStream),
#[cfg(feature = "ssl")]
SslTcpStream(Arc<Mutex<SslStream<TcpStream>>>),
}
impl Read for NetStream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
match self {
&mut NetStream::UnsecuredTcpStream(ref mut stream) => stream.read(buf),
#[cfg(feature = "ssl")]
&mut NetStream::SslTcpStream(ref mut stream) => stream.lock().unwrap().read(buf),
}
}
fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
match self {
&mut NetStream::UnsecuredTcpStream(ref mut stream) => stream.read_exact(buf),
#[cfg(feature = "ssl")]
&mut NetStream::SslTcpStream(ref mut stream) => stream.lock().unwrap().read_exact(buf),
}
}
}
impl Write for NetStream {
fn write(&mut self, buf: &[u8]) -> Result<(usize)> {
match self {
&mut NetStream::UnsecuredTcpStream(ref mut stream) => stream.write(buf),
#[cfg(feature = "ssl")]
&mut NetStream::SslTcpStream(ref mut stream) => {
stream.lock().unwrap().write(buf)
}
}
}
fn write_all(&mut self, buf: &[u8]) -> Result<()> {
match self {
&mut NetStream::UnsecuredTcpStream(ref mut stream) => stream.write_all(buf),
#[cfg(feature = "ssl")]
&mut NetStream::SslTcpStream(ref mut stream) => stream.lock().unwrap().write_all(buf),
}
}
fn flush(&mut self) -> Result<()> {
match self {
&mut NetStream::UnsecuredTcpStream(ref mut stream) => stream.flush(),
#[cfg(feature = "ssl")]
&mut NetStream::SslTcpStream(ref mut stream) => stream.lock().unwrap().flush(),
}
}
}
#[cfg(test)]
#[allow(unused_must_use)]
impl Drop for Connection {
fn drop(&mut self) {
info!(
"Drop for Connection:Dropping connection id: {}",
self.id.clone()
);
match self.reader.get_mut() {
&mut NetStream::UnsecuredTcpStream(ref mut stream) => {
stream.shutdown(Shutdown::Read);
stream.shutdown(Shutdown::Write);
}
#[cfg(feature = "ssl")]
&mut NetStream::SslTcpStream(ref mut ssl) => {
ssl.lock().unwrap().shutdown();
}
}
}
}