use std::marker::PhantomData;
use std::sync::Arc;
use std::time::Duration;
use layered::Layer;
use crate::timeout::*;
use crate::typestates::{NotSet, Set};
use crate::utils::{EnableIf, TelemetryHelper};
use crate::{ResilienceContext, TelemetryString};
#[derive(Debug)]
pub struct TimeoutLayer<In, Out, S1 = Set, S2 = Set> {
context: ResilienceContext<In, Out>,
timeout: Option<Duration>,
timeout_output: Option<TimeoutOutput<Out>>,
on_timeout: Option<OnTimeout<Out>>,
enable_if: EnableIf<In>,
telemetry: TelemetryHelper,
timeout_override: Option<TimeoutOverride<In>>,
_state: PhantomData<fn(In, S1, S2) -> Out>,
}
impl<In, Out> TimeoutLayer<In, Out, NotSet, NotSet> {
#[must_use]
pub(crate) fn new(name: TelemetryString, context: &ResilienceContext<In, Out>) -> Self {
Self {
timeout: None,
timeout_output: None,
on_timeout: None,
enable_if: EnableIf::default(),
telemetry: context.create_telemetry(name),
context: context.clone(),
timeout_override: None,
_state: PhantomData,
}
}
}
impl<In, Out, E, S1, S2> TimeoutLayer<In, Result<Out, E>, S1, S2> {
pub fn timeout_error(
self,
timeout_error: impl Fn(TimeoutOutputArgs) -> E + Send + Sync + 'static,
) -> TimeoutLayer<In, Result<Out, E>, S1, Set> {
self.into_state::<Set, S2>()
.timeout_output(move |args| Err(timeout_error(args)))
.into_state()
}
}
impl<In, Out, S1, S2> TimeoutLayer<In, Out, S1, S2> {
#[must_use]
pub fn timeout(mut self, timeout: Duration) -> TimeoutLayer<In, Out, Set, S2> {
self.timeout = Some(timeout);
self.into_state::<Set, S2>()
}
#[must_use]
pub fn config(self, config: &TimeoutConfig) -> TimeoutLayer<In, Out, Set, S2> {
self.timeout(config.timeout).enable(config.enabled)
}
#[must_use]
pub fn timeout_output(mut self, output: impl Fn(TimeoutOutputArgs) -> Out + Send + Sync + 'static) -> TimeoutLayer<In, Out, S1, Set> {
self.timeout_output = Some(TimeoutOutput::new(output));
self.into_state::<S1, Set>()
}
#[must_use]
pub fn on_timeout(mut self, on_timeout: impl Fn(&Out, OnTimeoutArgs) + Send + Sync + 'static) -> Self {
self.on_timeout = Some(OnTimeout::new(on_timeout));
self
}
#[must_use]
pub fn timeout_override(
mut self,
timeout_override: impl Fn(&In, TimeoutOverrideArgs) -> Option<Duration> + Send + Sync + 'static,
) -> Self {
self.timeout_override = Some(TimeoutOverride::new(timeout_override));
self
}
#[must_use]
pub fn enable_if(mut self, is_enabled: impl Fn(&In) -> bool + Send + Sync + 'static) -> Self {
self.enable_if = EnableIf::custom(is_enabled);
self
}
#[must_use]
fn enable(mut self, enabled: bool) -> Self {
self.enable_if = EnableIf::new(enabled);
self
}
#[must_use]
pub fn enable_always(self) -> Self {
self.enable(true)
}
#[must_use]
pub fn disable(self) -> Self {
self.enable(false)
}
}
impl<In, Out, S> Layer<S> for TimeoutLayer<In, Out, Set, Set> {
type Service = Timeout<In, Out, S>;
fn layer(&self, inner: S) -> Self::Service {
let shared = TimeoutShared {
clock: self.context.get_clock().clone(),
timeout: self.timeout.expect("enforced by the type state pattern"),
enable_if: self.enable_if.clone(),
on_timeout: self.on_timeout.clone(),
timeout_output: self.timeout_output.clone().expect("enforced by the type state pattern"),
timeout_override: self.timeout_override.clone(),
#[cfg(any(feature = "logs", feature = "metrics", test))]
telemetry: self.telemetry.clone(),
};
Timeout {
shared: Arc::new(shared),
inner,
}
}
}
impl<In, Out, S1, S2> TimeoutLayer<In, Out, S1, S2> {
fn into_state<T1, T2>(self) -> TimeoutLayer<In, Out, T1, T2> {
TimeoutLayer {
timeout: self.timeout,
enable_if: self.enable_if,
timeout_output: self.timeout_output,
on_timeout: self.on_timeout,
telemetry: self.telemetry,
context: self.context,
timeout_override: self.timeout_override,
_state: PhantomData,
}
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod tests {
use std::fmt::Debug;
use std::sync::atomic::{AtomicBool, Ordering};
use layered::Execute;
use tick::Clock;
use super::*;
#[cfg_attr(miri, ignore)]
#[test]
fn new_needs_timeout_output() {
let layer = create_ready_layer();
insta::assert_debug_snapshot!(layer);
}
#[test]
fn timeout_output_ensure_set_correctly() {
let context = create_test_context();
let layer = TimeoutLayer::new("test".into(), &context);
let layer: TimeoutLayer<_, _, NotSet, Set> = layer.timeout_output(|args| format!("timeout: {}", args.timeout().as_millis()));
let result = layer.timeout_output.unwrap().call(TimeoutOutputArgs {
timeout: Duration::from_millis(3),
});
assert_eq!(result, "timeout: 3");
}
#[test]
fn timeout_error_ensure_set_correctly() {
let context = create_test_context_result();
let layer = TimeoutLayer::new("test".into(), &context);
let layer: TimeoutLayer<_, _, NotSet, Set> = layer.timeout_error(|args| format!("timeout: {}", args.timeout().as_millis()));
let result = layer
.timeout_output
.unwrap()
.call(TimeoutOutputArgs {
timeout: Duration::from_millis(3),
})
.unwrap_err();
assert_eq!(result, "timeout: 3");
}
#[test]
fn timeout_ensure_set_correctly() {
let layer: TimeoutLayer<_, _, Set, Set> = TimeoutLayer::new("test".into(), &create_test_context())
.timeout_output(|_args| "timeout: ".to_string())
.timeout(Duration::from_millis(3));
assert_eq!(layer.timeout.unwrap(), Duration::from_millis(3));
}
#[test]
fn on_timeout_ok() {
let called = Arc::new(AtomicBool::new(false));
let called_clone = Arc::clone(&called);
let layer: TimeoutLayer<_, _, Set, Set> = create_ready_layer().on_timeout(move |_output, _args| {
called_clone.store(true, Ordering::SeqCst);
});
layer.on_timeout.unwrap().call(
&"output".to_string(),
OnTimeoutArgs {
timeout: Duration::from_millis(3),
},
);
assert!(called.load(Ordering::SeqCst));
}
#[test]
fn timeout_override_ok() {
let layer: TimeoutLayer<_, _, Set, Set> = create_ready_layer().timeout_override(|_input, _args| Some(Duration::from_secs(3)));
let result = layer.timeout_override.unwrap().call(
&"a".to_string(),
TimeoutOverrideArgs {
default_timeout: Duration::from_millis(3),
},
);
assert_eq!(result, Some(Duration::from_secs(3)));
}
#[test]
fn enable_if_ok() {
let layer: TimeoutLayer<_, _, Set, Set> = create_ready_layer().enable_if(|input| matches!(input.as_ref(), "enable"));
assert!(layer.enable_if.call(&"enable".to_string()));
assert!(!layer.enable_if.call(&"disable".to_string()));
}
#[test]
fn disable_ok() {
let layer: TimeoutLayer<_, _, Set, Set> = create_ready_layer().disable();
assert!(!layer.enable_if.call(&"whatever".to_string()));
}
#[test]
fn enable_ok() {
let layer: TimeoutLayer<_, _, Set, Set> = create_ready_layer().disable().enable_always();
assert!(layer.enable_if.call(&"whatever".to_string()));
}
#[test]
fn timeout_when_ready_ok() {
let layer: TimeoutLayer<_, _, Set, Set> = create_ready_layer().timeout(Duration::from_secs(123));
assert_eq!(layer.timeout.unwrap(), Duration::from_secs(123));
}
#[test]
fn timeout_output_when_ready_ok() {
let layer: TimeoutLayer<_, _, Set, Set> = create_ready_layer().timeout_output(|_args| "some new value".to_string());
assert!(layer.timeout_output.is_some());
let result = layer.timeout_output.unwrap().call(TimeoutOutputArgs {
timeout: Duration::from_secs(123),
});
assert_eq!(result, "some new value");
}
#[test]
fn timeout_error_when_ready_ok() {
let layer: TimeoutLayer<_, _, Set, Set> = create_ready_layer_with_result().timeout_error(|_args| "some error value".to_string());
assert!(layer.timeout_output.is_some());
let result = layer
.timeout_output
.unwrap()
.call(TimeoutOutputArgs {
timeout: Duration::from_secs(123),
})
.unwrap_err();
assert_eq!(result, "some error value");
}
#[test]
fn layer_ok() {
let _layered = create_ready_layer().layer(Execute::new(|input: String| async move { input }));
}
#[cfg_attr(miri, ignore)]
#[test]
fn config_applies_all_settings() {
let config = TimeoutConfig {
enabled: false,
timeout: Duration::from_secs(45),
};
let context = create_test_context();
let layer = TimeoutLayer::new("test".into(), &context)
.timeout_output(|_args| "timeout".to_string())
.config(&config);
insta::assert_debug_snapshot!(layer);
}
#[test]
fn static_assertions() {
static_assertions::assert_impl_all!(TimeoutLayer<String, String, Set, Set>: Layer<String>);
static_assertions::assert_not_impl_all!(TimeoutLayer<String, String, Set, NotSet>: Layer<String>);
static_assertions::assert_not_impl_all!(TimeoutLayer<String, String, NotSet, Set>: Layer<String>);
static_assertions::assert_impl_all!(TimeoutLayer<String, String, Set, Set>: Debug);
}
fn create_test_context() -> ResilienceContext<String, String> {
ResilienceContext::new(Clock::new_frozen()).name("test_pipeline")
}
fn create_test_context_result() -> ResilienceContext<String, Result<String, String>> {
ResilienceContext::new(Clock::new_frozen()).name("test_pipeline")
}
fn create_ready_layer() -> TimeoutLayer<String, String, Set, Set> {
TimeoutLayer::new("test".into(), &create_test_context())
.timeout_output(|_args| "timeout: ".to_string())
.timeout(Duration::from_millis(3))
}
fn create_ready_layer_with_result() -> TimeoutLayer<String, Result<String, String>, Set, Set> {
TimeoutLayer::new("test".into(), &create_test_context_result())
.timeout_error(|_args| "timeout: ".to_string())
.timeout(Duration::from_millis(3))
}
}