llm_router 0.1.0

A high-performance router and load balancer for LLM APIs like ChatGPT
Documentation
use axum::{
    extract::{Json, State},
    http::StatusCode,
    response::IntoResponse,
    routing::post,
    serve,
    Router as AxumRouter,
};
use llm_router::{
    types::{ChatCompletionRequest, ChatCompletionResponse, ModelCapability, RouterError, RoutingStrategy},
    ModelInstanceConfig, RequestTracker, Router,
};
use reqwest;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tracing::{error, info};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

// Shared application state
#[derive(Clone)]
struct AppState {
    router: Arc<Router>,
    http_client: reqwest::Client,
}

#[tokio::main]
async fn main() {
    // Initialize tracing
    tracing_subscriber::registry()
        .with(tracing_subscriber::EnvFilter::new(
            std::env::var("RUST_LOG").unwrap_or_else(|_| "info,llm_router=debug".into()),
        ))
        .with(tracing_subscriber::fmt::layer())
        .init();

    info!("Initializing LLM Router...");

    // Configure the router
    let router_config = Router::builder()
        .strategy(RoutingStrategy::LoadBased) // Use LoadBased strategy
        .instance_with_models(
            "instance_1",
            // Replace with a *real*, accessible backend URL for testing
            // This example uses httpbin.org for demonstration purposes
            "https://httpbin.org", 
            vec![
                ModelInstanceConfig {
                    model_name: "gpt-4".to_string(),
                    capabilities: vec![ModelCapability::Chat],
                },
                ModelInstanceConfig {
                    model_name: "text-embedding-ada-002".to_string(),
                    capabilities: vec![ModelCapability::Embedding],
                },
            ],
        )
        .instance_with_models(
            "instance_2",
            // Replace with another *real*, accessible backend URL
            "https://httpbin.org", // Using httpbin again for demo
            vec![ModelInstanceConfig {
                model_name: "gpt-3.5-turbo".to_string(),
                capabilities: vec![ModelCapability::Chat],
            }],
        )
        .health_check_path("/get") // httpbin.org has a /get endpoint
        .health_check_interval(Duration::from_secs(10))
        .instance_timeout_duration(Duration::from_secs(30))
        .build();

    let shared_router = Arc::new(router_config);

    // Create a reusable reqwest client
    let http_client = reqwest::Client::new();

    // Create the application state
    let app_state = AppState {
        router: shared_router,
        http_client,
    };

    // Build our application router
    let app = AxumRouter::new()
        // Define the route for chat completions
        // Note: You might want to align the path with OpenAI's API structure like /v1/chat/completions
        .route("/chat/completions", post(chat_completions_handler))
        .with_state(app_state);

    // Run the server
    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    info!("Starting server on {}", addr);
    let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
    info!("Listening on {}", addr);
    serve(listener, app.into_make_service()).await.unwrap();
}

async fn chat_completions_handler(
    State(state): State<AppState>,
    Json(payload): Json<ChatCompletionRequest>, // Use the actual request type
) -> impl IntoResponse {
    let model_name = payload.model.clone();
    info!(model = %model_name, "Received chat completion request");

    // 1. Select an instance
    let instance = match state
        .router
        .select_instance_for_model(&model_name, ModelCapability::Chat)
        .await
    {
        Ok(instance) => {
            info!(instance_id = %instance.id, "Selected instance");
            instance
        }
        Err(RouterError::NoHealthyInstancesForModel(model, cap)) => {
            error!(%model, ?cap, "No healthy instances available");
            return (
                StatusCode::SERVICE_UNAVAILABLE,
                Json(serde_json::json!({"error": "No healthy instances available for the requested model and capability"})),
            )
                .into_response();
        }
        Err(e) => {
            error!(error = %e, "Failed to select instance");
            return (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(serde_json::json!({"error": "Internal server error during instance selection"})),
            )
                .into_response();
        }
    };

    // 2. Use RequestTracker for load balancing (will decrement count on drop)
    let _tracker = RequestTracker::new(state.router.as_ref().clone(), instance.id.clone()).await;

    // 3. Construct the target URL
    //    IMPORTANT: Adjust the path based on the actual backend API endpoint structure
    let target_url = format!("{}/post", instance.base_url); // httpbin.org /post echoes the request

    info!(url = %target_url, "Forwarding request to backend");

    // 4. Forward the request
    match state
        .http_client
        .post(&target_url)
        .json(&payload) // Send the original payload
        .send()
        .await
    {
        Ok(response) => {
            let status = response.status();
            match response.json::<ChatCompletionResponse>().await {
                Ok(backend_response) => {
                    info!(instance_id = %instance.id, status = %status, "Received response from backend");
                    (status, Json(backend_response)).into_response()
                }
                Err(e) => {
                    error!(instance_id = %instance.id, status = %status, error = %e, "Failed to decode backend response");
                    // Maybe timeout the instance if decoding fails consistently?
                    // Consider calling state.router.timeout_instance(&instance.id).await;
                    (
                        StatusCode::BAD_GATEWAY,
                        Json(serde_json::json!({"error": "Failed to decode response from backend instance"})),
                    ).into_response()
                }
            }
        }
        Err(e) => {
            error!(instance_id = %instance.id, error = %e, "Failed to forward request to backend");
            // Request failed (timeout, connection error, etc.) - timeout the instance
            match state.router.timeout_instance(&instance.id).await {
                Ok(_) => info!(instance_id = %instance.id, "Timed out instance due to request failure"),
                Err(timeout_err) => error!(instance_id = %instance.id, error = %timeout_err, "Failed to timeout instance"),
            }
            (
                StatusCode::BAD_GATEWAY,
                Json(serde_json::json!({"error": format!("Failed to forward request to backend instance: {}", instance.id)})),
            ).into_response()
        }
    }
}