use axum::{
extract::{Json, State},
http::StatusCode,
response::IntoResponse,
routing::post,
serve,
Router as AxumRouter,
};
use llm_router::{
types::{ChatCompletionRequest, ChatCompletionResponse, ModelCapability, RouterError, RoutingStrategy},
ModelInstanceConfig, RequestTracker, Router,
};
use reqwest;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tracing::{error, info};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[derive(Clone)]
struct AppState {
router: Arc<Router>,
http_client: reqwest::Client,
}
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::new(
std::env::var("RUST_LOG").unwrap_or_else(|_| "info,llm_router=debug".into()),
))
.with(tracing_subscriber::fmt::layer())
.init();
info!("Initializing LLM Router...");
let router_config = Router::builder()
.strategy(RoutingStrategy::LoadBased) .instance_with_models(
"instance_1",
"https://httpbin.org",
vec![
ModelInstanceConfig {
model_name: "gpt-4".to_string(),
capabilities: vec![ModelCapability::Chat],
},
ModelInstanceConfig {
model_name: "text-embedding-ada-002".to_string(),
capabilities: vec![ModelCapability::Embedding],
},
],
)
.instance_with_models(
"instance_2",
"https://httpbin.org", vec![ModelInstanceConfig {
model_name: "gpt-3.5-turbo".to_string(),
capabilities: vec![ModelCapability::Chat],
}],
)
.health_check_path("/get") .health_check_interval(Duration::from_secs(10))
.instance_timeout_duration(Duration::from_secs(30))
.build();
let shared_router = Arc::new(router_config);
let http_client = reqwest::Client::new();
let app_state = AppState {
router: shared_router,
http_client,
};
let app = AxumRouter::new()
.route("/chat/completions", post(chat_completions_handler))
.with_state(app_state);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
info!("Starting server on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
info!("Listening on {}", addr);
serve(listener, app.into_make_service()).await.unwrap();
}
async fn chat_completions_handler(
State(state): State<AppState>,
Json(payload): Json<ChatCompletionRequest>, ) -> impl IntoResponse {
let model_name = payload.model.clone();
info!(model = %model_name, "Received chat completion request");
let instance = match state
.router
.select_instance_for_model(&model_name, ModelCapability::Chat)
.await
{
Ok(instance) => {
info!(instance_id = %instance.id, "Selected instance");
instance
}
Err(RouterError::NoHealthyInstancesForModel(model, cap)) => {
error!(%model, ?cap, "No healthy instances available");
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({"error": "No healthy instances available for the requested model and capability"})),
)
.into_response();
}
Err(e) => {
error!(error = %e, "Failed to select instance");
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error": "Internal server error during instance selection"})),
)
.into_response();
}
};
let _tracker = RequestTracker::new(state.router.as_ref().clone(), instance.id.clone()).await;
let target_url = format!("{}/post", instance.base_url);
info!(url = %target_url, "Forwarding request to backend");
match state
.http_client
.post(&target_url)
.json(&payload) .send()
.await
{
Ok(response) => {
let status = response.status();
match response.json::<ChatCompletionResponse>().await {
Ok(backend_response) => {
info!(instance_id = %instance.id, status = %status, "Received response from backend");
(status, Json(backend_response)).into_response()
}
Err(e) => {
error!(instance_id = %instance.id, status = %status, error = %e, "Failed to decode backend response");
(
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({"error": "Failed to decode response from backend instance"})),
).into_response()
}
}
}
Err(e) => {
error!(instance_id = %instance.id, error = %e, "Failed to forward request to backend");
match state.router.timeout_instance(&instance.id).await {
Ok(_) => info!(instance_id = %instance.id, "Timed out instance due to request failure"),
Err(timeout_err) => error!(instance_id = %instance.id, error = %timeout_err, "Failed to timeout instance"),
}
(
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({"error": format!("Failed to forward request to backend instance: {}", instance.id)})),
).into_response()
}
}
}