use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use bytes::Bytes;
use super::{Transport, TransportKind};
use crate::error::{SrxError, TransportError};
pub struct TimeoutTransport {
inner: Arc<dyn Transport>,
send_timeout: Duration,
recv_timeout: Duration,
}
impl TimeoutTransport {
pub fn new(inner: Arc<dyn Transport>, timeout: Duration) -> Self {
Self {
inner,
send_timeout: timeout,
recv_timeout: timeout,
}
}
pub fn with_separate_timeouts(
inner: Arc<dyn Transport>,
send_timeout: Duration,
recv_timeout: Duration,
) -> Self {
Self {
inner,
send_timeout,
recv_timeout,
}
}
pub fn send_timeout(&self) -> Duration {
self.send_timeout
}
pub fn recv_timeout(&self) -> Duration {
self.recv_timeout
}
}
#[async_trait]
impl Transport for TimeoutTransport {
fn kind(&self) -> TransportKind {
self.inner.kind()
}
async fn send(&self, data: Bytes) -> crate::error::Result<()> {
tokio::time::timeout(self.send_timeout, self.inner.send(data))
.await
.map_err(|_| {
SrxError::Transport(TransportError::Timeout {
transport: format!("{:?}", self.inner.kind()),
details: format!("send timed out after {:?}", self.send_timeout),
})
})?
}
async fn recv(&self) -> crate::error::Result<Bytes> {
tokio::time::timeout(self.recv_timeout, self.inner.recv())
.await
.map_err(|_| {
SrxError::Transport(TransportError::Timeout {
transport: format!("{:?}", self.inner.kind()),
details: format!("recv timed out after {:?}", self.recv_timeout),
})
})?
}
async fn is_healthy(&self) -> bool {
self.inner.is_healthy().await
}
async fn close(&self) -> crate::error::Result<()> {
self.inner.close().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
struct HangingTransport;
#[async_trait]
impl Transport for HangingTransport {
fn kind(&self) -> TransportKind {
TransportKind::Tcp
}
async fn send(&self, _data: Bytes) -> crate::error::Result<()> {
std::future::pending().await
}
async fn recv(&self) -> crate::error::Result<Bytes> {
std::future::pending().await
}
async fn is_healthy(&self) -> bool {
true
}
async fn close(&self) -> crate::error::Result<()> {
Ok(())
}
}
struct InstantTransport {
send_count: AtomicU32,
}
impl InstantTransport {
fn new() -> Self {
Self {
send_count: AtomicU32::new(0),
}
}
}
#[async_trait]
impl Transport for InstantTransport {
fn kind(&self) -> TransportKind {
TransportKind::Udp
}
async fn send(&self, _data: Bytes) -> crate::error::Result<()> {
self.send_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn recv(&self) -> crate::error::Result<Bytes> {
Ok(Bytes::from_static(b"instant"))
}
async fn is_healthy(&self) -> bool {
true
}
async fn close(&self) -> crate::error::Result<()> {
Ok(())
}
}
#[tokio::test]
async fn send_timeout_triggers() {
let t = TimeoutTransport::new(Arc::new(HangingTransport), Duration::from_millis(10));
let err = t.send(Bytes::from_static(b"test")).await;
assert!(err.is_err());
let msg = format!("{}", err.unwrap_err());
assert!(msg.contains("timed out"));
}
#[tokio::test]
async fn recv_timeout_triggers() {
let t = TimeoutTransport::new(Arc::new(HangingTransport), Duration::from_millis(10));
let err = t.recv().await;
assert!(err.is_err());
}
#[tokio::test]
async fn fast_operation_succeeds() {
let t = TimeoutTransport::new(Arc::new(InstantTransport::new()), Duration::from_secs(5));
t.send(Bytes::from_static(b"ok")).await.unwrap();
let data = t.recv().await.unwrap();
assert_eq!(data.as_ref(), b"instant");
}
#[tokio::test]
async fn kind_delegates() {
let t = TimeoutTransport::new(Arc::new(InstantTransport::new()), Duration::from_secs(1));
assert_eq!(t.kind(), TransportKind::Udp);
}
#[tokio::test]
async fn separate_timeouts() {
let t = TimeoutTransport::with_separate_timeouts(
Arc::new(HangingTransport),
Duration::from_millis(5),
Duration::from_millis(50),
);
assert_eq!(t.send_timeout(), Duration::from_millis(5));
assert_eq!(t.recv_timeout(), Duration::from_millis(50));
let err = t.send(Bytes::from_static(b"x")).await;
assert!(err.is_err());
}
#[tokio::test]
async fn is_healthy_delegates() {
let t = TimeoutTransport::new(Arc::new(InstantTransport::new()), Duration::from_secs(1));
assert!(t.is_healthy().await);
}
#[tokio::test]
async fn close_delegates() {
let t = TimeoutTransport::new(Arc::new(InstantTransport::new()), Duration::from_secs(1));
t.close().await.unwrap();
}
}