use crate::{
client::utils,
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
data::{Request, Response},
net::{DataStream, SecretKey, Transport, TransportError, TransportWriteHalf},
};
use log::*;
use std::{
collections::HashMap,
convert,
sync::{Arc, Mutex},
};
use tokio::{
io,
net::TcpStream,
sync::{mpsc, oneshot},
task::{JoinError, JoinHandle},
time::Duration,
};
mod info;
pub use info::{SessionInfo, SessionInfoFile, SessionInfoParseError};
type Callbacks = Arc<Mutex<HashMap<usize, oneshot::Sender<Response>>>>;
pub struct Session<T>
where
T: DataStream,
{
t_write: TransportWriteHalf<T::Write>,
callbacks: Callbacks,
response_task: JoinHandle<()>,
pub broadcast: Option<mpsc::Receiver<Response>>,
}
impl Session<TcpStream> {
pub async fn tcp_connect(info: SessionInfo) -> io::Result<Self> {
let addr = info.to_socket_addr().await?;
let transport =
Transport::<TcpStream>::connect(addr, Some(Arc::new(info.auth_key))).await?;
debug!(
"Session has been established with {}",
transport
.peer_addr()
.map(|x| x.to_string())
.unwrap_or_else(|_| String::from("???"))
);
Self::initialize(transport)
}
pub async fn tcp_connect_timeout(info: SessionInfo, duration: Duration) -> io::Result<Self> {
utils::timeout(duration, Self::tcp_connect(info))
.await
.and_then(convert::identity)
}
}
#[cfg(unix)]
impl Session<tokio::net::UnixStream> {
pub async fn unix_connect(
path: impl AsRef<std::path::Path>,
auth_key: Option<Arc<SecretKey>>,
) -> io::Result<Self> {
let transport = Transport::<tokio::net::UnixStream>::connect(path, auth_key).await?;
debug!(
"Session has been established with {}",
transport
.peer_addr()
.map(|x| format!("{:?}", x))
.unwrap_or_else(|_| String::from("???"))
);
Self::initialize(transport)
}
pub async fn unix_connect_timeout(
path: impl AsRef<std::path::Path>,
auth_key: Option<Arc<SecretKey>>,
duration: Duration,
) -> io::Result<Self> {
utils::timeout(duration, Self::unix_connect(path, auth_key))
.await
.and_then(convert::identity)
}
}
impl<T> Session<T>
where
T: DataStream,
{
pub fn initialize(transport: Transport<T>) -> io::Result<Self> {
let (mut t_read, t_write) = transport.into_split();
let callbacks: Callbacks = Arc::new(Mutex::new(HashMap::new()));
let (broadcast_tx, broadcast_rx) = mpsc::channel(CLIENT_BROADCAST_CHANNEL_CAPACITY);
let callbacks_2 = Arc::clone(&callbacks);
let response_task = tokio::spawn(async move {
loop {
match t_read.receive::<Response>().await {
Ok(Some(res)) => {
trace!("Incoming response: {:?}", res);
let maybe_callback = res
.origin_id
.as_ref()
.and_then(|id| callbacks_2.lock().unwrap().remove(id));
if let Some(tx) = maybe_callback {
trace!("Callback exists for response! Triggering!");
if let Err(res) = tx.send(res) {
error!("Failed to trigger callback for response {}", res.id);
}
} else {
trace!("Callback missing for response! Broadcasting!");
if let Err(x) = broadcast_tx.send(res).await {
error!("Failed to trigger broadcast: {}", x);
}
}
}
Ok(None) => {
debug!("Session closing response task as transport read-half closed!");
break;
}
Err(x) => {
error!("{}", x);
break;
}
}
}
});
Ok(Self {
t_write,
callbacks,
broadcast: Some(broadcast_rx),
response_task,
})
}
pub async fn wait(self) -> Result<(), JoinError> {
self.response_task.await
}
pub fn abort(&self) {
self.response_task.abort()
}
pub async fn send(&mut self, req: Request) -> Result<Response, TransportError> {
trace!("Sending request: {:?}", req);
let (tx, rx) = oneshot::channel();
self.callbacks.lock().unwrap().insert(req.id, tx);
self.t_write.send(req).await?;
rx.await
.map_err(|x| TransportError::from(io::Error::new(io::ErrorKind::ConnectionAborted, x)))
}
pub async fn send_timeout(
&mut self,
req: Request,
duration: Duration,
) -> Result<Response, TransportError> {
utils::timeout(duration, self.send(req))
.await
.map_err(TransportError::from)
.and_then(convert::identity)
}
pub async fn fire(&mut self, req: Request) -> Result<(), TransportError> {
trace!("Firing off request: {:?}", req);
self.t_write.send(req).await
}
pub async fn fire_timeout(
&mut self,
req: Request,
duration: Duration,
) -> Result<(), TransportError> {
utils::timeout(duration, self.fire(req))
.await
.map_err(TransportError::from)
.and_then(convert::identity)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
constants::test::TENANT,
data::{RequestData, ResponseData},
};
use std::time::Duration;
#[tokio::test]
async fn send_should_wait_until_response_received() {
let (t1, mut t2) = Transport::make_pair();
let mut session = Session::initialize(t1).unwrap();
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
let res = Response::new(
TENANT,
Some(req.id),
vec![ResponseData::ProcEntries {
entries: Vec::new(),
}],
);
let (actual, _) = tokio::join!(session.send(req), t2.send(res.clone()));
match actual {
Ok(actual) => assert_eq!(actual, res),
x => panic!("Unexpected response: {:?}", x),
}
}
#[tokio::test]
async fn send_timeout_should_fail_if_response_not_received_in_time() {
let (t1, mut t2) = Transport::make_pair();
let mut session = Session::initialize(t1).unwrap();
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
match session.send_timeout(req, Duration::from_millis(30)).await {
Err(TransportError::IoError(x)) => assert_eq!(x.kind(), io::ErrorKind::TimedOut),
x => panic!("Unexpected response: {:?}", x),
}
let req = t2.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.tenant, TENANT);
}
#[tokio::test]
async fn fire_should_send_request_and_not_wait_for_response() {
let (t1, mut t2) = Transport::make_pair();
let mut session = Session::initialize(t1).unwrap();
let req = Request::new(TENANT, vec![RequestData::ProcList {}]);
match session.fire(req).await {
Ok(_) => {}
x => panic!("Unexpected response: {:?}", x),
}
let req = t2.receive::<Request>().await.unwrap().unwrap();
assert_eq!(req.tenant, TENANT);
}
}