pub mod middleware;
pub mod routes;
use crate::config::{Config, ServerConfig};
use crate::utils::error::{GatewayError, Result};
use actix_cors::Cors;
use actix_web::{
App, HttpResponse, HttpServer as ActixHttpServer,
middleware::{DefaultHeaders, Logger},
web,
};
use chrono;
use serde_json::json;
use std::sync::Arc;
use tracing::info;
#[derive(Clone)]
#[allow(dead_code)]
pub struct AppState {
pub config: Arc<Config>,
pub auth: Arc<crate::auth::AuthSystem>,
pub router: Arc<crate::core::router::Router>,
pub storage: Arc<crate::storage::StorageLayer>,
}
impl AppState {
pub fn new(
config: Config,
auth: crate::auth::AuthSystem,
router: crate::core::router::Router,
storage: crate::storage::StorageLayer,
) -> Self {
Self {
config: Arc::new(config),
auth: Arc::new(auth),
router: Arc::new(router),
storage: Arc::new(storage),
}
}
}
#[allow(dead_code)]
pub struct HttpServer {
config: ServerConfig,
state: AppState,
}
#[allow(dead_code)]
impl HttpServer {
pub async fn new(config: &Config) -> Result<Self> {
info!("Creating HTTP server");
let storage = crate::storage::StorageLayer::new(&config.gateway.storage).await?;
let auth =
crate::auth::AuthSystem::new(&config.gateway.auth, Arc::new(storage.clone())).await?;
let router = crate::core::router::Router::new(
config.gateway.providers.clone(),
Arc::new(storage.clone()),
crate::core::router::RoutingStrategy::RoundRobin,
)
.await?;
let state = AppState::new(config.clone(), auth, router, storage);
Ok(Self {
config: config.gateway.server.clone(),
state,
})
}
fn create_app(
state: web::Data<AppState>,
) -> App<
impl actix_web::dev::ServiceFactory<
actix_web::dev::ServiceRequest,
Config = (),
Response = actix_web::dev::ServiceResponse<impl actix_web::body::MessageBody>,
Error = actix_web::Error,
InitError = (),
>,
> {
info!("Setting up routes and middleware");
let cors_config = &state.config.gateway.server.cors;
let mut cors = Cors::default();
if cors_config.enabled {
if cors_config.allows_all_origins() {
cors = cors.allow_any_origin();
cors_config.validate().unwrap_or_else(|e| {
eprintln!("⚠️ CORS Configuration Warning: {}", e);
});
} else {
for origin in &cors_config.allowed_origins {
cors = cors.allowed_origin(origin);
}
}
let methods: Vec<actix_web::http::Method> = cors_config
.allowed_methods
.iter()
.filter_map(|m| m.parse().ok())
.collect();
if !methods.is_empty() {
cors = cors.allowed_methods(methods);
}
let headers: Vec<actix_web::http::header::HeaderName> = cors_config
.allowed_headers
.iter()
.filter_map(|h| h.parse().ok())
.collect();
if !headers.is_empty() {
cors = cors.allowed_headers(headers);
}
cors = cors.max_age(cors_config.max_age as usize);
if cors_config.allow_credentials {
cors = cors.supports_credentials();
}
}
App::new()
.app_data(state)
.wrap(cors)
.wrap(Logger::default())
.wrap(DefaultHeaders::new().add(("Server", "LiteLLM-RS")))
.route("/health", web::get().to(health_check))
.configure(routes::ai::configure_routes)
}
pub async fn start(self) -> Result<()> {
let bind_addr = format!("{}:{}", self.config.host, self.config.port);
info!("Starting HTTP server on {}", bind_addr);
let state = web::Data::new(self.state);
let server = ActixHttpServer::new(move || Self::create_app(state.clone()))
.bind(&bind_addr)
.map_err(|e| GatewayError::server(format!("Failed to bind to {}: {}", bind_addr, e)))?
.run();
info!("HTTP server listening on {}", bind_addr);
server
.await
.map_err(|e| GatewayError::server(format!("Server error: {}", e)))?;
info!("HTTP server stopped");
Ok(())
}
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("Failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
info!("Received Ctrl+C signal, shutting down gracefully");
},
_ = terminate => {
info!("Received terminate signal, shutting down gracefully");
},
}
}
pub fn config(&self) -> &ServerConfig {
&self.config
}
pub fn state(&self) -> &AppState {
&self.state
}
}
impl AppState {
#[allow(dead_code)] pub fn config(&self) -> &Config {
&self.config
}
}
#[allow(dead_code)]
pub struct ServerBuilder {
config: Option<Config>,
}
#[allow(dead_code)]
impl ServerBuilder {
pub fn new() -> Self {
Self { config: None }
}
pub fn with_config(mut self, config: Config) -> Self {
self.config = Some(config);
self
}
pub async fn build(self) -> Result<HttpServer> {
let config = self
.config
.ok_or_else(|| GatewayError::Config("Configuration is required".to_string()))?;
HttpServer::new(&config).await
}
}
#[allow(dead_code)]
pub async fn run_server() -> Result<()> {
info!("🚀 启动 Rust LiteLLM Gateway");
let config_path = "config/gateway.yaml";
info!("📄 加载配置文件: {}", config_path);
let config = match Config::from_file(config_path).await {
Ok(config) => {
info!("✅ 配置文件加载成功");
config
}
Err(e) => {
info!("⚠️ 配置文件加载失败,使用默认配置: {}", e);
info!("💡 请确保 config/gateway.yaml 文件存在并填入正确的 API 密钥");
Config::default()
}
};
let server = HttpServer::new(&config).await?;
info!(
"🌐 服务器启动地址: http://{}:{}",
config.server().host,
config.server().port
);
info!("📋 API 端点:");
info!(" GET /health - 健康检查");
info!(" GET /v1/models - 模型列表");
info!(" POST /v1/chat/completions - 聊天完成");
info!(" POST /v1/completions - 文本完成");
info!(" POST /v1/embeddings - 文本嵌入");
server.start().await
}
impl Default for ServerBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ServerHealth {
pub status: String,
pub uptime: u64,
pub active_connections: u32,
pub memory_usage: u64,
pub cpu_usage: f64,
pub storage_health: crate::storage::StorageHealthStatus,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct RequestMetrics {
pub request_id: String,
pub method: String,
pub path: String,
pub status_code: u16,
pub response_time_ms: u64,
pub request_size: u64,
pub response_size: u64,
pub user_agent: Option<String>,
pub client_ip: Option<String>,
pub user_id: Option<uuid::Uuid>,
pub api_key_id: Option<uuid::Uuid>,
}
async fn health_check() -> HttpResponse {
HttpResponse::Ok().json(json!({
"status": "healthy",
"timestamp": chrono::Utc::now().to_rfc3339(),
"version": env!("CARGO_PKG_VERSION")
}))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_server_builder() {
let builder = ServerBuilder::new();
assert!(builder.config.is_none());
}
#[test]
fn test_app_state_creation() {
assert_eq!(
std::mem::size_of::<HttpServer>(),
std::mem::size_of::<HttpServer>()
);
}
#[test]
fn test_request_metrics_creation() {
let metrics = RequestMetrics {
request_id: "req-123".to_string(),
method: "GET".to_string(),
path: "/health".to_string(),
status_code: 200,
response_time_ms: 50,
request_size: 0,
response_size: 100,
user_agent: Some("test-agent".to_string()),
client_ip: Some("127.0.0.1".to_string()),
user_id: None,
api_key_id: None,
};
assert_eq!(metrics.request_id, "req-123");
assert_eq!(metrics.method, "GET");
assert_eq!(metrics.status_code, 200);
}
}