use async_trait::async_trait;
use std::time::Duration;
use super::traits::{BoxHandshaker, HandshakeReceiver, HandshakeSender, Handshaker};
use crate::context::ConnectionContext;
use crate::error::HandshakeError;
pub struct ChainedHandshaker {
handshakers: Vec<BoxHandshaker>,
timeout: Option<Duration>,
}
impl ChainedHandshaker {
#[must_use]
pub fn new() -> Self {
Self {
handshakers: Vec::new(),
timeout: Some(Duration::from_secs(60)),
}
}
#[must_use]
pub fn then<H: Handshaker + 'static>(mut self, handshaker: H) -> Self {
self.handshakers.push(Box::new(handshaker));
self
}
#[must_use]
pub fn then_boxed(mut self, handshaker: BoxHandshaker) -> Self {
self.handshakers.push(handshaker);
self
}
#[must_use]
pub const fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
#[must_use]
pub const fn without_timeout(mut self) -> Self {
self.timeout = None;
self
}
#[must_use]
pub fn len(&self) -> usize {
self.handshakers.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.handshakers.is_empty()
}
}
impl Default for ChainedHandshaker {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Handshaker for ChainedHandshaker {
async fn handshake(
&self,
sender: &mut dyn HandshakeSender,
receiver: &mut dyn HandshakeReceiver,
context: &ConnectionContext,
) -> Result<(), HandshakeError> {
for (i, handshaker) in self.handshakers.iter().enumerate() {
tracing::debug!(
step = i + 1,
total = self.handshakers.len(),
handshaker = handshaker.name(),
"Executing handshake step"
);
handshaker
.handshake(sender, receiver, context)
.await
.map_err(|e| {
tracing::warn!(
step = i + 1,
handshaker = handshaker.name(),
error = ?e,
"Handshake step failed"
);
e
})?;
}
Ok(())
}
fn is_retryable(&self, error: &HandshakeError) -> bool {
error.is_retryable()
}
fn name(&self) -> &'static str {
"chained"
}
fn timeout(&self) -> Option<Duration> {
self.timeout
}
}
#[macro_export]
macro_rules! chain_handshakers {
($($handshaker:expr),+ $(,)?) => {
$crate::handshake::ChainedHandshaker::new()
$(.then($handshaker))+
};
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handshake::NoOpHandshaker;
#[test]
fn test_chained_handshaker_creation() {
let chained = ChainedHandshaker::new()
.then(NoOpHandshaker)
.then(NoOpHandshaker);
assert_eq!(chained.len(), 2);
assert!(!chained.is_empty());
}
#[test]
fn test_empty_chain() {
let chained = ChainedHandshaker::new();
assert!(chained.is_empty());
assert_eq!(chained.len(), 0);
}
}