use crate::config::Config;
use crate::plugin::Registry;
use crate::plugins::CachePlugin;
use axum::{
Json, Router,
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
routing::{get, post},
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::net::TcpListener;
use tokio::sync::RwLock;
use tracing::{info, warn};
#[derive(Debug, Serialize, Deserialize)]
pub struct CacheControlRequest {
pub action: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CacheStatsResponse {
pub size: usize,
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub expirations: u64,
pub hit_rate: f64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ReloadConfigRequest {
pub path: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SuccessResponse {
pub message: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ErrorResponse {
pub error: String,
}
#[derive(Clone)]
pub struct AdminState {
config: Arc<RwLock<Config>>,
registry: Arc<Registry>,
start_time: Instant,
}
impl AdminState {
pub fn new(config: Arc<RwLock<Config>>, registry: Arc<Registry>) -> Self {
Self {
config,
registry,
start_time: Instant::now(),
}
}
}
pub struct AdminServer {
addr: String,
state: Arc<AdminState>,
}
impl AdminServer {
pub fn new(addr: impl Into<String>, state: AdminState) -> Self {
Self {
addr: addr.into(),
state: Arc::new(state),
}
}
pub async fn run_with_signal(
self,
startup_tx: Option<tokio::sync::oneshot::Sender<()>>,
mut shutdown_rx: Option<tokio::sync::oneshot::Receiver<()>>,
) -> Result<(), std::io::Error> {
let app = Router::new()
.route("/api/cache/control", post(cache_control))
.route("/api/cache/stats", get(cache_stats))
.route("/api/config/reload", post(reload_config))
.route("/api/server/stats", get(server_stats))
.with_state(Arc::clone(&self.state));
let listener = TcpListener::bind(&self.addr).await?;
info!("Admin API server listening on {}", self.addr);
if let Some(tx) = startup_tx {
let _ = tx.send(());
}
let shutdown_fut = async move {
if let Some(rx) = shutdown_rx.as_mut() {
let _ = rx.await;
} else {
#[cfg(unix)]
{
use tokio::signal::unix::{SignalKind, signal};
let mut sigterm = signal(SignalKind::terminate()).unwrap();
let mut sighup = signal(SignalKind::hangup()).unwrap();
tokio::select! {
_ = tokio::signal::ctrl_c() => {},
_ = sigterm.recv() => {},
_ = sighup.recv() => {},
}
}
#[cfg(not(unix))]
{
let _ = tokio::signal::ctrl_c().await;
}
}
};
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_fut)
.await?;
Ok(())
}
pub async fn run(self) -> Result<(), std::io::Error> {
self.run_with_signal(None, None).await
}
}
#[allow(clippy::result_large_err)]
fn get_cache_plugin(registry: &Arc<Registry>) -> Result<Arc<dyn crate::plugin::Plugin>, Response> {
let cache = registry.get("cache").ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: "Cache not configured".to_string(),
}),
)
.into_response()
})?;
cache
.as_ref()
.as_any()
.downcast_ref::<CachePlugin>()
.ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: "Cache plugin found but failed to access".to_string(),
}),
)
.into_response()
})?;
Ok(cache)
}
async fn cache_control(
State(state): State<Arc<AdminState>>,
Json(request): Json<CacheControlRequest>,
) -> Response {
let cache = match get_cache_plugin(&state.registry) {
Ok(c) => c,
Err(e) => return e,
};
match request.action.as_str() {
"clear" => {
if let Some(cache_plugin) = cache.as_ref().as_any().downcast_ref::<CachePlugin>() {
cache_plugin.clear();
info!("Cache cleared via admin API");
(
StatusCode::OK,
Json(SuccessResponse {
message: "Cache cleared successfully".to_string(),
}),
)
.into_response()
} else {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Failed to downcast cache plugin".to_string(),
}),
)
.into_response()
}
}
_ => (
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Unknown action: {}", request.action),
}),
)
.into_response(),
}
}
async fn cache_stats(State(state): State<Arc<AdminState>>) -> Response {
let cache = match get_cache_plugin(&state.registry) {
Ok(c) => c,
Err(e) => return e,
};
if let Some(cache_plugin) = cache.as_ref().as_any().downcast_ref::<CachePlugin>() {
let stats = cache_plugin.stats();
let hits = stats.hits();
let misses = stats.misses();
let total = hits + misses;
let hit_rate = if total > 0 {
(hits as f64 / total as f64) * 100.0
} else {
0.0
};
let response = CacheStatsResponse {
size: cache_plugin.size(),
hits,
misses,
evictions: stats.evictions(),
expirations: stats.expirations(),
hit_rate,
};
(StatusCode::OK, Json(response)).into_response()
} else {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Failed to downcast cache plugin".to_string(),
}),
)
.into_response()
}
}
async fn reload_config(
State(state): State<Arc<AdminState>>,
Json(request): Json<ReloadConfigRequest>,
) -> Response {
let path = request.path.unwrap_or_else(|| "config.yaml".to_string());
match crate::config::loader::load_from_file(&path) {
Ok(new_config) => match new_config.validate() {
Ok(_) => {
let mut config = state.config.write().await;
*config = new_config;
info!("Configuration reloaded from {} via admin API", path);
(
StatusCode::OK,
Json(SuccessResponse {
message: format!("Configuration reloaded from {}", path),
}),
)
.into_response()
}
Err(e) => {
warn!("Configuration validation failed: {}", e);
(
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Configuration validation failed: {}", e),
}),
)
.into_response()
}
},
Err(e) => {
warn!("Failed to load configuration: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Failed to load configuration: {}", e),
}),
)
.into_response()
}
}
}
async fn server_stats(State(state): State<Arc<AdminState>>) -> impl IntoResponse {
#[derive(Serialize)]
struct StatusResponse {
status: String,
version: String,
uptime: String,
}
let elapsed = state.start_time.elapsed();
let uptime = format_duration(elapsed);
let response = StatusResponse {
status: "running".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
uptime,
};
(StatusCode::OK, Json(response))
}
fn format_duration(d: Duration) -> String {
let secs = d.as_secs();
let days = secs / 86400;
let hours = (secs % 86400) / 3600;
let minutes = (secs % 3600) / 60;
let seconds = secs % 60;
if days > 0 {
format!("{}d {:02}:{:02}:{:02}", days, hours, minutes, seconds)
} else {
format!("{:02}:{:02}:{:02}", hours, minutes, seconds)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_control_request_serialization() {
let req = CacheControlRequest {
action: "clear".to_string(),
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("clear"));
let deserialized: CacheControlRequest = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.action, "clear");
}
#[test]
fn test_cache_stats_response_serialization() {
let resp = CacheStatsResponse {
size: 100,
hits: 80,
misses: 20,
evictions: 5,
expirations: 15,
hit_rate: 80.0,
};
let json = serde_json::to_string(&resp).unwrap();
assert!(json.contains("100"));
assert!(json.contains("80.0"));
let deserialized: CacheStatsResponse = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.size, 100);
assert_eq!(deserialized.hits, 80);
assert_eq!(deserialized.hit_rate, 80.0);
}
#[test]
fn test_reload_config_request_with_path() {
let req = ReloadConfigRequest {
path: Some("/etc/config.yaml".to_string()),
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("/etc/config.yaml"));
let deserialized: ReloadConfigRequest = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.path, Some("/etc/config.yaml".to_string()));
}
#[test]
fn test_reload_config_request_without_path() {
let req = ReloadConfigRequest { path: None };
let json = serde_json::to_string(&req).unwrap();
let deserialized: ReloadConfigRequest = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.path, None);
}
#[test]
fn test_success_response_serialization() {
let resp = SuccessResponse {
message: "Operation successful".to_string(),
};
let json = serde_json::to_string(&resp).unwrap();
assert!(json.contains("Operation successful"));
let deserialized: SuccessResponse = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.message, "Operation successful");
}
#[test]
fn test_error_response_serialization() {
let resp = ErrorResponse {
error: "An error occurred".to_string(),
};
let json = serde_json::to_string(&resp).unwrap();
assert!(json.contains("An error occurred"));
let deserialized: ErrorResponse = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.error, "An error occurred");
}
#[test]
fn test_admin_state_creation() {
let config = Arc::new(RwLock::new(Config::new()));
let registry = Arc::new(Registry::new());
let _state = AdminState::new(Arc::clone(&config), Arc::clone(®istry));
assert_eq!(registry.len(), 0);
}
#[test]
fn test_admin_state_is_clone() {
let config = Arc::new(RwLock::new(Config::new()));
let registry = Arc::new(Registry::new());
let state = AdminState::new(Arc::clone(&config), Arc::clone(®istry));
let _cloned = state.clone();
}
#[test]
fn test_admin_server_creation() {
let config = Arc::new(RwLock::new(Config::new()));
let registry = Arc::new(Registry::new());
let state = AdminState::new(Arc::clone(&config), Arc::clone(®istry));
let _server = AdminServer::new("127.0.0.1:9999", state);
}
#[test]
fn test_admin_server_creation_with_shorthand() {
let config = Arc::new(RwLock::new(Config::new()));
let registry = Arc::new(Registry::new());
let state = AdminState::new(Arc::clone(&config), Arc::clone(®istry));
let _server = AdminServer::new(":9999", state);
}
#[tokio::test]
async fn test_server_stats_endpoint() {
let config = Arc::new(RwLock::new(Config::new()));
let registry = Arc::new(Registry::new());
let state = Arc::new(AdminState::new(Arc::clone(&config), Arc::clone(®istry)));
let response = server_stats(State(state)).await.into_response();
let (parts, body) = response.into_parts();
assert_eq!(parts.status, StatusCode::OK);
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
assert!(body_str.contains("running"));
assert!(body_str.contains("version"));
assert!(body_str.contains("uptime"));
}
#[test]
fn test_format_duration_uptime_style() {
assert_eq!(format_duration(Duration::from_secs(0)), "00:00:00");
assert_eq!(format_duration(Duration::from_secs(65)), "00:01:05");
assert_eq!(format_duration(Duration::from_secs(3661)), "01:01:01");
let one_day_one_hour = Duration::from_secs(86400 + 3600);
assert_eq!(format_duration(one_day_one_hour), "1d 01:00:00");
let days = Duration::from_secs(8 * 86400 + 22 * 3600 + 21 * 60);
assert_eq!(format_duration(days), "8d 22:21:00");
}
#[tokio::test]
async fn test_cache_stats_with_no_cache_plugin() {
let config = Arc::new(RwLock::new(Config::new()));
let registry = Arc::new(Registry::new());
let state = Arc::new(AdminState::new(Arc::clone(&config), Arc::clone(®istry)));
let response = cache_stats(State(state)).await;
let (parts, _body) = response.into_parts();
assert_eq!(parts.status, StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn test_cache_control_with_no_cache_plugin() {
let config = Arc::new(RwLock::new(Config::new()));
let registry = Arc::new(Registry::new());
let state = Arc::new(AdminState::new(Arc::clone(&config), Arc::clone(®istry)));
let request = CacheControlRequest {
action: "clear".to_string(),
};
let response = cache_control(State(state), Json(request)).await;
let (parts, _body) = response.into_parts();
assert_eq!(parts.status, StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn test_cache_control_with_unknown_action() {
let config = Arc::new(RwLock::new(Config::new()));
let registry = Arc::new(Registry::new());
let state = Arc::new(AdminState::new(Arc::clone(&config), Arc::clone(®istry)));
let request = CacheControlRequest {
action: "unknown_action".to_string(),
};
let response = cache_control(State(state), Json(request)).await;
let (parts, _body) = response.into_parts();
assert_eq!(parts.status, StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn test_reload_config_with_invalid_path() {
let config = Arc::new(RwLock::new(Config::new()));
let registry = Arc::new(Registry::new());
let state = Arc::new(AdminState::new(Arc::clone(&config), Arc::clone(®istry)));
let request = ReloadConfigRequest {
path: Some("/nonexistent/path/config.yaml".to_string()),
};
let response = reload_config(State(state), Json(request)).await;
let (parts, _body) = response.into_parts();
assert_eq!(parts.status, StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
async fn test_reload_config_with_default_path() {
let config = Arc::new(RwLock::new(Config::new()));
let registry = Arc::new(Registry::new());
let state = Arc::new(AdminState::new(Arc::clone(&config), Arc::clone(®istry)));
let request = ReloadConfigRequest { path: None };
let response = reload_config(State(state), Json(request)).await;
let (parts, _body) = response.into_parts();
assert!(
parts.status == StatusCode::INTERNAL_SERVER_ERROR
|| parts.status == StatusCode::OK
|| parts.status == StatusCode::BAD_REQUEST
);
}
#[test]
fn test_hit_rate_calculation_with_all_hits() {
let hits = 100u64;
let misses = 0u64;
let total = hits + misses;
let hit_rate = if total > 0 {
(hits as f64 / total as f64) * 100.0
} else {
0.0
};
assert_eq!(hit_rate, 100.0);
}
#[test]
fn test_hit_rate_calculation_with_all_misses() {
let hits = 0u64;
let misses = 100u64;
let total = hits + misses;
let hit_rate = if total > 0 {
(hits as f64 / total as f64) * 100.0
} else {
0.0
};
assert_eq!(hit_rate, 0.0);
}
#[test]
fn test_hit_rate_calculation_mixed() {
let hits = 80u64;
let misses = 20u64;
let total = hits + misses;
let hit_rate = if total > 0 {
(hits as f64 / total as f64) * 100.0
} else {
0.0
};
assert!((hit_rate - 80.0).abs() < 0.01);
}
#[test]
fn test_hit_rate_calculation_zero_queries() {
let hits = 0u64;
let misses = 0u64;
let total = hits + misses;
let hit_rate = if total > 0 {
(hits as f64 / total as f64) * 100.0
} else {
0.0
};
assert_eq!(hit_rate, 0.0);
}
#[test]
fn test_cache_stats_response_with_large_numbers() {
let resp = CacheStatsResponse {
size: 1_000_000,
hits: 10_000_000,
misses: 2_000_000,
evictions: 50_000,
expirations: 100_000,
hit_rate: 83.33,
};
assert_eq!(resp.size, 1_000_000);
assert_eq!(resp.hits, 10_000_000);
assert_eq!(resp.misses, 2_000_000);
assert_eq!(resp.evictions, 50_000);
assert!((resp.hit_rate - 83.33).abs() < 0.01);
}
#[test]
fn test_cache_stats_response_zero_values() {
let resp = CacheStatsResponse {
size: 0,
hits: 0,
misses: 0,
evictions: 0,
expirations: 0,
hit_rate: 0.0,
};
assert_eq!(resp.size, 0);
assert_eq!(resp.hits, 0);
assert_eq!(resp.misses, 0);
assert_eq!(resp.evictions, 0);
assert_eq!(resp.hit_rate, 0.0);
}
#[test]
fn test_cache_control_request_with_empty_action() {
let req = CacheControlRequest {
action: String::new(),
};
let json = serde_json::to_string(&req).unwrap();
let deserialized: CacheControlRequest = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.action, "");
}
#[test]
fn test_cache_control_request_with_special_characters() {
let req = CacheControlRequest {
action: "clear-with-dashes_and_underscores".to_string(),
};
let json = serde_json::to_string(&req).unwrap();
let deserialized: CacheControlRequest = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.action, "clear-with-dashes_and_underscores");
}
#[test]
fn test_success_response_with_long_message() {
let long_msg = "a".repeat(1000);
let resp = SuccessResponse {
message: long_msg.clone(),
};
let json = serde_json::to_string(&resp).unwrap();
let deserialized: SuccessResponse = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.message, long_msg);
}
#[test]
fn test_error_response_with_unicode() {
let resp = ErrorResponse {
error: "Error: 无法访问缓存".to_string(),
};
let json = serde_json::to_string(&resp).unwrap();
let deserialized: ErrorResponse = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.error, "Error: 无法访问缓存");
}
}