use std::{io, sync::Arc, time::Duration};
use bincode::Encode;
use futures::future::BoxFuture;
use log::error;
use tokio::{net::TcpStream, time::timeout};
use super::{ClientError, ClientPreambleFailureHandler, ClientPreambleSuccessHandler};
use crate::preamble::write_preamble;
pub(crate) struct PreambleConfig<P> {
pub(crate) preamble: P,
pub(crate) on_success: Option<ClientPreambleSuccessHandler<P>>,
pub(crate) on_failure: Option<ClientPreambleFailureHandler>,
pub(crate) timeout: Option<Duration>,
}
impl<P> Clone for PreambleConfig<P>
where
P: Clone,
{
fn clone(&self) -> Self {
Self {
preamble: self.preamble.clone(),
on_success: self.on_success.clone(),
on_failure: self.on_failure.clone(),
timeout: self.timeout,
}
}
}
impl<P> PreambleConfig<P>
where
P: Encode + Send + Sync + 'static,
{
pub(crate) fn new(preamble: P) -> Self {
Self {
preamble,
on_success: None,
on_failure: None,
timeout: None,
}
}
pub(crate) fn set_timeout(&mut self, timeout: Duration) { self.timeout = Some(timeout); }
pub(crate) fn set_success_handler<H>(&mut self, handler: H)
where
H: for<'a> Fn(&'a P, &'a mut TcpStream) -> BoxFuture<'a, io::Result<Vec<u8>>>
+ Send
+ Sync
+ 'static,
{
self.on_success = Some(Arc::new(handler));
}
pub(crate) fn set_failure_handler<H>(&mut self, handler: H)
where
H: for<'a> Fn(&'a ClientError, &'a mut TcpStream) -> BoxFuture<'a, io::Result<()>>
+ Send
+ Sync
+ 'static,
{
self.on_failure = Some(Arc::new(handler));
}
}
pub(crate) async fn perform_preamble_exchange<P>(
stream: &mut TcpStream,
config: PreambleConfig<P>,
) -> Result<Vec<u8>, ClientError>
where
P: Encode + Send + Sync + 'static,
{
let PreambleConfig {
preamble,
on_success,
on_failure,
timeout: preamble_timeout,
} = config;
let result = run_preamble_exchange(stream, &preamble, on_success, preamble_timeout).await;
if let Err(ref err) = result {
invoke_failure_handler(stream, err, on_failure.as_ref()).await;
}
result
}
async fn run_preamble_exchange<P>(
stream: &mut TcpStream,
preamble: &P,
on_success: Option<ClientPreambleSuccessHandler<P>>,
preamble_timeout: Option<Duration>,
) -> Result<Vec<u8>, ClientError>
where
P: Encode + Send + Sync + 'static,
{
let exchange = async {
write_preamble(stream, preamble)
.await
.map_err(ClientError::PreambleEncode)?;
match on_success.as_ref() {
Some(handler) => handler(preamble, stream)
.await
.map_err(ClientError::PreambleRead),
None => Ok(Vec::new()),
}
};
match preamble_timeout {
Some(limit) => timeout(limit, exchange)
.await
.unwrap_or(Err(ClientError::PreambleTimeout)),
None => exchange.await,
}
}
async fn invoke_failure_handler(
stream: &mut TcpStream,
err: &ClientError,
on_failure: Option<&ClientPreambleFailureHandler>,
) {
if let Some(handler) = on_failure
&& let Err(e) = handler(err, stream).await
{
error!("preamble failure handler error: {e}");
}
}