use crate::api::config::PanelConfig;
use crate::api::events::EventBus;
use axum::extract::DefaultBodyLimit;
use axum::http::{HeaderName, Method};
use axum::middleware as axum_mw;
use axum::routing::{get, post, put};
use axum::{extract::State, Json, Router};
use std::path::PathBuf;
use std::sync::Arc;
use tower_http::cors::{AllowOrigin, CorsLayer};
#[derive(Clone)]
pub struct AppState {
pub api_token: String,
pub event_bus: EventBus,
pub password_hash: Option<String>,
pub jwt_secret: String,
pub ws_semaphore: Arc<tokio::sync::Semaphore>,
pub session_manager: Option<Arc<crate::session::SessionManager>>,
pub task_store: Option<Arc<crate::api::tasks::TaskStore>>,
pub health_registry: Option<Arc<crate::health::HealthRegistry>>,
pub usage_metrics: Option<Arc<crate::health::UsageMetrics>>,
pub metrics_collector: Option<Arc<crate::utils::metrics::MetricsCollector>>,
pub provider: Option<Arc<dyn crate::providers::LLMProvider>>,
pub config: Option<Arc<crate::config::Config>>,
}
impl AppState {
pub const MAX_WS_CONNECTIONS: usize = 5;
pub fn new(api_token: String, event_bus: EventBus) -> Self {
Self {
api_token,
event_bus,
password_hash: None,
jwt_secret: uuid::Uuid::new_v4().to_string(),
ws_semaphore: Arc::new(tokio::sync::Semaphore::new(Self::MAX_WS_CONNECTIONS)),
session_manager: None,
task_store: None,
health_registry: None,
usage_metrics: None,
metrics_collector: None,
provider: None,
config: None,
}
}
}
async fn csrf_token_handler(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
let token = super::middleware::generate_csrf_token(&state.jwt_secret);
Json(serde_json::json!({ "token": token }))
}
pub fn build_router(
state: AppState,
static_dir: Option<PathBuf>,
cors_origin: Option<String>,
) -> Router {
let shared_state = Arc::new(state);
let origin_str = cors_origin.unwrap_or_else(|| "http://localhost:9092".to_string());
let origin_value = origin_str
.parse::<axum::http::HeaderValue>()
.unwrap_or_else(|_| {
"http://localhost:9092"
.parse::<axum::http::HeaderValue>()
.expect("fallback origin is valid")
});
let cors = CorsLayer::new()
.allow_origin(AllowOrigin::exact(origin_value))
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
.allow_headers([
HeaderName::from_static("content-type"),
HeaderName::from_static("authorization"),
HeaderName::from_static("x-csrf-token"),
]);
let api = Router::new()
.route("/api/auth/login", post(super::routes::auth::login))
.route("/api/csrf-token", get(csrf_token_handler))
.route("/api/health", get(super::routes::health::get_health))
.route("/api/metrics", get(super::routes::metrics::get_metrics))
.route("/api/sessions", get(super::routes::sessions::list_sessions))
.route(
"/api/sessions/{key}",
get(super::routes::sessions::get_session)
.delete(super::routes::sessions::delete_session),
)
.route("/api/channels", get(super::routes::channels::list_channels))
.route(
"/api/cron",
get(super::routes::cron::list_jobs).post(super::routes::cron::create_job),
)
.route(
"/api/cron/{id}",
put(super::routes::cron::update_job).delete(super::routes::cron::delete_job),
)
.route(
"/api/cron/{id}/trigger",
post(super::routes::cron::trigger_job),
)
.route(
"/api/routines",
get(super::routes::routines::list_routines)
.post(super::routes::routines::create_routine),
)
.route(
"/api/routines/{id}",
put(super::routes::routines::update_routine)
.delete(super::routes::routines::delete_routine),
)
.route(
"/api/routines/{id}/toggle",
post(super::routes::routines::toggle_routine),
)
.route(
"/api/tasks",
get(super::routes::tasks::list_tasks).post(super::routes::tasks::create_task),
)
.route(
"/api/tasks/{id}",
put(super::routes::tasks::update_task).delete(super::routes::tasks::delete_task),
)
.route(
"/api/tasks/{id}/move",
post(super::routes::tasks::move_task),
)
.route("/ws/events", get(super::routes::ws::ws_events))
.route(
"/v1/chat/completions",
post(super::routes::openai::chat_completions),
)
.route("/v1/models", get(super::routes::openai::list_models))
.layer(DefaultBodyLimit::max(1024 * 1024))
.layer(cors)
.layer(axum_mw::from_fn_with_state(
shared_state.clone(),
super::middleware::auth_middleware,
))
.with_state(shared_state);
if let Some(dir) = static_dir {
api.fallback_service(tower_http::services::ServeDir::new(dir))
} else {
api
}
}
pub async fn start_server(
config: &PanelConfig,
state: AppState,
static_dir: Option<PathBuf>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let cors_origin = format!("http://{}:{}", config.bind, config.port);
let app = build_router(state, static_dir, Some(cors_origin));
let addr = format!("{}:{}", config.bind, config.api_port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
tracing::info!("Panel API server listening on {addr}");
axum::serve(listener, app).await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_app_state_new() {
let bus = EventBus::new(16);
let state = AppState::new("test-token".into(), bus);
assert_eq!(state.api_token, "test-token");
assert!(state.password_hash.is_none());
assert!(!state.jwt_secret.is_empty());
}
#[test]
fn test_app_state_jwt_secret_rotates() {
let bus1 = EventBus::new(4);
let bus2 = EventBus::new(4);
let s1 = AppState::new("tok".into(), bus1);
let s2 = AppState::new("tok".into(), bus2);
assert_ne!(s1.jwt_secret, s2.jwt_secret);
}
#[test]
fn test_build_router_no_static() {
let bus = EventBus::new(16);
let state = AppState::new("tok".into(), bus);
let _router = build_router(state, None, None);
}
#[test]
fn test_build_router_with_static() {
let bus = EventBus::new(16);
let state = AppState::new("tok".into(), bus);
let dir = std::env::temp_dir();
let _router = build_router(state, Some(dir), None);
}
#[test]
fn test_build_router_with_custom_cors_origin() {
let bus = EventBus::new(16);
let state = AppState::new("tok".into(), bus);
let _router = build_router(state, None, Some("http://10.0.0.1:3000".to_string()));
}
#[test]
fn test_ws_semaphore_initialized_with_correct_permits() {
let bus = EventBus::new(4);
let state = AppState::new("tok".into(), bus);
assert_eq!(
state.ws_semaphore.available_permits(),
AppState::MAX_WS_CONNECTIONS
);
}
#[test]
fn test_ws_semaphore_max_connections_constant() {
assert_eq!(AppState::MAX_WS_CONNECTIONS, 5);
}
}