1#[cfg(feature = "dashboard")]
34use crate::inspector::Inspector;
35
36#[cfg(feature = "dashboard")]
37use axum::{
38 extract::{
39 ws::{Message, WebSocket},
40 State as AxumState, WebSocketUpgrade,
41 },
42 response::{Html, IntoResponse},
43 routing::get,
44 Json, Router,
45};
46#[cfg(feature = "dashboard")]
47use serde::{Deserialize, Serialize};
48#[cfg(feature = "dashboard")]
49use std::net::SocketAddr;
50#[cfg(feature = "dashboard")]
51use std::time::{Duration, SystemTime, UNIX_EPOCH};
52#[cfg(feature = "dashboard")]
53use tokio::sync::broadcast;
54#[cfg(feature = "dashboard")]
55use tokio::task::JoinHandle;
56#[cfg(feature = "dashboard")]
57use tower_http::cors::CorsLayer;
58
59#[cfg(feature = "dashboard")]
61#[derive(Debug, Clone, Serialize, Deserialize)]
62#[serde(tag = "type", rename_all = "snake_case")]
63pub enum DashboardEvent {
64 TaskSpawned {
66 task_id: u64,
68 name: String,
70 parent: Option<u64>,
72 timestamp: u128,
74 },
75 TaskCompleted {
77 task_id: u64,
79 duration_ms: f64,
81 timestamp: u128,
83 },
84 TaskFailed {
86 task_id: u64,
88 error: Option<String>,
90 timestamp: u128,
92 },
93 StateChanged {
95 task_id: u64,
97 old_state: String,
99 new_state: String,
101 timestamp: u128,
103 },
104 MetricsSnapshot {
106 total_tasks: usize,
108 running_tasks: usize,
110 completed_tasks: usize,
112 failed_tasks: usize,
114 blocked_tasks: usize,
116 timestamp: u128,
118 },
119 AwaitStarted {
121 task_id: u64,
123 label: String,
125 timestamp: u128,
127 },
128 AwaitEnded {
130 task_id: u64,
132 label: String,
134 duration_ms: f64,
136 timestamp: u128,
138 },
139}
140
141#[cfg(feature = "dashboard")]
143#[derive(Clone)]
144struct DashboardState {
145 event_tx: broadcast::Sender<DashboardEvent>,
147 inspector: &'static Inspector,
149}
150
151#[cfg(feature = "dashboard")]
153pub struct Dashboard {
154 port: u16,
156 event_tx: broadcast::Sender<DashboardEvent>,
158}
159
160#[cfg(feature = "dashboard")]
161impl Dashboard {
162 #[must_use]
164 pub fn new(port: u16) -> Self {
165 let (event_tx, _) = broadcast::channel(1000);
166
167 Self { port, event_tx }
168 }
169
170 pub async fn start(self) -> Result<JoinHandle<Result<(), std::io::Error>>, std::io::Error> {
174 let addr = SocketAddr::from(([127, 0, 0, 1], self.port));
175 let inspector = Inspector::global();
176
177 let state = DashboardState {
178 event_tx: self.event_tx.clone(),
179 inspector,
180 };
181
182 let metrics_tx = self.event_tx.clone();
184 tokio::spawn(async move {
185 let mut interval = tokio::time::interval(Duration::from_millis(100));
186
187 loop {
188 interval.tick().await;
189
190 let stats = inspector.stats();
191 let snapshot = DashboardEvent::MetricsSnapshot {
192 total_tasks: stats.total_tasks,
193 running_tasks: stats.running_tasks,
194 completed_tasks: stats.completed_tasks,
195 failed_tasks: stats.failed_tasks,
196 blocked_tasks: stats.blocked_tasks,
197 timestamp: SystemTime::now()
198 .duration_since(UNIX_EPOCH)
199 .unwrap_or_default()
200 .as_millis(),
201 };
202
203 let _ = metrics_tx.send(snapshot);
204 }
205 });
206
207 let app = Router::new()
209 .route("/", get(serve_dashboard))
210 .route("/ws", get(websocket_handler))
211 .route("/api/tasks", get(api_tasks))
212 .route("/api/stats", get(api_stats))
213 .layer(CorsLayer::permissive())
214 .with_state(state);
215
216 let handle = tokio::spawn(async move {
218 let listener = tokio::net::TcpListener::bind(&addr).await?;
219 axum::serve(listener, app).await
220 });
221
222 Ok(handle)
223 }
224}
225
226#[cfg(feature = "dashboard")]
228async fn serve_dashboard() -> Html<&'static str> {
229 Html(include_str!("static/index.html"))
230}
231
232#[cfg(feature = "dashboard")]
234async fn websocket_handler(
235 ws: WebSocketUpgrade,
236 AxumState(state): AxumState<DashboardState>,
237) -> impl IntoResponse {
238 ws.on_upgrade(|socket| handle_websocket(socket, state))
239}
240
241#[cfg(feature = "dashboard")]
243async fn handle_websocket(mut socket: WebSocket, state: DashboardState) {
244 let mut event_rx = state.event_tx.subscribe();
245
246 let tasks = state.inspector.get_all_tasks();
248 for task in tasks {
249 let event = DashboardEvent::TaskSpawned {
250 task_id: task.id.as_u64(),
251 name: task.name.clone(),
252 parent: task.parent.map(|p| p.as_u64()),
253 timestamp: SystemTime::now()
254 .duration_since(UNIX_EPOCH)
255 .unwrap_or_default()
256 .as_millis(),
257 };
258
259 if let Ok(json) = serde_json::to_string(&event) {
260 if socket.send(Message::Text(json)).await.is_err() {
261 return;
262 }
263 }
264 }
265
266 while let Ok(event) = event_rx.recv().await {
268 if let Ok(json) = serde_json::to_string(&event) {
269 if socket.send(Message::Text(json)).await.is_err() {
270 break;
271 }
272 }
273 }
274}
275
276#[cfg(feature = "dashboard")]
278async fn api_tasks(AxumState(state): AxumState<DashboardState>) -> Json<serde_json::Value> {
279 let tasks = state.inspector.get_all_tasks();
280
281 let task_list: Vec<serde_json::Value> = tasks
282 .into_iter()
283 .map(|task| {
284 serde_json::json!({
285 "id": task.id.as_u64(),
286 "name": task.name,
287 "state": format!("{:?}", task.state),
288 "parent": task.parent.map(|p| p.as_u64()),
289 "poll_count": task.poll_count,
290 })
291 })
292 .collect();
293
294 Json(serde_json::json!({ "tasks": task_list }))
295}
296
297#[cfg(feature = "dashboard")]
299async fn api_stats(AxumState(state): AxumState<DashboardState>) -> Json<serde_json::Value> {
300 let stats = state.inspector.stats();
301
302 Json(serde_json::json!({
303 "total_tasks": stats.total_tasks,
304 "running_tasks": stats.running_tasks,
305 "completed_tasks": stats.completed_tasks,
306 "failed_tasks": stats.failed_tasks,
307 "blocked_tasks": stats.blocked_tasks,
308 }))
309}
310
311#[cfg(not(feature = "dashboard"))]
312compile_error!("The dashboard module requires the 'dashboard' feature to be enabled");