#[cfg(feature = "anomaly-detection")]
use crate::anomaly::AnomalyDetector;
use crate::config::{NodeConfig, NodeRole, ProxyConfig};
#[cfg(feature = "edge-proxy")]
use crate::edge::{EdgeCache, EdgeRegistry, InvalidationEvent};
#[cfg(feature = "wasm-plugins")]
use crate::plugins::PluginManager;
#[cfg(feature = "ha-tr")]
use crate::replay::{ReplayEngine, TimeTravelRequest};
use crate::server::{NodeHealth, ServerMetricsSnapshot};
use crate::{ProxyError, Result};
#[cfg(feature = "ha-tr")]
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{broadcast, RwLock};
const ADMIN_UI_HTML: &str = include_str!("admin_ui.html");
pub struct AdminServer {
listen_address: String,
state: Arc<AdminState>,
shutdown_tx: broadcast::Sender<()>,
}
pub struct AdminState {
pub node_health: RwLock<HashMap<String, NodeHealth>>,
pub metrics: RwLock<ServerMetricsSnapshot>,
pub active_sessions: RwLock<u64>,
pub config_snapshot: RwLock<ConfigSnapshot>,
pub proxy_config: RwLock<Option<ProxyConfig>>,
read_lb_counter: AtomicUsize,
commands: RwLock<HashMap<String, CommandHandler>>,
#[cfg(feature = "ha-tr")]
pub replay_engine: RwLock<Option<Arc<ReplayEngine>>>,
#[cfg(feature = "wasm-plugins")]
pub plugin_manager: RwLock<Option<Arc<PluginManager>>>,
pub chaos_overrides: RwLock<HashMap<String, ChaosOverride>>,
#[cfg(feature = "anomaly-detection")]
pub anomaly_detector: RwLock<Option<Arc<AnomalyDetector>>>,
#[cfg(feature = "edge-proxy")]
pub edge_cache: RwLock<Option<Arc<EdgeCache>>>,
#[cfg(feature = "edge-proxy")]
pub edge_registry: RwLock<Option<Arc<EdgeRegistry>>>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ChaosOverride {
pub since: String,
pub kind: String,
pub note: String,
}
type CommandHandler = Arc<dyn Fn(&[&str]) -> Result<String> + Send + Sync>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfigSnapshot {
pub listen_address: String,
pub admin_address: String,
pub tr_enabled: bool,
pub tr_mode: String,
pub pool_min_connections: usize,
pub pool_max_connections: usize,
pub nodes: Vec<NodeSnapshot>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeSnapshot {
pub address: String,
pub role: String,
pub weight: u32,
pub enabled: bool,
}
impl AdminServer {
pub fn new(listen_address: String, state: Arc<AdminState>) -> Self {
let (shutdown_tx, _) = broadcast::channel(1);
Self {
listen_address,
state,
shutdown_tx,
}
}
pub async fn run(&self) -> Result<()> {
let listener = TcpListener::bind(&self.listen_address)
.await
.map_err(|e| ProxyError::Network(format!("Failed to bind admin: {}", e)))?;
tracing::info!("Admin API listening on {}", self.listen_address);
let mut shutdown_rx = self.shutdown_tx.subscribe();
loop {
tokio::select! {
accept_result = listener.accept() => {
match accept_result {
Ok((stream, addr)) => {
let state = self.state.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(stream, addr, state).await {
tracing::error!("Admin connection error: {}", e);
}
});
}
Err(e) => {
tracing::error!("Admin accept error: {}", e);
}
}
}
_ = shutdown_rx.recv() => {
tracing::info!("Admin server shutting down");
break;
}
}
}
Ok(())
}
async fn handle_connection(
mut stream: TcpStream,
addr: SocketAddr,
state: Arc<AdminState>,
) -> Result<()> {
tracing::debug!("Admin connection from {}", addr);
let (reader, mut writer) = stream.split();
let mut reader = BufReader::new(reader);
let mut line = String::new();
let mut headers = Vec::new();
let mut content_length: usize = 0;
loop {
line.clear();
let bytes_read = reader
.read_line(&mut line)
.await
.map_err(|e| ProxyError::Network(format!("Read error: {}", e)))?;
if bytes_read == 0 || line == "\r\n" {
break;
}
let trimmed = line.trim();
if trimmed.to_lowercase().starts_with("content-length:") {
if let Some(len_str) = trimmed.split(':').nth(1) {
content_length = len_str.trim().parse().unwrap_or(0);
}
}
headers.push(trimmed.to_string());
}
if headers.is_empty() {
return Ok(());
}
let request_line = &headers[0];
let parts: Vec<&str> = request_line.split_whitespace().collect();
if parts.len() < 2 {
Self::send_response(&mut writer, 400, "Bad Request", "Invalid request line").await?;
return Ok(());
}
let method = parts[0];
let path = parts[1];
let body = if content_length > 0 && (method == "POST" || method == "PUT") {
let mut body_buf = vec![0u8; content_length];
reader.read_exact(&mut body_buf).await
.map_err(|e| ProxyError::Network(format!("Body read error: {}", e)))?;
Some(String::from_utf8_lossy(&body_buf).to_string())
} else {
None
};
if method == "GET" && (path == "/" || path == "/ui" || path == "/ui/") {
Self::send_html_response(&mut writer, 200, ADMIN_UI_HTML).await?;
return Ok(());
}
let response = Self::route_request(method, path, body.as_deref(), &state).await;
match response {
Ok((status, body)) => {
Self::send_json_response(&mut writer, status, &body).await?;
}
Err(e) => {
let error = ErrorResponse {
error: e.to_string(),
};
Self::send_json_response(&mut writer, 500, &error).await?;
}
}
Ok(())
}
async fn send_html_response(
writer: &mut tokio::net::tcp::WriteHalf<'_>,
status: u16,
html: &str,
) -> Result<()> {
let status_text = match status {
200 => "OK",
404 => "Not Found",
_ => "Unknown",
};
let response = format!(
"HTTP/1.1 {} {}\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
status,
status_text,
html.len(),
html
);
writer
.write_all(response.as_bytes())
.await
.map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
Ok(())
}
async fn route_request(
method: &str,
path: &str,
body: Option<&str>,
state: &Arc<AdminState>,
) -> Result<(u16, serde_json::Value)> {
match (method, path) {
("POST", "/api/sql") => {
Self::handle_sql_request(body, state).await
}
("GET", "/health") => {
let health = HealthResponse { status: "ok" };
Ok((200, serde_json::to_value(health)?))
}
("GET", "/health/ready") => {
let ready = Self::check_readiness(state).await;
let response = ReadinessResponse {
ready,
message: if ready {
"Proxy is ready"
} else {
"Proxy is not ready"
},
};
let status = if ready { 200 } else { 503 };
Ok((status, serde_json::to_value(response)?))
}
("GET", "/health/live") => {
let response = LivenessResponse { alive: true };
Ok((200, serde_json::to_value(response)?))
}
("GET", "/metrics") => {
let metrics = state.metrics.read().await.clone();
Ok((200, serde_json::to_value(MetricsResponse::from(metrics))?))
}
("GET", "/metrics/prometheus") => {
let metrics = state.metrics.read().await.clone();
let prometheus = Self::format_prometheus_metrics(&metrics);
Ok((200, serde_json::json!({ "text": prometheus })))
}
("GET", "/nodes") => {
let health = state.node_health.read().await;
let nodes: Vec<NodeHealthResponse> = health
.values()
.map(|h| NodeHealthResponse::from(h.clone()))
.collect();
Ok((200, serde_json::to_value(nodes)?))
}
("GET", path) if path.starts_with("/nodes/") => {
let node_addr = path.trim_start_matches("/nodes/");
let health = state.node_health.read().await;
match health.get(node_addr) {
Some(h) => Ok((200, serde_json::to_value(NodeHealthResponse::from(h.clone()))?)),
None => Ok((404, serde_json::json!({ "error": "Node not found" }))),
}
}
("POST", path) if path.starts_with("/nodes/") && path.ends_with("/enable") => {
let node_addr = path
.trim_start_matches("/nodes/")
.trim_end_matches("/enable");
Self::set_node_enabled(state, node_addr, true).await?;
Ok((200, serde_json::json!({ "status": "enabled" })))
}
("POST", path) if path.starts_with("/nodes/") && path.ends_with("/disable") => {
let node_addr = path
.trim_start_matches("/nodes/")
.trim_end_matches("/disable");
Self::set_node_enabled(state, node_addr, false).await?;
Ok((200, serde_json::json!({ "status": "disabled" })))
}
("GET", "/topology") => {
let topo = Self::compute_topology(state).await;
Ok((200, serde_json::to_value(topo)?))
}
#[cfg(feature = "ha-tr")]
("POST", "/api/replay") => Self::handle_replay_request(body, state).await,
#[cfg(not(feature = "ha-tr"))]
("POST", "/api/replay") => Ok((
503,
serde_json::json!({ "error": "ha-tr feature not compiled in" }),
)),
#[cfg(feature = "ha-tr")]
("POST", "/api/shadow") => Self::handle_shadow_request(body).await,
#[cfg(not(feature = "ha-tr"))]
("POST", "/api/shadow") => Ok((
503,
serde_json::json!({ "error": "ha-tr feature not compiled in" }),
)),
("GET", "/plugins") => Self::handle_plugins_list(state).await,
#[cfg(feature = "anomaly-detection")]
("GET", p) if p == "/anomalies" || p.starts_with("/anomalies?") => {
Self::handle_anomalies_list(p, state).await
}
#[cfg(not(feature = "anomaly-detection"))]
("GET", p) if p == "/anomalies" || p.starts_with("/anomalies?") => Ok((
503,
serde_json::json!({ "error": "anomaly-detection feature not compiled in" }),
)),
#[cfg(feature = "edge-proxy")]
("GET", "/api/edge") => Self::handle_edge_status(state).await,
#[cfg(feature = "edge-proxy")]
("POST", "/api/edge/register") => {
Self::handle_edge_register(body, state).await
}
#[cfg(feature = "edge-proxy")]
("POST", "/api/edge/invalidate") => {
Self::handle_edge_invalidate(body, state).await
}
#[cfg(not(feature = "edge-proxy"))]
("GET", "/api/edge")
| ("POST", "/api/edge/register")
| ("POST", "/api/edge/invalidate") => Ok((
503,
serde_json::json!({ "error": "edge-proxy feature not compiled in" }),
)),
("POST", "/api/chaos") => Self::handle_chaos_request(body, state).await,
("GET", "/api/chaos") => {
let overrides = state.chaos_overrides.read().await.clone();
Ok((200, serde_json::to_value(overrides)?))
}
("GET", "/config") => {
let config = state.config_snapshot.read().await.clone();
Ok((200, serde_json::to_value(config)?))
}
("GET", "/sessions") => {
let count = *state.active_sessions.read().await;
let response = SessionsResponse {
active_sessions: count,
};
Ok((200, serde_json::to_value(response)?))
}
("GET", "/pools") => {
let pools = Self::get_pool_stats(state).await;
Ok((200, serde_json::to_value(pools)?))
}
("GET", "/version") => {
let response = VersionResponse {
version: crate::VERSION.to_string(),
build_time: env!("CARGO_PKG_VERSION").to_string(),
};
Ok((200, serde_json::to_value(response)?))
}
_ => Ok((404, serde_json::json!({ "error": "Not found" }))),
}
}
async fn handle_sql_request(
body: Option<&str>,
state: &Arc<AdminState>,
) -> Result<(u16, serde_json::Value)> {
let body = body.ok_or_else(|| ProxyError::Internal("Missing request body".to_string()))?;
let request: SqlRequest = serde_json::from_str(body)
.map_err(|e| ProxyError::Internal(format!("Invalid JSON: {}", e)))?;
let sql = request.query.trim();
if sql.is_empty() {
return Ok((400, serde_json::json!({ "error": "Empty query" })));
}
let is_write = Self::is_write_query(sql);
let query_type = if is_write { "write" } else { "read" };
let proxy_config = state.proxy_config.read().await;
let config = proxy_config.as_ref()
.ok_or_else(|| ProxyError::Internal("Proxy config not initialized".to_string()))?;
let health = state.node_health.read().await;
let target_node = if is_write {
Self::select_primary_node(config, &health)?
} else {
Self::select_read_node(config, &health, state)?
};
let target_address = format!("{}:{}", target_node.host, target_node.port);
let http_port = target_node.http_port;
let http_url = format!("http://{}:{}/api/sql", target_node.host, http_port);
tracing::debug!(
"Routing {} query to {} ({})",
query_type,
target_address,
match target_node.role {
NodeRole::Primary => "primary",
NodeRole::Standby => "standby",
NodeRole::ReadReplica => "replica",
}
);
let result = Self::forward_sql_request(&http_url, sql).await?;
let response = SqlResponse {
query_type: query_type.to_string(),
routed_to: target_address,
node_role: format!("{:?}", target_node.role).to_lowercase(),
result,
};
Ok((200, serde_json::to_value(response)?))
}
fn is_write_query(sql: &str) -> bool {
let upper = sql.trim().to_uppercase();
if upper.starts_with("INSERT")
|| upper.starts_with("UPDATE")
|| upper.starts_with("DELETE")
|| upper.starts_with("CREATE")
|| upper.starts_with("ALTER")
|| upper.starts_with("DROP")
|| upper.starts_with("TRUNCATE")
|| upper.starts_with("GRANT")
|| upper.starts_with("REVOKE")
|| upper.starts_with("VACUUM")
|| upper.starts_with("REINDEX")
|| upper.starts_with("MERGE")
|| upper.starts_with("UPSERT")
{
return true;
}
if upper.starts_with("BEGIN")
|| upper.starts_with("COMMIT")
|| upper.starts_with("ROLLBACK")
|| upper.starts_with("SAVEPOINT")
{
return true;
}
false
}
fn select_primary_node<'a>(
config: &'a ProxyConfig,
health: &HashMap<String, NodeHealth>,
) -> Result<&'a NodeConfig> {
config.nodes.iter()
.find(|n| {
n.role == NodeRole::Primary
&& n.enabled
&& health.get(&n.address()).map(|h| h.healthy).unwrap_or(false)
})
.ok_or_else(|| ProxyError::Internal("No healthy primary node available".to_string()))
}
fn select_read_node<'a>(
config: &'a ProxyConfig,
health: &HashMap<String, NodeHealth>,
state: &AdminState,
) -> Result<&'a NodeConfig> {
let healthy_nodes: Vec<&NodeConfig> = config.nodes.iter()
.filter(|n| n.enabled && health.get(&n.address()).map(|h| h.healthy).unwrap_or(false))
.collect();
if healthy_nodes.is_empty() {
return Err(ProxyError::Internal("No healthy nodes available".to_string()));
}
if config.load_balancer.read_write_split {
let read_nodes: Vec<&NodeConfig> = healthy_nodes.iter()
.filter(|n| n.role == NodeRole::Standby || n.role == NodeRole::ReadReplica)
.copied()
.collect();
if !read_nodes.is_empty() {
let counter = state.read_lb_counter.fetch_add(1, Ordering::Relaxed);
let index = counter % read_nodes.len();
return Ok(read_nodes[index]);
}
}
let counter = state.read_lb_counter.fetch_add(1, Ordering::Relaxed);
let index = counter % healthy_nodes.len();
Ok(healthy_nodes[index])
}
async fn forward_sql_request(url: &str, sql: &str) -> Result<serde_json::Value> {
let request_body = serde_json::json!({ "query": sql });
let body_bytes = serde_json::to_vec(&request_body)
.map_err(|e| ProxyError::Internal(format!("JSON serialization error: {}", e)))?;
let url_parts: Vec<&str> = url.trim_start_matches("http://").splitn(2, '/').collect();
if url_parts.is_empty() {
return Err(ProxyError::Internal("Invalid URL".to_string()));
}
let host_port = url_parts[0];
let path = if url_parts.len() > 1 {
format!("/{}", url_parts[1])
} else {
"/".to_string()
};
let stream = TcpStream::connect(host_port).await
.map_err(|e| ProxyError::Network(format!("Failed to connect to {}: {}", host_port, e)))?;
let (reader, mut writer) = stream.into_split();
let mut reader = BufReader::new(reader);
let request = format!(
"POST {} HTTP/1.1\r\nHost: {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
path,
host_port,
body_bytes.len()
);
writer.write_all(request.as_bytes()).await
.map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
writer.write_all(&body_bytes).await
.map_err(|e| ProxyError::Network(format!("Write body error: {}", e)))?;
let mut response_headers = Vec::new();
let mut line = String::new();
let mut content_length: usize = 0;
loop {
line.clear();
let bytes_read = reader.read_line(&mut line).await
.map_err(|e| ProxyError::Network(format!("Response read error: {}", e)))?;
if bytes_read == 0 || line == "\r\n" {
break;
}
let trimmed = line.trim();
if trimmed.to_lowercase().starts_with("content-length:") {
if let Some(len_str) = trimmed.split(':').nth(1) {
content_length = len_str.trim().parse().unwrap_or(0);
}
}
response_headers.push(trimmed.to_string());
}
let mut body_buf = vec![0u8; content_length];
if content_length > 0 {
reader.read_exact(&mut body_buf).await
.map_err(|e| ProxyError::Network(format!("Response body read error: {}", e)))?;
}
let response_body = String::from_utf8_lossy(&body_buf);
serde_json::from_str(&response_body)
.map_err(|e| ProxyError::Internal(format!("Invalid JSON response: {} - body: {}", e, response_body)))
}
async fn check_readiness(state: &Arc<AdminState>) -> bool {
let health = state.node_health.read().await;
health.values().any(|h| h.healthy)
}
async fn set_node_enabled(state: &Arc<AdminState>, node_addr: &str, enabled: bool) -> Result<()> {
let mut health = state.node_health.write().await;
if let Some(node_health) = health.get_mut(node_addr) {
node_health.healthy = enabled;
Ok(())
} else {
Err(ProxyError::Config(format!("Node not found: {}", node_addr)))
}
}
async fn get_pool_stats(_state: &Arc<AdminState>) -> Vec<PoolStatsResponse> {
Vec::new()
}
#[cfg(feature = "ha-tr")]
async fn handle_replay_request(
body: Option<&str>,
state: &Arc<AdminState>,
) -> Result<(u16, serde_json::Value)> {
let raw = body.ok_or_else(|| {
ProxyError::Internal("replay: empty request body".to_string())
})?;
let req: ReplayRequestBody = match serde_json::from_str(raw) {
Ok(r) => r,
Err(e) => {
return Ok((
400,
serde_json::json!({ "error": format!("invalid body: {}", e) }),
));
}
};
let engine = match state.replay_engine.read().await.clone() {
Some(e) => e,
None => {
return Ok((
503,
serde_json::json!({ "error": "replay engine not attached" }),
));
}
};
let tt = TimeTravelRequest {
from: req.from,
to: req.to,
target_host: req.target_host,
target_port: req.target_port,
target_user: req.target_user,
target_password: req.target_password,
target_database: req.target_database,
};
match engine.replay_window(&tt).await {
Ok(summary) => Ok((200, serde_json::to_value(summary)?)),
Err(e) => Ok((
500,
serde_json::json!({ "error": format!("replay failed: {}", e) }),
)),
}
}
#[cfg(feature = "edge-proxy")]
async fn handle_edge_status(
state: &Arc<AdminState>,
) -> Result<(u16, serde_json::Value)> {
let cache_stats = match state.edge_cache.read().await.clone() {
Some(c) => Some(c.stats()),
None => None,
};
let edges = match state.edge_registry.read().await.clone() {
Some(r) => r.list(),
None => Vec::new(),
};
Ok((200, serde_json::json!({
"cache": cache_stats,
"registered": edges,
"edge_count": edges.len(),
})))
}
#[cfg(feature = "edge-proxy")]
async fn handle_edge_register(
body: Option<&str>,
state: &Arc<AdminState>,
) -> Result<(u16, serde_json::Value)> {
let raw = body.ok_or_else(|| {
ProxyError::Internal("edge register: empty body".to_string())
})?;
let req: EdgeRegisterBody = match serde_json::from_str(raw) {
Ok(r) => r,
Err(e) => {
return Ok((
400,
serde_json::json!({ "error": format!("invalid body: {}", e) }),
));
}
};
let registry = match state.edge_registry.read().await.clone() {
Some(r) => r,
None => {
return Ok((
503,
serde_json::json!({ "error": "edge registry not attached" }),
));
}
};
let now = chrono::Utc::now().to_rfc3339();
match registry.register(&req.edge_id, &req.region, &req.base_url, &now) {
Ok(_rx) => {
Ok((201, serde_json::json!({
"edge_id": req.edge_id,
"region": req.region,
"base_url": req.base_url,
"registered_at": now,
})))
}
Err(e) => Ok((
503,
serde_json::json!({ "error": e.to_string() }),
)),
}
}
#[cfg(feature = "edge-proxy")]
async fn handle_edge_invalidate(
body: Option<&str>,
state: &Arc<AdminState>,
) -> Result<(u16, serde_json::Value)> {
let raw = body.ok_or_else(|| {
ProxyError::Internal("edge invalidate: empty body".to_string())
})?;
let req: EdgeInvalidateBody = match serde_json::from_str(raw) {
Ok(r) => r,
Err(e) => {
return Ok((
400,
serde_json::json!({ "error": format!("invalid body: {}", e) }),
));
}
};
let cache = match state.edge_cache.read().await.clone() {
Some(c) => c,
None => {
return Ok((
503,
serde_json::json!({ "error": "edge cache not attached" }),
));
}
};
let registry = match state.edge_registry.read().await.clone() {
Some(r) => r,
None => {
return Ok((
503,
serde_json::json!({ "error": "edge registry not attached" }),
));
}
};
let version = req.up_to_version.unwrap_or_else(|| cache.next_version());
let dropped_local = cache.invalidate(version, &req.tables);
let ev = InvalidationEvent {
up_to_version: version,
tables: req.tables.clone(),
committed_at: chrono::Utc::now().to_rfc3339(),
};
let (sent, pruned) = registry.broadcast(ev).await;
Ok((200, serde_json::json!({
"version": version,
"tables": req.tables,
"dropped_local": dropped_local,
"edges_notified": sent,
"edges_pruned": pruned,
})))
}
#[cfg(feature = "anomaly-detection")]
async fn handle_anomalies_list(
path: &str,
state: &Arc<AdminState>,
) -> Result<(u16, serde_json::Value)> {
let limit = parse_limit_query(path, 100, 1024);
let det = match state.anomaly_detector.read().await.clone() {
Some(d) => d,
None => {
return Ok((
503,
serde_json::json!({ "error": "anomaly detector not attached" }),
));
}
};
let events = det.recent_events(limit);
Ok((200, serde_json::json!({
"count": events.len(),
"limit": limit,
"events": events,
"buffer_total": det.event_count(),
})))
}
#[cfg(feature = "ha-tr")]
async fn handle_shadow_request(
body: Option<&str>,
) -> Result<(u16, serde_json::Value)> {
use crate::backend::{tls::default_client_config, BackendClient, BackendConfig, ParamValue, TlsMode};
use crate::shadow_execute::shadow_execute;
let raw = body.ok_or_else(|| {
ProxyError::Internal("shadow: empty request body".to_string())
})?;
let req: ShadowRequestBody = match serde_json::from_str(raw) {
Ok(r) => r,
Err(e) => {
return Ok((
400,
serde_json::json!({ "error": format!("invalid body: {}", e) }),
));
}
};
let mk_cfg = |host: String, port: u16, user: Option<String>, password: Option<String>, database: Option<String>| BackendConfig {
host,
port,
user: user.unwrap_or_else(|| "postgres".into()),
password,
database,
application_name: Some("heliosdb-proxy-shadow".into()),
tls_mode: TlsMode::Disable,
connect_timeout: std::time::Duration::from_secs(5),
query_timeout: std::time::Duration::from_secs(30),
tls_config: default_client_config(),
};
let source_cfg = mk_cfg(
req.source_host,
req.source_port,
req.source_user,
req.source_password,
req.source_database,
);
let shadow_cfg = mk_cfg(
req.shadow_host,
req.shadow_port,
req.shadow_user,
req.shadow_password,
req.shadow_database,
);
let mut source = match BackendClient::connect(&source_cfg).await {
Ok(c) => c,
Err(e) => {
return Ok((
500,
serde_json::json!({ "error": format!("source connect: {}", e) }),
));
}
};
let params: Vec<ParamValue> = req
.params
.unwrap_or_default()
.into_iter()
.map(|s| ParamValue::Text(s))
.collect();
let outcome = shadow_execute(&mut source, &shadow_cfg, &req.sql, ¶ms).await;
source.close().await;
match outcome {
Ok((_qr, report)) => Ok((200, serde_json::json!({
"sql": report.sql,
"both_succeeded": report.both_succeeded,
"row_count_match": report.row_count_match,
"row_hash_match": report.row_hash_match,
"primary_elapsed_us": report.primary_elapsed_us,
"shadow_elapsed_us": report.shadow_elapsed_us,
"primary_error": report.primary_error,
"shadow_error": report.shadow_error,
"is_clean": report.is_clean(),
}))),
Err(e) => Ok((
500,
serde_json::json!({ "error": format!("shadow execute: {}", e) }),
)),
}
}
async fn handle_chaos_request(
body: Option<&str>,
state: &Arc<AdminState>,
) -> Result<(u16, serde_json::Value)> {
let raw = body.ok_or_else(|| {
ProxyError::Internal("chaos: empty request body".to_string())
})?;
let action: ChaosAction = match serde_json::from_str(raw) {
Ok(a) => a,
Err(e) => {
return Ok((
400,
serde_json::json!({ "error": format!("invalid body: {}", e) }),
));
}
};
match action {
ChaosAction::ForceUnhealthy { target_node } => {
if let Err(e) = Self::set_node_enabled(state, &target_node, false).await {
return Ok((
404,
serde_json::json!({ "error": e.to_string() }),
));
}
state.chaos_overrides.write().await.insert(
target_node.clone(),
ChaosOverride {
since: chrono::Utc::now().to_rfc3339(),
kind: "force_unhealthy".to_string(),
note: format!("forced unhealthy via chaos endpoint"),
},
);
Ok((200, serde_json::json!({
"applied": "force_unhealthy",
"target_node": target_node,
})))
}
ChaosAction::Restore { target_node } => {
if let Err(e) = Self::set_node_enabled(state, &target_node, true).await {
return Ok((
404,
serde_json::json!({ "error": e.to_string() }),
));
}
state.chaos_overrides.write().await.remove(&target_node);
Ok((200, serde_json::json!({
"restored": target_node,
})))
}
ChaosAction::Reset => {
let overrides: Vec<String> =
state.chaos_overrides.read().await.keys().cloned().collect();
let mut restored = Vec::with_capacity(overrides.len());
for addr in overrides {
let _ = Self::set_node_enabled(state, &addr, true).await;
restored.push(addr);
}
state.chaos_overrides.write().await.clear();
Ok((200, serde_json::json!({
"reset": true,
"restored": restored,
})))
}
}
}
#[cfg(feature = "wasm-plugins")]
async fn handle_plugins_list(state: &Arc<AdminState>) -> Result<(u16, serde_json::Value)> {
let pm = match state.plugin_manager.read().await.clone() {
Some(p) => p,
None => {
return Ok((
503,
serde_json::json!({ "error": "plugin manager not attached" }),
));
}
};
let plugins: Vec<PluginListEntry> = pm
.list_plugins()
.into_iter()
.map(|info| PluginListEntry {
name: info.name,
version: info.version,
description: info.description,
hooks: info
.hooks
.iter()
.map(|h| h.export_name().to_string())
.collect(),
state: format!("{:?}", info.state),
invocations: info.stats.total_calls,
errors: info.stats.error_count,
})
.collect();
Ok((200, serde_json::to_value(plugins)?))
}
#[cfg(not(feature = "wasm-plugins"))]
async fn handle_plugins_list(_state: &Arc<AdminState>) -> Result<(u16, serde_json::Value)> {
Ok((
503,
serde_json::json!({ "error": "wasm-plugins feature not compiled in" }),
))
}
async fn compute_topology(state: &Arc<AdminState>) -> TopologyResponse {
let health = state.node_health.read().await;
let cfg = state.config_snapshot.read().await;
let mut current_primary: Option<String> = None;
for n in &cfg.nodes {
if n.role.eq_ignore_ascii_case("primary") {
let healthy = health.get(&n.address).map(|h| h.healthy).unwrap_or(false);
if healthy {
current_primary = Some(n.address.clone());
break;
}
}
}
let healthy_nodes = health.values().filter(|h| h.healthy).count() as u32;
let unhealthy_nodes = health.values().filter(|h| !h.healthy).count() as u32;
let total_nodes = cfg.nodes.len() as u32;
TopologyResponse {
current_primary,
healthy_nodes,
unhealthy_nodes,
total_nodes,
last_failover_at: None,
}
}
fn format_prometheus_metrics(metrics: &ServerMetricsSnapshot) -> String {
let mut output = String::new();
output.push_str("# HELP heliosdb_proxy_connections_total Total connections accepted\n");
output.push_str("# TYPE heliosdb_proxy_connections_total counter\n");
output.push_str(&format!(
"heliosdb_proxy_connections_total {}\n",
metrics.connections_accepted
));
output.push_str("# HELP heliosdb_proxy_connections_closed Total connections closed\n");
output.push_str("# TYPE heliosdb_proxy_connections_closed counter\n");
output.push_str(&format!(
"heliosdb_proxy_connections_closed {}\n",
metrics.connections_closed
));
output.push_str("# HELP heliosdb_proxy_queries_total Total queries processed\n");
output.push_str("# TYPE heliosdb_proxy_queries_total counter\n");
output.push_str(&format!(
"heliosdb_proxy_queries_total {}\n",
metrics.queries_processed
));
output.push_str("# HELP heliosdb_proxy_bytes_received_total Total bytes received\n");
output.push_str("# TYPE heliosdb_proxy_bytes_received_total counter\n");
output.push_str(&format!(
"heliosdb_proxy_bytes_received_total {}\n",
metrics.bytes_received
));
output.push_str("# HELP heliosdb_proxy_bytes_sent_total Total bytes sent\n");
output.push_str("# TYPE heliosdb_proxy_bytes_sent_total counter\n");
output.push_str(&format!(
"heliosdb_proxy_bytes_sent_total {}\n",
metrics.bytes_sent
));
output.push_str("# HELP heliosdb_proxy_failovers_total Total failovers\n");
output.push_str("# TYPE heliosdb_proxy_failovers_total counter\n");
output.push_str(&format!(
"heliosdb_proxy_failovers_total {}\n",
metrics.failovers
));
output
}
async fn send_response(
writer: &mut tokio::net::tcp::WriteHalf<'_>,
status: u16,
status_text: &str,
body: &str,
) -> Result<()> {
let response = format!(
"HTTP/1.1 {} {}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
status,
status_text,
body.len(),
body
);
writer
.write_all(response.as_bytes())
.await
.map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
Ok(())
}
async fn send_json_response<T: Serialize>(
writer: &mut tokio::net::tcp::WriteHalf<'_>,
status: u16,
body: &T,
) -> Result<()> {
let json = serde_json::to_string(body)
.map_err(|e| ProxyError::Internal(format!("JSON error: {}", e)))?;
let status_text = match status {
200 => "OK",
400 => "Bad Request",
404 => "Not Found",
500 => "Internal Server Error",
503 => "Service Unavailable",
_ => "Unknown",
};
let response = format!(
"HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
status,
status_text,
json.len(),
json
);
writer
.write_all(response.as_bytes())
.await
.map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
Ok(())
}
pub fn shutdown(&self) {
let _ = self.shutdown_tx.send(());
}
}
impl AdminState {
pub fn new() -> Self {
Self {
node_health: RwLock::new(HashMap::new()),
metrics: RwLock::new(ServerMetricsSnapshot {
connections_accepted: 0,
connections_closed: 0,
queries_processed: 0,
bytes_received: 0,
bytes_sent: 0,
failovers: 0,
}),
active_sessions: RwLock::new(0),
config_snapshot: RwLock::new(ConfigSnapshot {
listen_address: String::new(),
admin_address: String::new(),
tr_enabled: false,
tr_mode: String::new(),
pool_min_connections: 0,
pool_max_connections: 0,
nodes: Vec::new(),
}),
proxy_config: RwLock::new(None),
read_lb_counter: AtomicUsize::new(0),
commands: RwLock::new(HashMap::new()),
#[cfg(feature = "ha-tr")]
replay_engine: RwLock::new(None),
#[cfg(feature = "wasm-plugins")]
plugin_manager: RwLock::new(None),
chaos_overrides: RwLock::new(HashMap::new()),
#[cfg(feature = "anomaly-detection")]
anomaly_detector: RwLock::new(None),
#[cfg(feature = "edge-proxy")]
edge_cache: RwLock::new(None),
#[cfg(feature = "edge-proxy")]
edge_registry: RwLock::new(None),
}
}
#[cfg(feature = "anomaly-detection")]
pub async fn with_anomaly_detector(&self, detector: Arc<AnomalyDetector>) {
*self.anomaly_detector.write().await = Some(detector);
}
#[cfg(feature = "edge-proxy")]
pub async fn with_edge(&self, cache: Arc<EdgeCache>, registry: Arc<EdgeRegistry>) {
*self.edge_cache.write().await = Some(cache);
*self.edge_registry.write().await = Some(registry);
}
#[cfg(feature = "ha-tr")]
pub async fn with_replay_engine(&self, engine: Arc<ReplayEngine>) {
*self.replay_engine.write().await = Some(engine);
}
#[cfg(feature = "wasm-plugins")]
pub async fn with_plugin_manager(&self, manager: Arc<PluginManager>) {
*self.plugin_manager.write().await = Some(manager);
}
pub async fn set_proxy_config(&self, config: ProxyConfig) {
let mut proxy_config = self.proxy_config.write().await;
*proxy_config = Some(config);
}
pub async fn register_command<F>(&self, name: &str, handler: F)
where
F: Fn(&[&str]) -> Result<String> + Send + Sync + 'static,
{
let mut commands = self.commands.write().await;
commands.insert(name.to_string(), Arc::new(handler));
}
pub async fn execute_command(&self, name: &str, args: &[&str]) -> Result<String> {
let commands = self.commands.read().await;
match commands.get(name) {
Some(handler) => handler(args),
None => Err(ProxyError::Internal(format!("Unknown command: {}", name))),
}
}
}
impl Default for AdminState {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Deserialize)]
struct SqlRequest {
query: String,
}
#[derive(Debug, Serialize)]
struct SqlResponse {
query_type: String,
routed_to: String,
node_role: String,
result: serde_json::Value,
}
#[derive(Serialize)]
struct HealthResponse {
status: &'static str,
}
#[derive(Serialize)]
struct ReadinessResponse {
ready: bool,
message: &'static str,
}
#[derive(Serialize)]
struct LivenessResponse {
alive: bool,
}
#[derive(Serialize)]
struct ErrorResponse {
error: String,
}
#[derive(Serialize)]
struct MetricsResponse {
connections_accepted: u64,
connections_closed: u64,
connections_active: u64,
queries_processed: u64,
bytes_received: u64,
bytes_sent: u64,
failovers: u64,
}
impl From<ServerMetricsSnapshot> for MetricsResponse {
fn from(m: ServerMetricsSnapshot) -> Self {
Self {
connections_accepted: m.connections_accepted,
connections_closed: m.connections_closed,
connections_active: m.connections_accepted.saturating_sub(m.connections_closed),
queries_processed: m.queries_processed,
bytes_received: m.bytes_received,
bytes_sent: m.bytes_sent,
failovers: m.failovers,
}
}
}
#[derive(Serialize)]
struct NodeHealthResponse {
address: String,
healthy: bool,
last_check: String,
failure_count: u32,
last_error: Option<String>,
latency_ms: f64,
replication_lag_bytes: Option<u64>,
}
impl From<NodeHealth> for NodeHealthResponse {
fn from(h: NodeHealth) -> Self {
Self {
address: h.address,
healthy: h.healthy,
last_check: h.last_check.to_rfc3339(),
failure_count: h.failure_count,
last_error: h.last_error,
latency_ms: h.latency_ms,
replication_lag_bytes: h.replication_lag_bytes,
}
}
}
#[derive(Serialize)]
struct SessionsResponse {
active_sessions: u64,
}
#[cfg(feature = "edge-proxy")]
#[derive(Debug, Deserialize)]
struct EdgeRegisterBody {
edge_id: String,
region: String,
base_url: String,
}
#[cfg(feature = "edge-proxy")]
#[derive(Debug, Deserialize)]
struct EdgeInvalidateBody {
#[serde(default)]
tables: Vec<String>,
#[serde(default)]
up_to_version: Option<u64>,
}
#[cfg(feature = "anomaly-detection")]
fn parse_limit_query(path: &str, default: usize, max: usize) -> usize {
let q = match path.find('?') {
Some(i) => &path[i + 1..],
None => return default,
};
for kv in q.split('&') {
let mut it = kv.splitn(2, '=');
if let (Some(k), Some(v)) = (it.next(), it.next()) {
if k == "limit" {
if let Ok(n) = v.parse::<usize>() {
return n.min(max);
}
}
}
}
default
}
#[cfg(feature = "ha-tr")]
#[derive(Debug, Deserialize)]
struct ShadowRequestBody {
sql: String,
#[serde(default)]
params: Option<Vec<String>>,
source_host: String,
source_port: u16,
#[serde(default)]
source_user: Option<String>,
#[serde(default)]
source_password: Option<String>,
#[serde(default)]
source_database: Option<String>,
shadow_host: String,
shadow_port: u16,
#[serde(default)]
shadow_user: Option<String>,
#[serde(default)]
shadow_password: Option<String>,
#[serde(default)]
shadow_database: Option<String>,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "action", rename_all = "snake_case")]
enum ChaosAction {
ForceUnhealthy { target_node: String },
Restore { target_node: String },
Reset,
}
#[cfg(feature = "wasm-plugins")]
#[derive(Serialize)]
struct PluginListEntry {
name: String,
version: String,
description: String,
hooks: Vec<String>,
state: String,
invocations: u64,
errors: u64,
}
#[cfg(feature = "ha-tr")]
#[derive(Debug, Deserialize)]
struct ReplayRequestBody {
from: DateTime<Utc>,
to: DateTime<Utc>,
target_host: String,
target_port: u16,
#[serde(default)]
target_user: Option<String>,
#[serde(default)]
target_password: Option<String>,
#[serde(default)]
target_database: Option<String>,
}
#[derive(Serialize)]
struct TopologyResponse {
#[serde(rename = "currentPrimary")]
current_primary: Option<String>,
#[serde(rename = "healthyNodes")]
healthy_nodes: u32,
#[serde(rename = "unhealthyNodes")]
unhealthy_nodes: u32,
#[serde(rename = "totalNodes")]
total_nodes: u32,
#[serde(rename = "lastFailoverAt")]
last_failover_at: Option<String>,
}
#[derive(Serialize)]
struct PoolStatsResponse {
node: String,
active_connections: u64,
idle_connections: u64,
pending_requests: u64,
total_connections_created: u64,
total_connections_closed: u64,
}
#[derive(Serialize)]
struct VersionResponse {
version: String,
build_time: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_admin_state_creation() {
let state = AdminState::new();
let sessions = state.active_sessions.read().await;
assert_eq!(*sessions, 0);
}
#[tokio::test]
async fn test_readiness_check_no_nodes() {
let state = Arc::new(AdminState::new());
let ready = AdminServer::check_readiness(&state).await;
assert!(!ready);
}
#[tokio::test]
async fn test_readiness_check_with_healthy_node() {
let state = Arc::new(AdminState::new());
{
let mut health = state.node_health.write().await;
health.insert(
"localhost:5432".to_string(),
NodeHealth {
address: "localhost:5432".to_string(),
healthy: true,
last_check: chrono::Utc::now(),
failure_count: 0,
last_error: None,
latency_ms: 1.0,
replication_lag_bytes: None,
},
);
}
let ready = AdminServer::check_readiness(&state).await;
assert!(ready);
}
#[tokio::test]
async fn test_command_registration() {
let state = AdminState::new();
state
.register_command("test", |args| {
Ok(format!("Test command with {} args", args.len()))
})
.await;
let result = state.execute_command("test", &["arg1", "arg2"]).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "Test command with 2 args");
}
#[tokio::test]
async fn test_unknown_command() {
let state = AdminState::new();
let result = state.execute_command("unknown", &[]).await;
assert!(result.is_err());
}
#[test]
fn test_prometheus_metrics_format() {
let metrics = ServerMetricsSnapshot {
connections_accepted: 100,
connections_closed: 50,
queries_processed: 1000,
bytes_received: 50000,
bytes_sent: 100000,
failovers: 2,
};
let output = AdminServer::format_prometheus_metrics(&metrics);
assert!(output.contains("heliosdb_proxy_connections_total 100"));
assert!(output.contains("heliosdb_proxy_queries_total 1000"));
assert!(output.contains("heliosdb_proxy_failovers_total 2"));
}
#[test]
fn test_metrics_response_active_connections() {
let snapshot = ServerMetricsSnapshot {
connections_accepted: 100,
connections_closed: 30,
queries_processed: 500,
bytes_received: 10000,
bytes_sent: 20000,
failovers: 1,
};
let response = MetricsResponse::from(snapshot);
assert_eq!(response.connections_active, 70);
}
async fn topology_state(
nodes: &[(&str, &str, bool)],
) -> Arc<AdminState> {
let state = Arc::new(AdminState::new());
{
let mut cfg = state.config_snapshot.write().await;
cfg.nodes = nodes
.iter()
.map(|(addr, role, _)| NodeSnapshot {
address: (*addr).to_string(),
role: (*role).to_string(),
weight: 100,
enabled: true,
})
.collect();
}
{
let mut health = state.node_health.write().await;
for (addr, _, healthy) in nodes {
health.insert(
(*addr).to_string(),
NodeHealth {
address: (*addr).to_string(),
healthy: *healthy,
last_check: chrono::Utc::now(),
failure_count: 0,
last_error: None,
latency_ms: 1.0,
replication_lag_bytes: None,
},
);
}
}
state
}
#[tokio::test]
async fn test_topology_returns_healthy_primary() {
let state = topology_state(&[
("primary.svc:5432", "primary", true),
("standby-a.svc:5432", "standby", true),
("standby-b.svc:5432", "standby", false),
])
.await;
let topo = AdminServer::compute_topology(&state).await;
assert_eq!(topo.current_primary.as_deref(), Some("primary.svc:5432"));
assert_eq!(topo.healthy_nodes, 2);
assert_eq!(topo.unhealthy_nodes, 1);
assert_eq!(topo.total_nodes, 3);
}
#[tokio::test]
async fn test_topology_no_primary_when_primary_unhealthy() {
let state = topology_state(&[
("primary.svc:5432", "primary", false),
("standby.svc:5432", "standby", true),
])
.await;
let topo = AdminServer::compute_topology(&state).await;
assert_eq!(topo.current_primary, None);
assert_eq!(topo.healthy_nodes, 1);
assert_eq!(topo.unhealthy_nodes, 1);
}
#[tokio::test]
async fn test_topology_handles_empty_cluster() {
let state = Arc::new(AdminState::new());
let topo = AdminServer::compute_topology(&state).await;
assert_eq!(topo.current_primary, None);
assert_eq!(topo.healthy_nodes, 0);
assert_eq!(topo.unhealthy_nodes, 0);
assert_eq!(topo.total_nodes, 0);
}
#[tokio::test]
async fn test_topology_role_match_is_case_insensitive() {
let state = topology_state(&[
("primary.svc:5432", "PRIMARY", true),
])
.await;
let topo = AdminServer::compute_topology(&state).await;
assert_eq!(topo.current_primary.as_deref(), Some("primary.svc:5432"));
}
#[cfg(feature = "ha-tr")]
#[tokio::test]
async fn test_replay_returns_503_when_engine_unattached() {
let state = Arc::new(AdminState::new());
let body = r#"{
"from": "2026-04-25T10:00:00Z",
"to": "2026-04-25T11:00:00Z",
"target_host": "127.0.0.1",
"target_port": 5432
}"#;
let (status, value) = AdminServer::handle_replay_request(Some(body), &state)
.await
.expect("handler returns Ok with status code");
assert_eq!(status, 503);
assert_eq!(value["error"], "replay engine not attached");
}
#[cfg(feature = "ha-tr")]
#[tokio::test]
async fn test_replay_400_on_malformed_body() {
let state = Arc::new(AdminState::new());
let (status, _) = AdminServer::handle_replay_request(Some("not json"), &state)
.await
.expect("handler returns Ok with status code");
assert_eq!(status, 400);
}
#[cfg(feature = "ha-tr")]
#[tokio::test]
async fn test_replay_errors_on_empty_body() {
let state = Arc::new(AdminState::new());
let err = AdminServer::handle_replay_request(None, &state).await;
assert!(err.is_err(), "empty body must surface as Err");
}
#[cfg(feature = "wasm-plugins")]
#[tokio::test]
async fn test_plugins_list_returns_503_when_manager_unattached() {
let state = Arc::new(AdminState::new());
let (status, value) = AdminServer::handle_plugins_list(&state)
.await
.expect("handler returns Ok with status code");
assert_eq!(status, 503);
assert_eq!(value["error"], "plugin manager not attached");
}
#[cfg(not(feature = "wasm-plugins"))]
#[tokio::test]
async fn test_plugins_list_503_without_feature() {
let state = Arc::new(AdminState::new());
let (status, _) = AdminServer::handle_plugins_list(&state)
.await
.expect("handler returns Ok");
assert_eq!(status, 503);
}
async fn chaos_state_with_node(addr: &str) -> Arc<AdminState> {
let state = Arc::new(AdminState::new());
state.node_health.write().await.insert(
addr.to_string(),
NodeHealth {
address: addr.to_string(),
healthy: true,
last_check: chrono::Utc::now(),
failure_count: 0,
last_error: None,
latency_ms: 1.0,
replication_lag_bytes: None,
},
);
state
}
#[tokio::test]
async fn test_chaos_force_unhealthy_flips_node_and_records_override() {
let state = chaos_state_with_node("primary.svc:5432").await;
let body = r#"{"action":"force_unhealthy","target_node":"primary.svc:5432"}"#;
let (status, value) = AdminServer::handle_chaos_request(Some(body), &state)
.await
.expect("handler returns Ok");
assert_eq!(status, 200);
assert_eq!(value["applied"], "force_unhealthy");
assert!(!state.node_health.read().await["primary.svc:5432"].healthy);
assert!(state.chaos_overrides.read().await.contains_key("primary.svc:5432"));
}
#[tokio::test]
async fn test_chaos_restore_clears_override_and_flips_back() {
let state = chaos_state_with_node("primary.svc:5432").await;
let _ = AdminServer::handle_chaos_request(
Some(r#"{"action":"force_unhealthy","target_node":"primary.svc:5432"}"#),
&state,
)
.await
.unwrap();
let (status, _) = AdminServer::handle_chaos_request(
Some(r#"{"action":"restore","target_node":"primary.svc:5432"}"#),
&state,
)
.await
.unwrap();
assert_eq!(status, 200);
assert!(state.node_health.read().await["primary.svc:5432"].healthy);
assert!(state.chaos_overrides.read().await.is_empty());
}
#[tokio::test]
async fn test_chaos_reset_restores_all_overrides() {
let state = chaos_state_with_node("a:5432").await;
state.node_health.write().await.insert(
"b:5432".to_string(),
NodeHealth {
address: "b:5432".to_string(),
healthy: true,
last_check: chrono::Utc::now(),
failure_count: 0,
last_error: None,
latency_ms: 1.0,
replication_lag_bytes: None,
},
);
for addr in &["a:5432", "b:5432"] {
let body = format!(r#"{{"action":"force_unhealthy","target_node":"{}"}}"#, addr);
let _ = AdminServer::handle_chaos_request(Some(&body), &state)
.await
.unwrap();
}
let (status, value) = AdminServer::handle_chaos_request(
Some(r#"{"action":"reset"}"#),
&state,
)
.await
.unwrap();
assert_eq!(status, 200);
assert_eq!(value["reset"], true);
let restored = value["restored"].as_array().unwrap();
assert_eq!(restored.len(), 2);
for addr in &["a:5432", "b:5432"] {
assert!(state.node_health.read().await[*addr].healthy);
}
assert!(state.chaos_overrides.read().await.is_empty());
}
#[tokio::test]
async fn test_chaos_force_unhealthy_404s_when_node_unknown() {
let state = Arc::new(AdminState::new());
let body = r#"{"action":"force_unhealthy","target_node":"missing.svc:5432"}"#;
let (status, _) = AdminServer::handle_chaos_request(Some(body), &state)
.await
.expect("handler returns Ok");
assert_eq!(status, 404);
}
#[tokio::test]
async fn test_chaos_400_on_malformed_body() {
let state = Arc::new(AdminState::new());
let (status, _) = AdminServer::handle_chaos_request(Some("not json"), &state)
.await
.expect("handler returns Ok");
assert_eq!(status, 400);
}
#[tokio::test]
async fn test_chaos_400_on_unknown_action() {
let state = Arc::new(AdminState::new());
let body = r#"{"action":"format_disk","target_node":"x"}"#;
let (status, _) = AdminServer::handle_chaos_request(Some(body), &state)
.await
.expect("handler returns Ok");
assert_eq!(status, 400);
}
#[cfg(feature = "ha-tr")]
#[tokio::test]
async fn test_shadow_400_on_malformed_body() {
let (status, _) = AdminServer::handle_shadow_request(Some("not json"))
.await
.expect("handler returns Ok");
assert_eq!(status, 400);
}
#[cfg(feature = "ha-tr")]
#[tokio::test]
async fn test_shadow_500_on_source_unreachable() {
let body = r#"{
"sql": "SELECT 1",
"source_host": "127.0.0.1",
"source_port": 1,
"shadow_host": "127.0.0.1",
"shadow_port": 1
}"#;
let (status, value) = AdminServer::handle_shadow_request(Some(body))
.await
.expect("handler returns Ok");
assert_eq!(status, 500);
let err = value["error"].as_str().expect("error field");
assert!(
err.contains("source connect"),
"expected source connect error, got {}",
err
);
}
#[cfg(feature = "ha-tr")]
#[tokio::test]
async fn test_shadow_errors_on_empty_body() {
let err = AdminServer::handle_shadow_request(None).await;
assert!(err.is_err(), "empty body must surface as Err");
}
#[cfg(feature = "anomaly-detection")]
#[tokio::test]
async fn test_anomalies_returns_503_when_detector_unattached() {
let state = Arc::new(AdminState::new());
let (status, value) =
AdminServer::handle_anomalies_list("/anomalies", &state)
.await
.expect("handler returns Ok");
assert_eq!(status, 503);
assert_eq!(value["error"], "anomaly detector not attached");
}
#[cfg(feature = "anomaly-detection")]
#[tokio::test]
async fn test_anomalies_returns_attached_detector_events() {
use crate::anomaly::{AnomalyConfig, AnomalyDetector, QueryObservation};
let state = Arc::new(AdminState::new());
let det = Arc::new(AnomalyDetector::new(AnomalyConfig::default()));
let _ = det.record_query(&QueryObservation {
tenant: "test".into(),
fingerprint: "fp".into(),
sql: "SELECT * FROM users WHERE id = 1 OR 1=1 --".into(),
timestamp: std::time::Instant::now(),
iso_timestamp: "ts".into(),
});
state.with_anomaly_detector(det.clone()).await;
let (status, value) =
AdminServer::handle_anomalies_list("/anomalies", &state)
.await
.expect("handler returns Ok");
assert_eq!(status, 200);
let count = value["count"].as_u64().expect("count field");
assert!(count > 0, "expected at least one event, got {}", count);
}
#[cfg(feature = "anomaly-detection")]
#[tokio::test]
async fn test_anomalies_limit_query_string_respected() {
use crate::anomaly::{AnomalyConfig, AnomalyDetector, QueryObservation};
let state = Arc::new(AdminState::new());
let det = Arc::new(AnomalyDetector::new(AnomalyConfig::default()));
for i in 0..50 {
let fp = format!("fp{}", i);
let _ = det.record_query(&QueryObservation {
tenant: "test".into(),
fingerprint: fp,
sql: "SELECT 1".into(),
timestamp: std::time::Instant::now(),
iso_timestamp: "ts".into(),
});
}
state.with_anomaly_detector(det).await;
let (status, value) =
AdminServer::handle_anomalies_list("/anomalies?limit=5", &state)
.await
.expect("handler returns Ok");
assert_eq!(status, 200);
assert_eq!(value["limit"].as_u64().unwrap(), 5);
assert_eq!(value["events"].as_array().unwrap().len(), 5);
}
#[cfg(feature = "anomaly-detection")]
#[test]
fn test_parse_limit_query_helper() {
assert_eq!(parse_limit_query("/anomalies", 100, 1024), 100);
assert_eq!(parse_limit_query("/anomalies?limit=42", 100, 1024), 42);
assert_eq!(parse_limit_query("/anomalies?limit=99999", 100, 1024), 1024);
assert_eq!(parse_limit_query("/anomalies?limit=abc", 100, 1024), 100);
assert_eq!(parse_limit_query("/anomalies?other=x&limit=7", 100, 1024), 7);
}
#[cfg(feature = "edge-proxy")]
async fn edge_state() -> Arc<AdminState> {
use crate::edge::{EdgeCache, EdgeRegistry};
use std::time::Duration;
let s = Arc::new(AdminState::new());
let cache = Arc::new(EdgeCache::new(100));
let registry = Arc::new(EdgeRegistry::new(8, Duration::from_secs(60)));
s.with_edge(cache, registry).await;
s
}
#[cfg(feature = "edge-proxy")]
#[tokio::test]
async fn test_edge_status_returns_empty_lists_initially() {
let s = edge_state().await;
let (status, value) = AdminServer::handle_edge_status(&s)
.await
.expect("handler returns Ok");
assert_eq!(status, 200);
assert_eq!(value["edge_count"].as_u64().unwrap(), 0);
assert_eq!(value["registered"].as_array().unwrap().len(), 0);
assert!(value["cache"].is_object(), "cache stats present");
}
#[cfg(feature = "edge-proxy")]
#[tokio::test]
async fn test_edge_register_then_status_lists_edge() {
let s = edge_state().await;
let body = r#"{"edge_id":"e1","region":"us-east","base_url":"https://e1.svc"}"#;
let (status, _) = AdminServer::handle_edge_register(Some(body), &s)
.await
.expect("handler ok");
assert_eq!(status, 201);
let (status2, value2) = AdminServer::handle_edge_status(&s).await.unwrap();
assert_eq!(status2, 200);
assert_eq!(value2["edge_count"].as_u64().unwrap(), 1);
assert_eq!(
value2["registered"][0]["edge_id"].as_str().unwrap(),
"e1"
);
}
#[cfg(feature = "edge-proxy")]
#[tokio::test]
async fn test_edge_register_400_on_malformed_body() {
let s = edge_state().await;
let (status, _) = AdminServer::handle_edge_register(Some("not json"), &s)
.await
.expect("handler ok");
assert_eq!(status, 400);
}
#[cfg(feature = "edge-proxy")]
#[tokio::test]
async fn test_edge_invalidate_drops_local_cache_entries() {
use crate::edge::{CacheEntry, CacheKey};
use std::time::{Duration, Instant};
let s = edge_state().await;
let cache = s.edge_cache.read().await.clone().unwrap();
cache.insert(
CacheKey::new("fp1", "p1"),
CacheEntry {
version: 1,
response_bytes: b"row".to_vec(),
tables: vec!["users".into()],
expires_at: Instant::now() + Duration::from_secs(60),
},
);
assert!(cache.get(&CacheKey::new("fp1", "p1")).is_some());
let body = r#"{"tables":["users"]}"#;
let (status, value) = AdminServer::handle_edge_invalidate(Some(body), &s)
.await
.expect("handler ok");
assert_eq!(status, 200);
assert_eq!(value["dropped_local"].as_u64().unwrap(), 1);
assert!(cache.get(&CacheKey::new("fp1", "p1")).is_none());
}
#[cfg(feature = "edge-proxy")]
#[tokio::test]
async fn test_edge_invalidate_503_when_cache_unattached() {
let s = Arc::new(AdminState::new());
let body = r#"{"tables":["users"]}"#;
let (status, _) = AdminServer::handle_edge_invalidate(Some(body), &s)
.await
.expect("handler ok");
assert_eq!(status, 503);
}
}