llm-cascade 0.1.0

Resilient cascading LLM inference with automatic failover across multiple providers
Documentation
//! Cascade execution engine: iterates provider entries with failover and cooldown.

use std::time::Instant;

use chrono::{Duration, Utc};
use rusqlite::Connection;

use crate::config::{AppConfig, ProviderConfig, ProviderType};
use crate::db;
use crate::error::{CascadeError, ProviderError};
use crate::models::{Conversation, LlmResponse};
use crate::persistence;
use crate::providers::anthropic::AnthropicProvider;
use crate::providers::gemini::GeminiProvider;
use crate::providers::ollama::OllamaProvider;
use crate::providers::openai::OpenAiProvider;
use crate::providers::LlmProvider;
use crate::secrets;

const BASE_COOLDOWN_SECS: i64 = 30;
const MAX_COOLDOWN_SECS: i64 = 3600;

fn build_provider(
    provider_name: &str,
    provider_config: &ProviderConfig,
    model: &str,
) -> Result<Box<dyn LlmProvider>, ProviderError> {
    match provider_config.r#type {
        ProviderType::Openai => {
            let service = provider_config.api_key_service.as_deref().unwrap_or(provider_name);
            let env_var = provider_config.api_key_env.as_deref().unwrap_or("OPENAI_API_KEY");
            let api_key = secrets::resolve_api_key(service, env_var)
                .map_err(|_| ProviderError::MissingApiKey(provider_name.into()))?;
            Ok(Box::new(OpenAiProvider::new(api_key, model.into(), provider_config.base_url.clone())))
        }
        ProviderType::Anthropic => {
            let service = provider_config.api_key_service.as_deref().unwrap_or(provider_name);
            let env_var = provider_config.api_key_env.as_deref().unwrap_or("ANTHROPIC_API_KEY");
            let api_key = secrets::resolve_api_key(service, env_var)
                .map_err(|_| ProviderError::MissingApiKey(provider_name.into()))?;
            Ok(Box::new(AnthropicProvider::new(api_key, model.into(), provider_config.base_url.clone())))
        }
        ProviderType::Gemini => {
            let service = provider_config.api_key_service.as_deref().unwrap_or(provider_name);
            let env_var = provider_config.api_key_env.as_deref().unwrap_or("GOOGLE_API_KEY");
            let api_key = secrets::resolve_api_key(service, env_var)
                .map_err(|_| ProviderError::MissingApiKey(provider_name.into()))?;
            Ok(Box::new(GeminiProvider::new(api_key, model.into(), provider_config.base_url.clone())))
        }
        ProviderType::Ollama => {
            Ok(Box::new(OllamaProvider::new(model.into(), provider_config.base_url.clone())))
        }
    }
}

fn compute_cooldown(entry_key: &str, conn: &Connection) -> Duration {
    let current = query_cooldown_level(entry_key, conn);
    let secs = (BASE_COOLDOWN_SECS * 2_i64.pow(current)).min(MAX_COOLDOWN_SECS);
    Duration::seconds(secs)
}

fn query_cooldown_level(entry_key: &str, conn: &Connection) -> u32 {
    let now = Utc::now().to_rfc3339();
    let count = conn.query_row(
        "SELECT COUNT(*) FROM attempt_log
         WHERE provider_model = ?1 AND http_status >= 400 AND timestamp > ?2",
        rusqlite::params![entry_key, now],
        |row| row.get::<_, i64>(0),
    );

    match count {
        Ok(c) if c > 0 => (c as u32).saturating_sub(1),
        _ => 0,
    }
}

/// Runs the named cascade, trying each provider/model entry in order.
///
/// Skips entries that are on cooldown. On success, logs the attempt and returns the response.
/// On failure, sets a cooldown and continues to the next entry. If all entries fail,
/// persists the conversation to a JSON file and returns a [`CascadeError`].
pub async fn run_cascade(
    cascade_name: &str,
    conversation: &Conversation,
    config: &AppConfig,
    conn: &Connection,
) -> Result<LlmResponse, CascadeError> {
    let cascade = config.cascades.get(cascade_name).ok_or_else(|| CascadeError {
        cascade_name: cascade_name.to_string(),
        message: format!("Cascade '{}' not found in configuration", cascade_name),
        failed_prompt_path: persistence::save_failed_conversation(
            conversation,
            &config.failure_persistence.dir,
            cascade_name,
        ),
    })?;

    if cascade.entries.is_empty() {
        let path = persistence::save_failed_conversation(
            conversation,
            &config.failure_persistence.dir,
            cascade_name,
        );
        return Err(CascadeError {
            cascade_name: cascade_name.to_string(),
            message: format!("Cascade '{}' has no entries", cascade_name),
            failed_prompt_path: path,
        });
    }

    let mut errors = Vec::new();
    let mut skipped = Vec::new();

    for entry in &cascade.entries {
        let provider_config = match config.providers.get(&entry.provider) {
            Some(c) => c,
            None => {
                tracing::warn!(
                    "Provider '{}' referenced in cascade '{}' not found in config",
                    entry.provider,
                    cascade_name,
                );
                errors.push(format!("{}/{}: provider not found", entry.provider, entry.model));
                continue;
            }
        };

        let entry_key = format!("{}/{}", entry.provider, entry.model);

        if db::is_on_cooldown(conn, &entry_key) {
            tracing::info!("Skipping '{}' — currently on cooldown", entry_key);
            skipped.push(entry_key);
            continue;
        }

        let provider = match build_provider(&entry.provider, provider_config, &entry.model) {
            Ok(p) => p,
            Err(e) => {
                tracing::warn!("Failed to initialize provider '{}': {}", entry_key, e);
                errors.push(format!("{}: {}", entry_key, e));
                continue;
            }
        };

        tracing::info!("Attempting provider: {}", entry_key);
        let start = Instant::now();

        match provider.complete(conversation).await {
            Ok(response) => {
                let latency_ms = start.elapsed().as_millis() as u64;
                db::log_attempt(
                    conn,
                    cascade_name,
                    &entry_key,
                    Some(200),
                    latency_ms,
                    response.input_tokens,
                    response.output_tokens,
                );
                tracing::info!(
                    "Success from {} ({}ms, in_tokens: {:?}, out_tokens: {:?})",
                    entry_key,
                    latency_ms,
                    response.input_tokens,
                    response.output_tokens,
                );
                return Ok(response);
            }
            Err(e) => {
                let latency_ms = start.elapsed().as_millis() as u64;
                let http_status = e.http_status();
                db::log_attempt(
                    conn,
                    cascade_name,
                    &entry_key,
                    http_status,
                    latency_ms,
                    None,
                    None,
                );

                let cooldown = if e.is_rate_limited() {
                    if let Some(retry_secs) = e.retry_after_seconds() {
                        Duration::seconds(retry_secs as i64)
                    } else {
                        compute_cooldown(&entry_key, conn)
                    }
                } else {
                    compute_cooldown(&entry_key, conn)
                };

                let cooldown_until = (Utc::now() + cooldown).to_rfc3339();
                db::set_cooldown(conn, &entry_key, &cooldown_until);

                tracing::warn!(
                    "Provider '{}' failed (HTTP {:?}): {}. Cooldown until {}",
                    entry_key,
                    http_status,
                    e,
                    cooldown_until,
                );
                errors.push(format!("{}: {}", entry_key, e));
            }
        }
    }

    let mut message = String::new();
    if !skipped.is_empty() {
        message.push_str(&format!("Skipped (on cooldown): {}\n", skipped.join(", ")));
    }
    message.push_str(&format!("Failed entries: {}", errors.join("; ")));

    let failed_prompt_path = persistence::save_failed_conversation(
        conversation,
        &config.failure_persistence.dir,
        cascade_name,
    );

    Err(CascadeError {
        cascade_name: cascade_name.to_string(),
        message,
        failed_prompt_path,
    })
}