seatbelt 0.4.4

Resilience and recovery mechanisms for fallible operations.
Documentation
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

use std::marker::PhantomData;
use std::sync::Arc;

use layered::Layer;

use crate::typestates::{NotSet, Set};
use crate::utils::{EnableIf, TelemetryHelper};
use crate::{ResilienceContext, TelemetryString, fallback::*};

/// Builder for configuring fallback resilience middleware.
///
/// This type is created by calling [`Fallback::layer`](crate::fallback::Fallback::layer)
/// and uses the type-state pattern to enforce that required properties are configured
/// before the layer can be built:
///
/// - [`should_fallback`][FallbackLayer::should_fallback]: Required predicate that decides
///   whether the inner service output needs a replacement
/// - [`fallback`][FallbackLayer::fallback] or [`fallback_async`][FallbackLayer::fallback_async]:
///   Required function that produces the replacement output
///
/// For comprehensive examples, see the [fallback module][crate::fallback] documentation.
///
/// # Type State
///
/// - `S1`: Tracks whether [`should_fallback`][FallbackLayer::should_fallback] has been set
/// - `S2`: Tracks whether [`fallback`][FallbackLayer::fallback] has been set
#[derive(Debug)]
pub struct FallbackLayer<In, Out, S1 = Set, S2 = Set> {
    should_fallback: Option<ShouldFallback<Out>>,
    fallback_action: Option<FallbackAction<Out>>,
    enable_if: EnableIf<In>,
    telemetry: TelemetryHelper,
    _state: PhantomData<fn(In, S1, S2) -> Out>,
}

impl<In, Out> FallbackLayer<In, Out, NotSet, NotSet> {
    #[must_use]
    pub(crate) fn new(name: TelemetryString, context: &ResilienceContext<In, Out>) -> Self {
        Self {
            should_fallback: None,
            fallback_action: None,
            enable_if: EnableIf::default(),
            telemetry: context.create_telemetry(name),
            _state: PhantomData,
        }
    }
}

impl<In, Out, S1, S2> FallbackLayer<In, Out, S1, S2> {
    /// Sets the predicate that decides whether the fallback should be invoked.
    ///
    /// The `predicate` receives a reference to the output produced by the inner
    /// service and returns `true` when the output is not considered valid and the
    /// fallback action should produce a replacement. This call replaces any
    /// previous predicate.
    #[must_use]
    pub fn should_fallback(mut self, predicate: impl Fn(&Out) -> bool + Send + Sync + 'static) -> FallbackLayer<In, Out, Set, S2> {
        self.should_fallback = Some(ShouldFallback::new(predicate));
        self.into_state::<Set, S2>()
    }

    /// Optionally enables the fallback middleware based on a condition.
    ///
    /// When disabled, the inner service output is returned as-is regardless of
    /// the [`should_fallback`][FallbackLayer::should_fallback] predicate. The
    /// `is_enabled` function receives a reference to the input and returns
    /// `true` when fallback protection should be active for this request.
    ///
    /// **Default**: Always enabled
    #[must_use]
    pub fn enable_if(mut self, is_enabled: impl Fn(&In) -> bool + Send + Sync + 'static) -> Self {
        self.enable_if = EnableIf::custom(is_enabled);
        self
    }

    /// Enables or disables the fallback middleware.
    ///
    /// When disabled, requests pass through without fallback protection.
    /// This call replaces any previous condition.
    #[must_use]
    fn enable(mut self, enabled: bool) -> Self {
        self.enable_if = EnableIf::new(enabled);
        self
    }

    /// Enables the fallback middleware unconditionally.
    ///
    /// All requests will have fallback protection applied.
    /// This call replaces any previous condition.
    ///
    /// **Note**: This is the default behavior.
    #[must_use]
    pub fn enable_always(self) -> Self {
        self.enable(true)
    }

    /// Disables the fallback middleware completely.
    ///
    /// All requests will pass through without fallback protection.
    /// This call replaces any previous condition.
    #[must_use]
    pub fn disable(self) -> Self {
        self.enable(false)
    }
}

impl<In, Out: Send + 'static, S1, S2> FallbackLayer<In, Out, S1, S2> {
    /// Sets a synchronous fallback action.
    ///
    /// The `action` receives the original (invalid) output and [`FallbackActionArgs`]
    /// with additional context, and returns a replacement output. This call replaces
    /// any previous fallback action.
    #[must_use]
    pub fn fallback(mut self, action: impl Fn(Out, FallbackActionArgs) -> Out + Send + Sync + 'static) -> FallbackLayer<In, Out, S1, Set> {
        self.fallback_action = Some(FallbackAction::new_sync(action));
        self.into_state::<S1, Set>()
    }

    /// Sets a fixed fallback value that is cloned on every invocation.
    ///
    /// This is a convenience shorthand for [`fallback`][FallbackLayer::fallback]
    /// when the replacement output is always the same value. The original
    /// (invalid) output is discarded and `value` is cloned in its place.
    ///
    /// This call replaces any previous fallback action.
    #[must_use]
    pub fn fallback_output(self, value: Out) -> FallbackLayer<In, Out, S1, Set>
    where
        Out: Clone + Sync,
    {
        self.fallback(move |_, _| value.clone())
    }

    /// Sets an asynchronous fallback action.
    ///
    /// The `action` receives the original (invalid) output and [`FallbackActionArgs`]
    /// with additional context, and returns a future that resolves to the replacement
    /// output. This call replaces any previous fallback action.
    #[must_use]
    pub fn fallback_async<F, Fut>(mut self, action: F) -> FallbackLayer<In, Out, S1, Set>
    where
        F: Fn(Out, FallbackActionArgs) -> Fut + Send + Sync + 'static,
        Fut: Future<Output = Out> + Send + 'static,
    {
        self.fallback_action = Some(FallbackAction::new_async(action));
        self.into_state::<S1, Set>()
    }
}

impl<In, Out, S> Layer<S> for FallbackLayer<In, Out, Set, Set> {
    type Service = Fallback<In, Out, S>;

    fn layer(&self, inner: S) -> Self::Service {
        let shared = FallbackShared {
            enable_if: self.enable_if.clone(),
            should_fallback: self.should_fallback.clone().expect("enforced by the type state pattern"),
            fallback_action: self.fallback_action.clone().expect("enforced by the type state pattern"),
            #[cfg(any(feature = "logs", feature = "metrics", test))]
            telemetry: self.telemetry.clone(),
        };

        Fallback {
            shared: Arc::new(shared),
            inner,
        }
    }
}

impl<In, Out, S1, S2> FallbackLayer<In, Out, S1, S2> {
    fn into_state<T1, T2>(self) -> FallbackLayer<In, Out, T1, T2> {
        FallbackLayer {
            should_fallback: self.should_fallback,
            fallback_action: self.fallback_action,
            enable_if: self.enable_if,
            telemetry: self.telemetry,
            _state: PhantomData,
        }
    }
}

#[cfg_attr(coverage_nightly, coverage(off))]
#[cfg(test)]
mod tests {
    use std::fmt::Debug;

    use layered::Execute;
    use tick::Clock;

    use super::*;

    #[test]
    fn new_needs_should_fallback_and_action() {
        let context = create_test_context();
        let layer: FallbackLayer<_, _, NotSet, NotSet> = FallbackLayer::new("test_fallback".into(), &context);

        assert!(layer.should_fallback.is_none());
        assert!(layer.fallback_action.is_none());
        assert_eq!(layer.telemetry.strategy_name.as_ref(), "test_fallback");
        assert!(layer.enable_if.call(&"test_input".to_string()));
    }

    #[test]
    fn should_fallback_ensure_set_correctly() {
        let context = create_test_context();
        let layer: FallbackLayer<_, _, Set, NotSet> =
            FallbackLayer::new("test".into(), &context).should_fallback(|output: &String| output == "bad");

        assert!(layer.should_fallback.as_ref().unwrap().call(&"bad".to_string()));
        assert!(!layer.should_fallback.as_ref().unwrap().call(&"good".to_string()));
    }

    #[test]
    fn fallback_sync_ensure_set_correctly() {
        let context = create_test_context();
        let layer: FallbackLayer<_, _, NotSet, Set> =
            FallbackLayer::new("test".into(), &context).fallback(|_output: String, _args| "replaced".to_string());

        assert!(layer.fallback_action.is_some());
    }

    #[cfg_attr(miri, ignore)]
    #[tokio::test]
    async fn fallback_async_ensure_set_correctly() {
        let context = create_test_context();
        let layer: FallbackLayer<_, _, NotSet, Set> =
            FallbackLayer::new("test".into(), &context).fallback_async(|_output: String, _args| async { "replaced".to_string() });

        let result = layer.fallback_action.unwrap().call("bad".to_string(), FallbackActionArgs {}).await;
        assert_eq!(result, "replaced");
    }

    #[test]
    fn enable_if_ok() {
        let layer: FallbackLayer<_, _, Set, Set> = create_ready_layer().enable_if(|input| matches!(input.as_ref(), "enable"));

        assert!(layer.enable_if.call(&"enable".to_string()));
        assert!(!layer.enable_if.call(&"disable".to_string()));
    }

    #[test]
    fn disable_ok() {
        let layer: FallbackLayer<_, _, Set, Set> = create_ready_layer().disable();

        assert!(!layer.enable_if.call(&"whatever".to_string()));
    }

    #[test]
    fn enable_ok() {
        let layer: FallbackLayer<_, _, Set, Set> = create_ready_layer().disable().enable_always();

        assert!(layer.enable_if.call(&"whatever".to_string()));
    }

    #[test]
    fn should_fallback_when_ready_ok() {
        let layer: FallbackLayer<_, _, Set, Set> = create_ready_layer().should_fallback(|output: &String| output == "new_bad");

        assert!(layer.should_fallback.unwrap().call(&"new_bad".to_string()));
    }

    #[test]
    fn fallback_output_ok() {
        let context = create_test_context();
        let layer: FallbackLayer<_, _, NotSet, Set> = FallbackLayer::new("test".into(), &context).fallback_output("fixed".to_string());

        assert!(layer.fallback_action.is_some());
    }

    #[test]
    fn fallback_when_ready_ok() {
        let layer: FallbackLayer<_, _, Set, Set> = create_ready_layer().fallback(|_, _| "new_fallback".to_string());

        assert!(layer.fallback_action.is_some());
    }

    #[test]
    fn layer_ok() {
        let _layered = create_ready_layer().layer(Execute::new(|input: String| async move { input }));
    }

    #[test]
    fn static_assertions() {
        static_assertions::assert_impl_all!(FallbackLayer<String, String, Set, Set>: Layer<String>);
        static_assertions::assert_not_impl_all!(FallbackLayer<String, String, Set, NotSet>: Layer<String>);
        static_assertions::assert_not_impl_all!(FallbackLayer<String, String, NotSet, Set>: Layer<String>);
        static_assertions::assert_impl_all!(FallbackLayer<String, String, Set, Set>: Debug);
    }

    fn create_test_context() -> ResilienceContext<String, String> {
        ResilienceContext::new(Clock::new_frozen()).name("test_pipeline")
    }

    fn create_ready_layer() -> FallbackLayer<String, String, Set, Set> {
        FallbackLayer::new("test".into(), &create_test_context())
            .should_fallback(|output: &String| output == "bad")
            .fallback(|_output: String, _args| "fallback_value".to_string())
    }
}