llm_router 0.1.0

A high-performance router and load balancer for LLM APIs like ChatGPT
Documentation
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{atomic::{AtomicUsize, AtomicBool}, Arc};
use std::time::Instant;
use thiserror::Error;
use tokio::sync::RwLock; // Need RwLock for dynamic instances
use url; // Add this line

// --- Public Structs ---

#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum ModelCapability {
    Chat,
    Embedding,
    Completion,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum InstanceStatus {
    Healthy,
    Unhealthy,
    Unknown, // Initial state before first health check
    TimedOut, // New status for when an instance is in timeout
}

#[derive(Debug, Clone)] // LLMInstance itself needs to be cloneable for adding/removing
pub struct LLMInstance {
    pub id: String,
    pub base_url: String,
    pub active_requests: Arc<AtomicUsize>, // Track load
    pub status: Arc<RwLock<InstanceStatus>>, // Track health
    pub is_in_timeout: Arc<AtomicBool>, // Whether instance is in timeout
    pub timeout_until: Arc<RwLock<Option<Instant>>>, // When the timeout expires
    pub supported_models: Arc<RwLock<HashMap<String, Vec<ModelCapability>>>>, // Map of model_name -> capabilities
}

// Add a constructor for easier creation with initial status
impl LLMInstance {
    pub fn new(
        id: String, 
        base_url: String, 
        initial_status: InstanceStatus,
        supported_models_map: HashMap<String, Vec<ModelCapability>>
    ) -> Self {
        LLMInstance {
            id,
            base_url,
            active_requests: Arc::new(AtomicUsize::new(0)),
            status: Arc::new(RwLock::new(initial_status)),
            is_in_timeout: Arc::new(AtomicBool::new(false)),
            timeout_until: Arc::new(RwLock::new(None)),
            supported_models: Arc::new(RwLock::new(supported_models_map)),
        }
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub enum RoutingStrategy {
    RoundRobin,
    LoadBased, // Least active requests
}

// --- API Payloads (remain the same) ---

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ChatCompletionRequest {
    pub model: String,
    pub messages: Vec<ChatMessage>,
    // ... other fields
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ChatMessage {
    pub role: String,
    pub content: String,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ChatCompletionResponse {
    #[serde(flatten)]
    pub data: serde_json::Value,
}

// --- Custom Error Type ---
#[derive(Error, Debug)]
pub enum RouterError {
    #[error("No healthy backend instances available")]
    NoHealthyInstances,

    #[error("No healthy backend instances available for model {0} with capability {1:?}")]
    NoHealthyInstancesForModel(String, ModelCapability),

    #[error("Instance with ID '{0}' not found")]
    InstanceNotFound(String),

    #[error("Instance with ID '{0}' already exists")]
    InstanceExists(String),

    #[error("Invalid backend URL format: {0}")]
    InvalidUrl(#[from] url::ParseError), // Can also parse URL strings

    #[error("Failed to forward request to backend instance {instance_id}: {source}")]
    BackendRequestFailed {
        instance_id: String,
        #[source]
        source: reqwest::Error,
    },

    #[error("Failed to decode backend response from instance {instance_id}: {source}")]
    BackendResponseDecodeFailed {
        instance_id: String,
        #[source]
        source: reqwest::Error,
    },

    #[error("Backend request timed out for instance {instance_id}")]
    BackendTimeout { instance_id: String },

    #[error("Backend returned status {status} for instance {instance_id}")]
    BackendErrorStatus {
        instance_id: String,
        status: reqwest::StatusCode,
    },

    #[error("Internal synchronization error: {0}")]
    SyncError(String),
}