#[cfg(feature = "dashboard")]
use crate::inspector::Inspector;
#[cfg(feature = "dashboard")]
use axum::{
extract::{
ws::{Message, WebSocket},
State as AxumState, WebSocketUpgrade,
},
response::{Html, IntoResponse},
routing::get,
Json, Router,
};
#[cfg(feature = "dashboard")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "dashboard")]
use std::net::SocketAddr;
#[cfg(feature = "dashboard")]
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[cfg(feature = "dashboard")]
use tokio::sync::broadcast;
#[cfg(feature = "dashboard")]
use tokio::task::JoinHandle;
#[cfg(feature = "dashboard")]
use tower_http::cors::CorsLayer;
#[cfg(feature = "dashboard")]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum DashboardEvent {
TaskSpawned {
task_id: u64,
name: String,
parent: Option<u64>,
timestamp: u128,
},
TaskCompleted {
task_id: u64,
duration_ms: f64,
timestamp: u128,
},
TaskFailed {
task_id: u64,
error: Option<String>,
timestamp: u128,
},
StateChanged {
task_id: u64,
old_state: String,
new_state: String,
timestamp: u128,
},
MetricsSnapshot {
total_tasks: usize,
running_tasks: usize,
completed_tasks: usize,
failed_tasks: usize,
blocked_tasks: usize,
timestamp: u128,
},
AwaitStarted {
task_id: u64,
label: String,
timestamp: u128,
},
AwaitEnded {
task_id: u64,
label: String,
duration_ms: f64,
timestamp: u128,
},
}
#[cfg(feature = "dashboard")]
#[derive(Clone)]
struct DashboardState {
event_tx: broadcast::Sender<DashboardEvent>,
inspector: &'static Inspector,
}
#[cfg(feature = "dashboard")]
pub struct Dashboard {
port: u16,
event_tx: broadcast::Sender<DashboardEvent>,
}
#[cfg(feature = "dashboard")]
impl Dashboard {
#[must_use]
pub fn new(port: u16) -> Self {
let (event_tx, _) = broadcast::channel(1000);
Self { port, event_tx }
}
pub async fn start(self) -> Result<JoinHandle<Result<(), std::io::Error>>, std::io::Error> {
let addr = SocketAddr::from(([127, 0, 0, 1], self.port));
let inspector = Inspector::global();
let state = DashboardState {
event_tx: self.event_tx.clone(),
inspector,
};
let metrics_tx = self.event_tx.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_millis(100));
loop {
interval.tick().await;
let stats = inspector.stats();
let snapshot = DashboardEvent::MetricsSnapshot {
total_tasks: stats.total_tasks,
running_tasks: stats.running_tasks,
completed_tasks: stats.completed_tasks,
failed_tasks: stats.failed_tasks,
blocked_tasks: stats.blocked_tasks,
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis(),
};
let _ = metrics_tx.send(snapshot);
}
});
let app = Router::new()
.route("/", get(serve_dashboard))
.route("/ws", get(websocket_handler))
.route("/api/tasks", get(api_tasks))
.route("/api/stats", get(api_stats))
.layer(CorsLayer::permissive())
.with_state(state);
let handle = tokio::spawn(async move {
let listener = tokio::net::TcpListener::bind(&addr).await?;
axum::serve(listener, app).await
});
Ok(handle)
}
}
#[cfg(feature = "dashboard")]
async fn serve_dashboard() -> Html<&'static str> {
Html(include_str!("static/index.html"))
}
#[cfg(feature = "dashboard")]
async fn websocket_handler(
ws: WebSocketUpgrade,
AxumState(state): AxumState<DashboardState>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_websocket(socket, state))
}
#[cfg(feature = "dashboard")]
async fn handle_websocket(mut socket: WebSocket, state: DashboardState) {
let mut event_rx = state.event_tx.subscribe();
let tasks = state.inspector.get_all_tasks();
for task in tasks {
let event = DashboardEvent::TaskSpawned {
task_id: task.id.as_u64(),
name: task.name.clone(),
parent: task.parent.map(|p| p.as_u64()),
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis(),
};
if let Ok(json) = serde_json::to_string(&event) {
if socket.send(Message::Text(json)).await.is_err() {
return;
}
}
}
while let Ok(event) = event_rx.recv().await {
if let Ok(json) = serde_json::to_string(&event) {
if socket.send(Message::Text(json)).await.is_err() {
break;
}
}
}
}
#[cfg(feature = "dashboard")]
async fn api_tasks(AxumState(state): AxumState<DashboardState>) -> Json<serde_json::Value> {
let tasks = state.inspector.get_all_tasks();
let task_list: Vec<serde_json::Value> = tasks
.into_iter()
.map(|task| {
serde_json::json!({
"id": task.id.as_u64(),
"name": task.name,
"state": format!("{:?}", task.state),
"parent": task.parent.map(|p| p.as_u64()),
"poll_count": task.poll_count,
})
})
.collect();
Json(serde_json::json!({ "tasks": task_list }))
}
#[cfg(feature = "dashboard")]
async fn api_stats(AxumState(state): AxumState<DashboardState>) -> Json<serde_json::Value> {
let stats = state.inspector.stats();
Json(serde_json::json!({
"total_tasks": stats.total_tasks,
"running_tasks": stats.running_tasks,
"completed_tasks": stats.completed_tasks,
"failed_tasks": stats.failed_tasks,
"blocked_tasks": stats.blocked_tasks,
}))
}
#[cfg(not(feature = "dashboard"))]
compile_error!("The dashboard module requires the 'dashboard' feature to be enabled");