use crate::{
ComponentHealth, ErrorResponse, HealthResponse, HealthStatus, KernelResponse, RequestMetadata,
ResponseMetadata,
common::{ServiceConfig, ServiceMetrics, headers, paths},
};
use axum::{
Router,
extract::{Path, State},
http::{HeaderMap, HeaderValue, StatusCode, header},
middleware::{self, Next},
response::{Json, Response},
routing::{get, post},
};
use rustkernel_core::registry::KernelRegistry;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Instant;
#[derive(Clone)]
pub struct AppState {
pub registry: Arc<KernelRegistry>,
pub config: ServiceConfig,
pub metrics: Arc<ServiceMetrics>,
pub start_time: Instant,
}
impl AppState {
pub fn new(registry: Arc<KernelRegistry>, config: ServiceConfig) -> Self {
Self {
registry,
config,
metrics: ServiceMetrics::new(),
start_time: Instant::now(),
}
}
}
#[derive(Debug, Clone)]
pub struct RouterConfig {
pub health_endpoints: bool,
pub metrics_endpoint: bool,
pub cors_enabled: bool,
pub api_prefix: String,
}
impl Default for RouterConfig {
fn default() -> Self {
Self {
health_endpoints: true,
metrics_endpoint: true,
cors_enabled: true,
api_prefix: "/api/v1".to_string(),
}
}
}
pub struct KernelRouter {
registry: Arc<KernelRegistry>,
config: RouterConfig,
service_config: ServiceConfig,
}
impl KernelRouter {
pub fn new(registry: Arc<KernelRegistry>) -> Self {
Self {
registry,
config: RouterConfig::default(),
service_config: ServiceConfig::default(),
}
}
pub fn with_config(mut self, config: RouterConfig) -> Self {
self.config = config;
self
}
pub fn with_service_config(mut self, config: ServiceConfig) -> Self {
self.service_config = config;
self
}
pub fn with_health_endpoints(mut self) -> Self {
self.config.health_endpoints = true;
self
}
pub fn with_metrics(mut self) -> Self {
self.config.metrics_endpoint = true;
self
}
pub fn build(self) -> Router {
let cors_enabled = self.config.cors_enabled;
let state = AppState::new(self.registry, self.service_config.clone());
let mut router = Router::new();
let api_routes = Router::new()
.route("/kernels", get(list_kernels))
.route("/kernels/:kernel_id", get(get_kernel_info))
.route("/kernels/:kernel_id/execute", post(execute_kernel));
router = router.nest(&self.config.api_prefix, api_routes);
if self.config.health_endpoints {
router = router
.route(paths::HEALTH, get(health_check))
.route(paths::LIVENESS, get(liveness_check))
.route(paths::READINESS, get(readiness_check));
}
if self.config.metrics_endpoint {
router = router.route(paths::METRICS, get(metrics_endpoint));
}
router = router.layer(middleware::from_fn(request_id_middleware));
if cors_enabled {
let cors = build_cors(&self.service_config);
router = router.layer(cors);
}
router.with_state(state)
}
}
fn build_cors(config: &ServiceConfig) -> tower_http::cors::CorsLayer {
use tower_http::cors::{Any, CorsLayer};
let cors = CorsLayer::new()
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
axum::http::Method::OPTIONS,
])
.allow_headers([
header::CONTENT_TYPE,
header::AUTHORIZATION,
headers::X_REQUEST_ID.parse().unwrap(),
headers::X_TENANT_ID.parse().unwrap(),
headers::X_API_KEY.parse().unwrap(),
]);
if config.cors_origins.iter().any(|o| o == "*") {
cors.allow_origin(Any)
} else {
let origins: Vec<HeaderValue> = config
.cors_origins
.iter()
.filter_map(|o| o.parse().ok())
.collect();
cors.allow_origin(origins)
}
}
async fn request_id_middleware(req: axum::extract::Request, next: Next) -> Response {
let request_id = req
.headers()
.get(headers::X_REQUEST_ID)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let mut response = next.run(req).await;
if let Ok(val) = HeaderValue::from_str(&request_id) {
response.headers_mut().insert(headers::X_REQUEST_ID, val);
}
response
}
async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
let uptime = state.start_time.elapsed().as_secs();
let stats = state.registry.stats();
let mut components = Vec::new();
let mut overall_status = HealthStatus::Healthy;
let registry_status = if stats.total > 0 {
HealthStatus::Healthy
} else {
overall_status = HealthStatus::Degraded;
HealthStatus::Degraded
};
components.push(ComponentHealth {
name: "kernel_registry".to_string(),
status: registry_status,
message: Some(format!(
"{} kernels ({} batch, {} ring)",
stats.total, stats.batch_kernels, stats.ring_kernels
)),
});
let error_rate = if state.metrics.request_count() > 0 {
state.metrics.error_count() as f64 / state.metrics.request_count() as f64
} else {
0.0
};
let execution_status = if error_rate < 0.1 {
HealthStatus::Healthy
} else if error_rate < 0.5 {
overall_status = HealthStatus::Degraded;
HealthStatus::Degraded
} else {
overall_status = HealthStatus::Unhealthy;
HealthStatus::Unhealthy
};
components.push(ComponentHealth {
name: "execution_engine".to_string(),
status: execution_status,
message: Some(format!(
"{} requests, {:.1}% error rate, {:.0}us avg latency",
state.metrics.request_count(),
error_rate * 100.0,
state.metrics.avg_latency_us()
)),
});
Json(HealthResponse {
status: overall_status,
version: state.config.version.clone(),
uptime_secs: uptime,
components,
})
}
async fn liveness_check() -> StatusCode {
StatusCode::OK
}
async fn readiness_check(State(state): State<AppState>) -> StatusCode {
if state.registry.stats().total > 0 {
StatusCode::OK
} else {
StatusCode::SERVICE_UNAVAILABLE
}
}
async fn metrics_endpoint(State(state): State<AppState>) -> String {
let metrics = &state.metrics;
let uptime = state.start_time.elapsed().as_secs();
let stats = state.registry.stats();
let error_rate = if metrics.request_count() > 0 {
metrics.error_count() as f64 / metrics.request_count() as f64
} else {
0.0
};
let mut output = String::with_capacity(2048);
output += &format!(
"# HELP rustkernels_requests_total Total number of requests\n\
# TYPE rustkernels_requests_total counter\n\
rustkernels_requests_total {}\n",
metrics.request_count()
);
output += &format!(
"# HELP rustkernels_errors_total Total number of errors\n\
# TYPE rustkernels_errors_total counter\n\
rustkernels_errors_total {}\n",
metrics.error_count()
);
output += &format!(
"# HELP rustkernels_request_duration_us Average request duration in microseconds\n\
# TYPE rustkernels_request_duration_us gauge\n\
rustkernels_request_duration_us {:.2}\n",
metrics.avg_latency_us()
);
output += &format!(
"# HELP rustkernels_request_duration_min_us Minimum request duration in microseconds\n\
# TYPE rustkernels_request_duration_min_us gauge\n\
rustkernels_request_duration_min_us {}\n",
metrics.min_latency_us()
);
output += &format!(
"# HELP rustkernels_request_duration_max_us Maximum request duration in microseconds\n\
# TYPE rustkernels_request_duration_max_us gauge\n\
rustkernels_request_duration_max_us {}\n",
metrics.max_latency_us()
);
output += &format!(
"# HELP rustkernels_error_rate Current error rate (0.0-1.0)\n\
# TYPE rustkernels_error_rate gauge\n\
rustkernels_error_rate {:.6}\n",
error_rate
);
output += &format!(
"# HELP rustkernels_uptime_seconds Service uptime in seconds\n\
# TYPE rustkernels_uptime_seconds gauge\n\
rustkernels_uptime_seconds {}\n",
uptime
);
output += &format!(
"# HELP rustkernels_kernels_registered Total registered kernels\n\
# TYPE rustkernels_kernels_registered gauge\n\
rustkernels_kernels_registered {}\n",
stats.total
);
output += &format!(
"# HELP rustkernels_batch_kernels Batch kernels available for execution\n\
# TYPE rustkernels_batch_kernels gauge\n\
rustkernels_batch_kernels {}\n",
stats.batch_kernels
);
output += &format!(
"# HELP rustkernels_ring_kernels Ring kernels registered\n\
# TYPE rustkernels_ring_kernels gauge\n\
rustkernels_ring_kernels {}\n",
stats.ring_kernels
);
output += "# HELP rustkernels_kernels_by_domain Kernels by domain\n\
# TYPE rustkernels_kernels_by_domain gauge\n";
for (domain, count) in &stats.by_domain {
output += &format!(
"rustkernels_kernels_by_domain{{domain=\"{}\"}} {}\n",
domain, count
);
}
output
}
async fn list_kernels(State(state): State<AppState>) -> Json<KernelListResponse> {
let stats = state.registry.stats();
let kernels: Vec<KernelSummary> = state
.registry
.all_kernel_ids()
.iter()
.filter_map(|id| state.registry.get(id))
.map(|meta| KernelSummary {
id: meta.id.clone(),
domain: format!("{:?}", meta.domain),
mode: format!("{:?}", meta.mode),
description: meta.description.clone(),
})
.collect();
Json(KernelListResponse {
total: stats.total,
kernels,
})
}
async fn get_kernel_info(
State(state): State<AppState>,
Path(kernel_id): Path<String>,
) -> Result<Json<KernelInfoResponse>, (StatusCode, Json<ErrorResponse>)> {
match state.registry.get(&kernel_id) {
Some(meta) => Ok(Json(KernelInfoResponse {
id: meta.id.clone(),
domain: format!("{:?}", meta.domain),
mode: format!("{:?}", meta.mode),
description: meta.description.clone(),
expected_throughput: meta.expected_throughput,
target_latency_us: meta.target_latency_us,
})),
None => Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
code: "KERNEL_NOT_FOUND".to_string(),
message: format!("Kernel not found: {}", kernel_id),
request_id: None,
details: None,
}),
)),
}
}
async fn execute_kernel(
State(state): State<AppState>,
headers: HeaderMap,
Path(kernel_id): Path<String>,
Json(request): Json<ExecuteRequest>,
) -> Result<Json<KernelResponse>, (StatusCode, Json<ErrorResponse>)> {
let start = Instant::now();
let request_id = extract_request_id(&headers);
if let Some(entry) = state.registry.get_batch(&kernel_id) {
let kernel = entry.create();
let input_bytes = serde_json::to_vec(&request.input).map_err(|e| {
state
.metrics
.record_request(start.elapsed().as_micros() as u64, true);
(
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
code: "INVALID_INPUT".to_string(),
message: format!("Failed to serialize input: {}", e),
request_id: Some(request_id.clone()),
details: None,
}),
)
})?;
let timeout_ms = request
.metadata
.timeout_ms
.unwrap_or(state.config.default_timeout.as_millis() as u64);
let timeout = std::time::Duration::from_millis(timeout_ms);
let result = tokio::time::timeout(timeout, kernel.execute_dyn(&input_bytes)).await;
match result {
Ok(Ok(output_bytes)) => {
let output: serde_json::Value =
serde_json::from_slice(&output_bytes).map_err(|e| {
state
.metrics
.record_request(start.elapsed().as_micros() as u64, true);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
code: "OUTPUT_DESERIALIZATION_ERROR".to_string(),
message: format!("Failed to deserialize kernel output: {}", e),
request_id: Some(request_id.clone()),
details: None,
}),
)
})?;
let duration_us = start.elapsed().as_micros() as u64;
state.metrics.record_request(duration_us, false);
Ok(Json(KernelResponse {
request_id,
kernel_id,
output,
metadata: ResponseMetadata {
duration_us,
backend: entry.metadata.mode.as_str().to_uppercase(),
gpu_memory_bytes: None,
trace_id: extract_trace_id(&headers),
},
}))
}
Ok(Err(e)) => {
let duration_us = start.elapsed().as_micros() as u64;
state.metrics.record_request(duration_us, true);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
code: "EXECUTION_FAILED".to_string(),
message: format!("Kernel execution failed: {}", e),
request_id: Some(request_id),
details: None,
}),
))
}
Err(_) => {
state
.metrics
.record_request(start.elapsed().as_micros() as u64, true);
Err((
StatusCode::GATEWAY_TIMEOUT,
Json(ErrorResponse {
code: "EXECUTION_TIMEOUT".to_string(),
message: format!("Kernel execution timed out after {}ms", timeout_ms),
request_id: Some(request_id),
details: None,
}),
))
}
}
} else if let Some(meta) = state.registry.get(&kernel_id) {
state
.metrics
.record_request(start.elapsed().as_micros() as u64, true);
Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
code: "RING_KERNEL_REST_UNSUPPORTED".to_string(),
message: format!(
"Kernel '{}' is a {} mode kernel. Ring kernels require persistent \
deployment via the Ring protocol or gRPC streaming API.",
kernel_id, meta.mode
),
request_id: Some(request_id),
details: Some(serde_json::json!({
"kernel_mode": meta.mode.as_str(),
"kernel_domain": format!("{:?}", meta.domain),
})),
}),
))
} else {
state
.metrics
.record_request(start.elapsed().as_micros() as u64, true);
Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
code: "KERNEL_NOT_FOUND".to_string(),
message: format!("Kernel not found: {}", kernel_id),
request_id: Some(request_id),
details: None,
}),
))
}
}
fn extract_request_id(headers: &HeaderMap) -> String {
headers
.get(headers::X_REQUEST_ID)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string())
}
fn extract_trace_id(headers: &HeaderMap) -> Option<String> {
headers
.get(headers::TRACEPARENT)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecuteRequest {
pub input: serde_json::Value,
#[serde(default)]
pub metadata: RequestMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KernelListResponse {
pub total: usize,
pub kernels: Vec<KernelSummary>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KernelSummary {
pub id: String,
pub domain: String,
pub mode: String,
pub description: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KernelInfoResponse {
pub id: String,
pub domain: String,
pub mode: String,
pub description: String,
pub expected_throughput: u64,
pub target_latency_us: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_router_config() {
let config = RouterConfig::default();
assert!(config.health_endpoints);
assert!(config.metrics_endpoint);
assert_eq!(config.api_prefix, "/api/v1");
}
#[test]
fn test_app_state() {
let registry = Arc::new(KernelRegistry::new());
let state = AppState::new(registry, ServiceConfig::default());
assert_eq!(state.metrics.request_count(), 0);
}
}