pe-core 0.1.0

Core types for Potential Expectations — messages, channels, state, traits
Documentation
//! Fallback provider -- tries a primary LLM, falls back to secondary on transient failure.
//!
//! Only falls back when the primary returns a transient error (as classified by
//! [`PeError::is_transient()`]). Permanent errors propagate immediately.
//!
//! # Example
//!
//! ```ignore
//! use pe_core::fallback_provider::FallbackProvider;
//!
//! let provider = FallbackProvider::new(openai_provider, anthropic_provider);
//! // Uses OpenAI by default; on transient failure, retries with Anthropic
//! ```

use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use crate::error::PeError;
use crate::llm::{LlmProvider, LlmResponse, StreamFuture, ToolSchema};
use crate::message::Message;

/// Wraps a primary and secondary [`LlmProvider`]. On transient failure from
/// the primary, automatically tries the secondary.
///
/// Implements `LlmProvider` itself, so it can be used anywhere a provider is expected
/// (including inside a [`MiddlewareStack`](super::middleware_stack::MiddlewareStack)).
pub struct FallbackProvider {
    primary: Arc<dyn LlmProvider>,
    secondary: Arc<dyn LlmProvider>,
}

impl FallbackProvider {
    /// Create a fallback provider with a primary and secondary.
    pub fn new(primary: impl LlmProvider, secondary: impl LlmProvider) -> Self {
        Self {
            primary: Arc::new(primary),
            secondary: Arc::new(secondary),
        }
    }

    async fn do_complete(
        primary: &dyn LlmProvider,
        secondary: &dyn LlmProvider,
        messages: Vec<Message>,
        tools: Vec<ToolSchema>,
    ) -> Result<LlmResponse, PeError> {
        match primary.complete(&messages, &tools).await {
            Ok(resp) => Ok(resp),
            Err(e) if e.is_transient() => secondary.complete(&messages, &tools).await,
            Err(e) => Err(e),
        }
    }
}

impl LlmProvider for FallbackProvider {
    fn complete(
        &self,
        messages: &[Message],
        tools: &[ToolSchema],
    ) -> Pin<Box<dyn Future<Output = Result<LlmResponse, PeError>> + Send + '_>> {
        let messages = messages.to_vec();
        let tools = tools.to_vec();
        Box::pin(Self::do_complete(
            self.primary.as_ref(),
            self.secondary.as_ref(),
            messages,
            tools,
        ))
    }

    fn stream(&self, messages: &[Message], tools: &[ToolSchema]) -> StreamFuture<'_> {
        let messages = messages.to_vec();
        let tools = tools.to_vec();
        Box::pin(async move {
            match self.primary.stream(&messages, &tools).await {
                Ok(stream) => Ok(stream),
                Err(e) if e.is_transient() => self.secondary.stream(&messages, &tools).await,
                Err(e) => Err(e),
            }
        })
    }

    fn embed(
        &self,
        text: &str,
    ) -> Pin<Box<dyn Future<Output = Result<Vec<f32>, PeError>> + Send + '_>> {
        let text = text.to_owned();
        Box::pin(async move {
            match self.primary.embed(&text).await {
                Ok(v) => Ok(v),
                Err(e) if e.is_transient() => self.secondary.embed(&text).await,
                Err(e) => Err(e),
            }
        })
    }

    fn provider_name(&self) -> &'static str {
        self.primary.provider_name()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::mock_provider::MockProvider;

    #[tokio::test]
    async fn test_primary_succeeds_no_fallback() {
        let primary = MockProvider::new().respond_with("primary");
        let secondary = MockProvider::new().respond_with("secondary");

        let fb = FallbackProvider::new(primary, secondary);
        let resp = fb.complete(&[], &[]).await.unwrap();
        assert_eq!(resp.message.content.as_text(), Some("primary"));
    }

    #[tokio::test]
    async fn test_falls_back_on_transient_error() {
        let primary = MockProvider::new().respond_with_error(PeError::LlmProvider {
            details: "503".into(),
        });
        let secondary = MockProvider::new().respond_with("fallback");

        let fb = FallbackProvider::new(primary, secondary);
        let resp = fb.complete(&[], &[]).await.unwrap();
        assert_eq!(resp.message.content.as_text(), Some("fallback"));
    }

    #[tokio::test]
    async fn test_permanent_error_propagates_no_fallback() {
        let primary = MockProvider::new().respond_with_error(PeError::PermissionDenied {
            action: "call".into(),
        });
        let secondary = MockProvider::new().respond_with("should not reach");

        let fb = FallbackProvider::new(primary, secondary);
        let err = fb.complete(&[], &[]).await.unwrap_err();
        assert!(matches!(err, PeError::PermissionDenied { .. }));
    }

    #[tokio::test]
    async fn test_both_fail_returns_secondary_error() {
        let primary = MockProvider::new().respond_with_error(PeError::LlmProvider {
            details: "primary down".into(),
        });
        let secondary = MockProvider::new().respond_with_error(PeError::LlmProvider {
            details: "secondary down".into(),
        });

        let fb = FallbackProvider::new(primary, secondary);
        let err = fb.complete(&[], &[]).await.unwrap_err();
        match err {
            PeError::LlmProvider { details } => assert_eq!(details, "secondary down"),
            other => panic!("expected LlmProvider, got {other:?}"),
        }
    }

    #[tokio::test]
    async fn test_provider_name_returns_primary() {
        let fb = FallbackProvider::new(MockProvider::new(), MockProvider::new());
        assert_eq!(fb.provider_name(), "mock");
    }
}