use crate::error::*;
use async_trait::async_trait;
use pravega_client_shared::PravegaNodeUri;
use snafu::ResultExt;
use std::fmt;
use std::fmt::{Debug, Formatter};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::io::{ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
use uuid::Uuid;
#[async_trait]
pub trait Connection: Send + Sync + Debug {
async fn send_async(&mut self, payload: &[u8]) -> Result<(), ConnectionError>;
async fn read_async(&mut self, buf: &mut [u8]) -> Result<(), ConnectionError>;
fn split(&mut self) -> (Box<dyn ConnectionReadHalf>, Box<dyn ConnectionWriteHalf>);
fn get_endpoint(&self) -> PravegaNodeUri;
fn get_uuid(&self) -> Uuid;
fn is_valid(&self) -> bool;
fn can_recycle(&mut self, recycle: bool);
}
pub struct TokioConnection {
pub uuid: Uuid,
pub endpoint: PravegaNodeUri,
pub stream: Option<TcpStream>,
pub can_recycle: bool,
}
#[async_trait]
impl Connection for TokioConnection {
async fn send_async(&mut self, payload: &[u8]) -> Result<(), ConnectionError> {
assert!(self.stream.is_some());
let endpoint = self.endpoint.clone();
self.stream
.as_mut()
.expect("get connection")
.write_all(payload)
.await
.context(SendData { endpoint })?;
Ok(())
}
async fn read_async(&mut self, buf: &mut [u8]) -> Result<(), ConnectionError> {
assert!(self.stream.is_some());
let endpoint = self.endpoint.clone();
self.stream
.as_mut()
.expect("get connection")
.read_exact(buf)
.await
.context(ReadData { endpoint })?;
Ok(())
}
fn split(&mut self) -> (Box<dyn ConnectionReadHalf>, Box<dyn ConnectionWriteHalf>) {
assert!(self.stream.is_some());
let (read_half, write_half) = tokio::io::split(self.stream.take().expect("take connection"));
let read = Box::new(ConnectionReadHalfTokio {
uuid: self.uuid,
endpoint: self.endpoint.clone(),
read_half: Some(read_half),
}) as Box<dyn ConnectionReadHalf>;
let write = Box::new(ConnectionWriteHalfTokio {
uuid: self.uuid,
endpoint: self.endpoint.clone(),
write_half: Some(write_half),
}) as Box<dyn ConnectionWriteHalf>;
(read, write)
}
fn get_endpoint(&self) -> PravegaNodeUri {
self.endpoint.clone()
}
fn get_uuid(&self) -> Uuid {
self.uuid
}
fn is_valid(&self) -> bool {
self.can_recycle
&& self.stream.as_ref().is_some()
&& self.stream.as_ref().expect("get connection").is_valid()
}
fn can_recycle(&mut self, can_recycle: bool) {
self.can_recycle = can_recycle
}
}
impl Debug for TokioConnection {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsConnection")
.field("connection id", &self.uuid)
.field("pravega endpoint", &self.endpoint)
.finish()
}
}
pub struct TlsConnection {
pub uuid: Uuid,
pub endpoint: PravegaNodeUri,
pub stream: Option<TlsStream<TcpStream>>,
pub can_recycle: bool,
}
#[async_trait]
impl Connection for TlsConnection {
async fn send_async(&mut self, payload: &[u8]) -> Result<(), ConnectionError> {
assert!(self.stream.is_some());
let endpoint = self.endpoint.clone();
self.stream
.as_mut()
.expect("get connection")
.write_all(payload)
.await
.context(SendData {
endpoint: endpoint.clone(),
})?;
self.stream
.as_mut()
.expect("get connection")
.flush()
.await
.context(SendData { endpoint })?;
Ok(())
}
async fn read_async(&mut self, buf: &mut [u8]) -> Result<(), ConnectionError> {
assert!(self.stream.is_some());
let endpoint = self.endpoint.clone();
self.stream
.as_mut()
.expect("get connection")
.read_exact(buf)
.await
.context(ReadData { endpoint })?;
Ok(())
}
fn split(&mut self) -> (Box<dyn ConnectionReadHalf>, Box<dyn ConnectionWriteHalf>) {
assert!(self.stream.is_some());
let (read_half, write_half) = tokio::io::split(self.stream.take().expect("take connection"));
let read = Box::new(ConnectionReadHalfTls {
uuid: self.uuid,
endpoint: self.endpoint.clone(),
read_half: Some(read_half),
}) as Box<dyn ConnectionReadHalf>;
let write = Box::new(ConnectionWriteHalfTls {
uuid: self.uuid,
endpoint: self.endpoint.clone(),
write_half: Some(write_half),
}) as Box<dyn ConnectionWriteHalf>;
(read, write)
}
fn get_endpoint(&self) -> PravegaNodeUri {
self.endpoint.clone()
}
fn get_uuid(&self) -> Uuid {
self.uuid
}
fn is_valid(&self) -> bool {
self.can_recycle
&& self.stream.as_ref().is_some()
&& self.stream.as_ref().expect("get connection").is_valid()
}
fn can_recycle(&mut self, can_recycle: bool) {
self.can_recycle = can_recycle;
}
}
impl Debug for TlsConnection {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsConnection")
.field("connection id", &self.uuid)
.field("pravega endpoint", &self.endpoint)
.finish()
}
}
#[async_trait]
pub trait ConnectionReadHalf: Send + Sync {
async fn read_async(&mut self, buf: &mut [u8]) -> Result<(), ConnectionError>;
fn get_id(&self) -> Uuid;
}
pub struct ConnectionReadHalfTokio {
uuid: Uuid,
endpoint: PravegaNodeUri,
read_half: Option<ReadHalf<TcpStream>>,
}
#[async_trait]
impl ConnectionReadHalf for ConnectionReadHalfTokio {
async fn read_async(&mut self, buf: &mut [u8]) -> Result<(), ConnectionError> {
let endpoint = self.endpoint.clone();
if let Some(ref mut reader) = self.read_half {
reader.read_exact(buf).await.context(ReadData { endpoint })?;
} else {
panic!("should not try to read when read half is gone");
}
Ok(())
}
fn get_id(&self) -> Uuid {
self.uuid
}
}
pub struct ConnectionReadHalfTls {
uuid: Uuid,
endpoint: PravegaNodeUri,
read_half: Option<ReadHalf<TlsStream<TcpStream>>>,
}
#[async_trait]
impl ConnectionReadHalf for ConnectionReadHalfTls {
async fn read_async(&mut self, buf: &mut [u8]) -> Result<(), ConnectionError> {
let endpoint = self.endpoint.clone();
if let Some(ref mut reader) = self.read_half {
reader.read_exact(buf).await.context(ReadData { endpoint })?;
} else {
panic!("should not try to read when read half is gone");
}
Ok(())
}
fn get_id(&self) -> Uuid {
self.uuid
}
}
#[async_trait]
pub trait ConnectionWriteHalf: Send + Sync + Debug {
async fn send_async(&mut self, payload: &[u8]) -> Result<(), ConnectionError>;
fn get_id(&self) -> Uuid;
}
#[derive(Debug)]
pub struct ConnectionWriteHalfTokio {
uuid: Uuid,
endpoint: PravegaNodeUri,
write_half: Option<WriteHalf<TcpStream>>,
}
#[async_trait]
impl ConnectionWriteHalf for ConnectionWriteHalfTokio {
async fn send_async(&mut self, payload: &[u8]) -> Result<(), ConnectionError> {
let endpoint = self.endpoint.clone();
if let Some(ref mut writer) = self.write_half {
writer.write_all(payload).await.context(SendData { endpoint })?;
} else {
panic!("should not try to write when write half is gone");
}
Ok(())
}
fn get_id(&self) -> Uuid {
self.uuid
}
}
#[derive(Debug)]
pub struct ConnectionWriteHalfTls {
uuid: Uuid,
endpoint: PravegaNodeUri,
write_half: Option<WriteHalf<TlsStream<TcpStream>>>,
}
#[async_trait]
impl ConnectionWriteHalf for ConnectionWriteHalfTls {
async fn send_async(&mut self, payload: &[u8]) -> Result<(), ConnectionError> {
let endpoint = self.endpoint.clone();
if let Some(ref mut writer) = self.write_half {
writer.write_all(payload).await.context(SendData {
endpoint: endpoint.clone(),
})?;
writer.flush().await.context(SendData { endpoint })?;
} else {
panic!("should not try to write when write half is gone");
}
Ok(())
}
fn get_id(&self) -> Uuid {
self.uuid
}
}
pub trait Validate {
fn is_valid(&self) -> bool;
}
impl Validate for TcpStream {
fn is_valid(&self) -> bool {
self.peer_addr().map_or_else(|_e| false, |_addr| true)
}
}
impl Validate for TlsStream<TcpStream> {
fn is_valid(&self) -> bool {
let (io, _session) = self.get_ref();
io.peer_addr().map_or_else(|_e| false, |_addr| true)
}
}