use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use futures::stream::{self, StreamExt};
use serde_json::Value;
use crate::error::{CognisError, Result};
use super::base::Runnable;
use super::config::RunnableConfig;
use super::RunnableStream;
pub enum TimeoutBehavior {
Error,
Default(Value),
Fallback(Arc<dyn Runnable>),
}
pub struct TimeoutConfig {
pub duration: Duration,
pub on_timeout: TimeoutBehavior,
}
impl TimeoutConfig {
pub fn new(duration: Duration) -> Self {
Self {
duration,
on_timeout: TimeoutBehavior::Error,
}
}
pub fn with_behavior(mut self, behavior: TimeoutBehavior) -> Self {
self.on_timeout = behavior;
self
}
pub fn with_default_value(self, value: Value) -> Self {
self.with_behavior(TimeoutBehavior::Default(value))
}
pub fn with_fallback(self, fallback: Arc<dyn Runnable>) -> Self {
self.with_behavior(TimeoutBehavior::Fallback(fallback))
}
}
pub struct RunnableTimeout {
inner: Arc<dyn Runnable>,
config: TimeoutConfig,
}
impl RunnableTimeout {
pub fn new(inner: Arc<dyn Runnable>, config: TimeoutConfig) -> Self {
Self { inner, config }
}
async fn handle_timeout(&self, input: Value, config: Option<&RunnableConfig>) -> Result<Value> {
match &self.config.on_timeout {
TimeoutBehavior::Error => Err(CognisError::Other(format!(
"Timeout: {} exceeded {:?}",
self.inner.name(),
self.config.duration
))),
TimeoutBehavior::Default(val) => Ok(val.clone()),
TimeoutBehavior::Fallback(fallback) => fallback.invoke(input, config).await,
}
}
}
#[async_trait]
impl Runnable for RunnableTimeout {
fn name(&self) -> &str {
self.inner.name()
}
async fn invoke(&self, input: Value, config: Option<&RunnableConfig>) -> Result<Value> {
match tokio::time::timeout(
self.config.duration,
self.inner.invoke(input.clone(), config),
)
.await
{
Ok(result) => result,
Err(_elapsed) => self.handle_timeout(input, config).await,
}
}
async fn stream(
&self,
input: Value,
config: Option<&RunnableConfig>,
) -> Result<RunnableStream> {
let duration = self.config.duration;
match tokio::time::timeout(duration, self.inner.stream(input.clone(), config)).await {
Ok(Ok(inner_stream)) => {
let deadline = tokio::time::Instant::now() + duration;
let wrapped = inner_stream.take_while(move |_| {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
async move { remaining > Duration::ZERO }
});
Ok(Box::pin(wrapped))
}
Ok(Err(e)) => Err(e),
Err(_elapsed) => {
match &self.config.on_timeout {
TimeoutBehavior::Error => Err(CognisError::Other(format!(
"Timeout: {} stream exceeded {:?}",
self.inner.name(),
self.config.duration
))),
TimeoutBehavior::Default(val) => {
let val = val.clone();
Ok(Box::pin(stream::once(async move { Ok(val) })))
}
TimeoutBehavior::Fallback(fallback) => fallback.stream(input, config).await,
}
}
}
}
}
pub struct RunnableDeadline {
inner: Arc<dyn Runnable>,
deadline: Instant,
}
impl RunnableDeadline {
pub fn new(inner: Arc<dyn Runnable>, deadline: Instant) -> Self {
Self { inner, deadline }
}
fn remaining(&self) -> Duration {
self.deadline.saturating_duration_since(Instant::now())
}
}
#[async_trait]
impl Runnable for RunnableDeadline {
fn name(&self) -> &str {
self.inner.name()
}
async fn invoke(&self, input: Value, config: Option<&RunnableConfig>) -> Result<Value> {
let remaining = self.remaining();
if remaining.is_zero() {
return Err(CognisError::Other(format!(
"Deadline exceeded for {}",
self.inner.name()
)));
}
match tokio::time::timeout(remaining, self.inner.invoke(input, config)).await {
Ok(result) => result,
Err(_elapsed) => Err(CognisError::Other(format!(
"Deadline exceeded for {}",
self.inner.name()
))),
}
}
async fn stream(
&self,
input: Value,
config: Option<&RunnableConfig>,
) -> Result<RunnableStream> {
let remaining = self.remaining();
if remaining.is_zero() {
return Err(CognisError::Other(format!(
"Deadline exceeded for {} stream",
self.inner.name()
)));
}
match tokio::time::timeout(remaining, self.inner.stream(input, config)).await {
Ok(Ok(inner_stream)) => {
let deadline = self.deadline;
let wrapped = inner_stream.take_while(move |_| {
let now = Instant::now();
async move { now < deadline }
});
Ok(Box::pin(wrapped))
}
Ok(Err(e)) => Err(e),
Err(_elapsed) => Err(CognisError::Other(format!(
"Deadline exceeded for {} stream",
self.inner.name()
))),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runnables::lambda::RunnableLambda;
use futures::StreamExt;
use serde_json::json;
fn fast_double() -> RunnableLambda {
RunnableLambda::new("fast_double", |v: Value| async move {
let n = v.as_i64().unwrap();
Ok(json!(n * 2))
})
}
fn slow_identity(delay_ms: u64) -> RunnableLambda {
RunnableLambda::new("slow_identity", move |v: Value| async move {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
Ok(v)
})
}
fn slow_double(delay_ms: u64) -> RunnableLambda {
RunnableLambda::new("slow_double", move |v: Value| async move {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
let n = v.as_i64().unwrap();
Ok(json!(n * 2))
})
}
fn error_runnable() -> RunnableLambda {
RunnableLambda::new("error", |_v: Value| async move {
Err(CognisError::Other("inner error".to_string()))
})
}
#[test]
fn test_timeout_config_defaults_to_error_behavior() {
let config = TimeoutConfig::new(Duration::from_secs(5));
assert_eq!(config.duration, Duration::from_secs(5));
assert!(matches!(config.on_timeout, TimeoutBehavior::Error));
}
#[test]
fn test_timeout_config_with_default_value() {
let config =
TimeoutConfig::new(Duration::from_secs(1)).with_default_value(json!("fallback"));
assert!(matches!(config.on_timeout, TimeoutBehavior::Default(_)));
if let TimeoutBehavior::Default(v) = &config.on_timeout {
assert_eq!(v, &json!("fallback"));
}
}
#[test]
fn test_timeout_config_with_fallback() {
let fb = Arc::new(fast_double()) as Arc<dyn Runnable>;
let config = TimeoutConfig::new(Duration::from_secs(1)).with_fallback(fb);
assert!(matches!(config.on_timeout, TimeoutBehavior::Fallback(_)));
}
#[tokio::test]
async fn test_timeout_fast_operation_succeeds() {
let inner = fast_double();
let config = TimeoutConfig::new(Duration::from_secs(1));
let timed = RunnableTimeout::new(Arc::new(inner), config);
let result = timed.invoke(json!(5), None).await.unwrap();
assert_eq!(result, json!(10));
}
#[tokio::test]
async fn test_timeout_slow_operation_errors() {
let inner = slow_identity(200);
let config = TimeoutConfig::new(Duration::from_millis(50));
let timed = RunnableTimeout::new(Arc::new(inner), config);
let result = timed.invoke(json!(1), None).await;
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("Timeout"));
}
#[tokio::test]
async fn test_timeout_slow_operation_returns_default() {
let inner = slow_identity(200);
let config =
TimeoutConfig::new(Duration::from_millis(50)).with_default_value(json!("default"));
let timed = RunnableTimeout::new(Arc::new(inner), config);
let result = timed.invoke(json!(1), None).await.unwrap();
assert_eq!(result, json!("default"));
}
#[tokio::test]
async fn test_timeout_slow_operation_uses_fallback() {
let inner = slow_identity(200);
let fallback = Arc::new(fast_double()) as Arc<dyn Runnable>;
let config = TimeoutConfig::new(Duration::from_millis(50)).with_fallback(fallback);
let timed = RunnableTimeout::new(Arc::new(inner), config);
let result = timed.invoke(json!(7), None).await.unwrap();
assert_eq!(result, json!(14)); }
#[tokio::test]
async fn test_timeout_name_delegates() {
let inner = fast_double();
let config = TimeoutConfig::new(Duration::from_secs(1));
let timed = RunnableTimeout::new(Arc::new(inner), config);
assert_eq!(timed.name(), "fast_double");
}
#[tokio::test]
async fn test_timeout_inner_error_propagated() {
let inner = error_runnable();
let config = TimeoutConfig::new(Duration::from_secs(1));
let timed = RunnableTimeout::new(Arc::new(inner), config);
let result = timed.invoke(json!(1), None).await;
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("inner error"));
}
#[tokio::test]
async fn test_timeout_just_within_limit() {
let inner = slow_double(30);
let config = TimeoutConfig::new(Duration::from_millis(100));
let timed = RunnableTimeout::new(Arc::new(inner), config);
let result = timed.invoke(json!(3), None).await.unwrap();
assert_eq!(result, json!(6));
}
#[tokio::test]
async fn test_timeout_completes_quickly_when_fast() {
let inner = fast_double();
let config = TimeoutConfig::new(Duration::from_secs(10));
let timed = RunnableTimeout::new(Arc::new(inner), config);
let start = Instant::now();
let _result = timed.invoke(json!(1), None).await.unwrap();
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_millis(100),
"Fast operation should complete quickly, took {:?}",
elapsed
);
}
#[tokio::test]
async fn test_timeout_stream_fast_succeeds() {
let inner = fast_double();
let config = TimeoutConfig::new(Duration::from_secs(1));
let timed = RunnableTimeout::new(Arc::new(inner), config);
let mut stream = timed.stream(json!(4), None).await.unwrap();
let item = stream.next().await.unwrap().unwrap();
assert_eq!(item, json!(8));
}
#[tokio::test]
async fn test_timeout_stream_slow_errors() {
let inner = slow_identity(200);
let config = TimeoutConfig::new(Duration::from_millis(50));
let timed = RunnableTimeout::new(Arc::new(inner), config);
let result = timed.stream(json!(1), None).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_deadline_fast_operation_succeeds() {
let inner = fast_double();
let deadline = Instant::now() + Duration::from_secs(1);
let dl = RunnableDeadline::new(Arc::new(inner), deadline);
let result = dl.invoke(json!(5), None).await.unwrap();
assert_eq!(result, json!(10));
}
#[tokio::test]
async fn test_deadline_slow_operation_errors() {
let inner = slow_identity(200);
let deadline = Instant::now() + Duration::from_millis(50);
let dl = RunnableDeadline::new(Arc::new(inner), deadline);
let result = dl.invoke(json!(1), None).await;
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("Deadline exceeded"));
}
#[tokio::test]
async fn test_deadline_already_past() {
let inner = fast_double();
let deadline = Instant::now() - Duration::from_secs(1);
let dl = RunnableDeadline::new(Arc::new(inner), deadline);
let result = dl.invoke(json!(5), None).await;
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("Deadline exceeded"));
}
#[tokio::test]
async fn test_deadline_name_delegates() {
let inner = fast_double();
let deadline = Instant::now() + Duration::from_secs(1);
let dl = RunnableDeadline::new(Arc::new(inner), deadline);
assert_eq!(dl.name(), "fast_double");
}
#[tokio::test]
async fn test_deadline_stream_succeeds() {
let inner = fast_double();
let deadline = Instant::now() + Duration::from_secs(1);
let dl = RunnableDeadline::new(Arc::new(inner), deadline);
let mut stream = dl.stream(json!(6), None).await.unwrap();
let item = stream.next().await.unwrap().unwrap();
assert_eq!(item, json!(12));
}
#[tokio::test]
async fn test_deadline_stream_past_deadline_errors() {
let inner = fast_double();
let deadline = Instant::now() - Duration::from_secs(1);
let dl = RunnableDeadline::new(Arc::new(inner), deadline);
let result = dl.stream(json!(1), None).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_deadline_shared_across_operations() {
let deadline = Instant::now() + Duration::from_millis(150);
let dl1 = RunnableDeadline::new(Arc::new(slow_double(50)), deadline);
let dl2 = RunnableDeadline::new(Arc::new(slow_double(50)), deadline);
let r1 = dl1.invoke(json!(2), None).await.unwrap();
let r2 = dl2.invoke(json!(3), None).await.unwrap();
assert_eq!(r1, json!(4));
assert_eq!(r2, json!(6));
}
#[tokio::test]
async fn test_ext_with_timeout() {
use crate::runnables::ext::RunnableExt;
let timed = fast_double().with_timeout(Duration::from_secs(1));
let result = timed.invoke(json!(10), None).await.unwrap();
assert_eq!(result, json!(20));
}
#[tokio::test]
async fn test_ext_with_timeout_config() {
use crate::runnables::ext::RunnableExt;
let config =
TimeoutConfig::new(Duration::from_millis(50)).with_default_value(json!("timed_out"));
let timed = slow_identity(200).with_timeout_config(config);
let result = timed.invoke(json!(1), None).await.unwrap();
assert_eq!(result, json!("timed_out"));
}
}