Skip to main content

cortexai_dashboard/
state.rs

1//! Dashboard state management
2
3use chrono::{DateTime, Utc};
4use parking_lot::RwLock;
5use cortexai_monitoring::{CostStats, CostTracker};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::broadcast;
10
11// Import PrometheusHandle for metrics endpoint
12use metrics_exporter_prometheus::PrometheusHandle;
13
14/// Agent status information
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AgentStatus {
17    pub id: String,
18    pub name: String,
19    pub role: String,
20    pub status: String,
21    pub messages_processed: u64,
22    pub last_activity: Option<DateTime<Utc>>,
23    pub current_task: Option<String>,
24}
25
26/// Trace entry representing a single step in agent execution
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct TraceEntry {
29    pub id: String,
30    pub session_id: String,
31    pub timestamp: DateTime<Utc>,
32    pub entry_type: TraceEntryType,
33    pub duration_ms: Option<f64>,
34    pub metadata: Option<serde_json::Value>,
35}
36
37/// Type of trace entry
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(tag = "type", content = "data")]
40pub enum TraceEntryType {
41    LlmRequest {
42        model: String,
43        prompt_tokens: u32,
44        completion_tokens: u32,
45        cost: f64,
46    },
47    LlmResponse {
48        content: String,
49        finish_reason: Option<String>,
50    },
51    ToolCall {
52        tool_name: String,
53        arguments: serde_json::Value,
54    },
55    ToolResult {
56        tool_name: String,
57        result: serde_json::Value,
58        success: bool,
59    },
60    AgentThought {
61        thought: String,
62    },
63    Error {
64        message: String,
65        error_type: String,
66    },
67}
68
69/// Session representing a conversation
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct Session {
72    pub id: String,
73    pub name: Option<String>,
74    pub created_at: DateTime<Utc>,
75    pub updated_at: DateTime<Utc>,
76    pub message_count: u32,
77    pub status: SessionStatus,
78    pub agent_id: Option<String>,
79    pub metadata: Option<serde_json::Value>,
80}
81
82/// Session status
83#[derive(Debug, Clone, Serialize, Deserialize)]
84#[serde(rename_all = "snake_case")]
85pub enum SessionStatus {
86    Active,
87    Completed,
88    Failed,
89    Archived,
90}
91
92/// Message in a session
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct SessionMessage {
95    pub id: String,
96    pub session_id: String,
97    pub role: String,
98    pub content: String,
99    pub timestamp: DateTime<Utc>,
100    pub metadata: Option<serde_json::Value>,
101}
102
103/// Dashboard metrics snapshot
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct DashboardMetrics {
106    pub timestamp: DateTime<Utc>,
107    pub cost_stats: CostStats,
108    pub agents: Vec<AgentStatus>,
109    pub active_agents: usize,
110    pub total_messages: u64,
111    pub uptime_seconds: u64,
112}
113
114/// WebSocket message types
115#[derive(Debug, Clone, Serialize, Deserialize)]
116#[serde(tag = "type", content = "data")]
117pub enum WsMessage {
118    /// Full metrics update
119    Metrics(DashboardMetrics),
120    /// Agent status changed
121    AgentUpdate(AgentStatus),
122    /// New request recorded
123    RequestRecorded {
124        model: String,
125        cost: f64,
126        tokens: u64,
127        latency_ms: f64,
128    },
129    /// Trace update
130    TraceUpdate(TraceEntry),
131    /// Session update
132    SessionUpdate(Session),
133    /// Error occurred
134    Error {
135        message: String,
136    },
137    /// Ping/pong for keepalive
138    Ping,
139    Pong,
140}
141
142/// Shared dashboard state
143pub struct DashboardState {
144    /// Cost tracker reference
145    pub cost_tracker: Arc<CostTracker>,
146    /// Prometheus metrics handle for /metrics endpoint
147    pub prometheus_handle: PrometheusHandle,
148    /// Active agents
149    agents: RwLock<HashMap<String, AgentStatus>>,
150    /// Sessions
151    sessions: RwLock<HashMap<String, Session>>,
152    /// Session messages
153    session_messages: RwLock<HashMap<String, Vec<SessionMessage>>>,
154    /// Traces
155    traces: RwLock<Vec<TraceEntry>>,
156    /// Total messages processed
157    total_messages: RwLock<u64>,
158    /// Server start time
159    started_at: DateTime<Utc>,
160    /// Broadcast channel for WebSocket updates
161    pub broadcast_tx: broadcast::Sender<WsMessage>,
162}
163
164impl DashboardState {
165    /// Create new dashboard state with Prometheus handle
166    pub fn new(cost_tracker: Arc<CostTracker>, prometheus_handle: PrometheusHandle) -> Self {
167        let (broadcast_tx, _) = broadcast::channel(1024);
168
169        Self {
170            cost_tracker,
171            prometheus_handle,
172            agents: RwLock::new(HashMap::new()),
173            sessions: RwLock::new(HashMap::new()),
174            session_messages: RwLock::new(HashMap::new()),
175            traces: RwLock::new(Vec::new()),
176            total_messages: RwLock::new(0),
177            started_at: Utc::now(),
178            broadcast_tx,
179        }
180    }
181
182    // ==================== Agents ====================
183
184    /// Register or update an agent
185    pub fn update_agent(&self, status: AgentStatus) {
186        self.agents
187            .write()
188            .insert(status.id.clone(), status.clone());
189        let _ = self.broadcast_tx.send(WsMessage::AgentUpdate(status));
190    }
191
192    /// Remove an agent
193    pub fn remove_agent(&self, agent_id: &str) {
194        self.agents.write().remove(agent_id);
195    }
196
197    /// Get all agents
198    pub fn get_agents(&self) -> Vec<AgentStatus> {
199        self.agents.read().values().cloned().collect()
200    }
201
202    /// Get a specific agent
203    pub fn get_agent(&self, id: &str) -> Option<AgentStatus> {
204        self.agents.read().get(id).cloned()
205    }
206
207    /// Start an agent
208    pub fn start_agent(&self, id: &str) -> Result<(), String> {
209        let mut agents = self.agents.write();
210        if let Some(agent) = agents.get_mut(id) {
211            agent.status = "running".to_string();
212            agent.last_activity = Some(Utc::now());
213            let _ = self
214                .broadcast_tx
215                .send(WsMessage::AgentUpdate(agent.clone()));
216            Ok(())
217        } else {
218            Err(format!("Agent {} not found", id))
219        }
220    }
221
222    /// Stop an agent
223    pub fn stop_agent(&self, id: &str) -> Result<(), String> {
224        let mut agents = self.agents.write();
225        if let Some(agent) = agents.get_mut(id) {
226            agent.status = "stopped".to_string();
227            agent.current_task = None;
228            agent.last_activity = Some(Utc::now());
229            let _ = self
230                .broadcast_tx
231                .send(WsMessage::AgentUpdate(agent.clone()));
232            Ok(())
233        } else {
234            Err(format!("Agent {} not found", id))
235        }
236    }
237
238    /// Restart an agent
239    pub fn restart_agent(&self, id: &str) -> Result<(), String> {
240        self.stop_agent(id)?;
241        self.start_agent(id)
242    }
243
244    // ==================== Sessions ====================
245
246    /// Add or update a session
247    pub fn update_session(&self, session: Session) {
248        self.sessions
249            .write()
250            .insert(session.id.clone(), session.clone());
251        let _ = self.broadcast_tx.send(WsMessage::SessionUpdate(session));
252    }
253
254    /// Get all sessions
255    pub fn get_sessions(&self) -> Vec<Session> {
256        let mut sessions: Vec<Session> = self.sessions.read().values().cloned().collect();
257        sessions.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
258        sessions
259    }
260
261    /// Get a specific session
262    pub fn get_session(&self, id: &str) -> Option<Session> {
263        self.sessions.read().get(id).cloned()
264    }
265
266    /// Add a message to a session
267    pub fn add_session_message(&self, message: SessionMessage) {
268        let session_id = message.session_id.clone();
269        self.session_messages
270            .write()
271            .entry(session_id.clone())
272            .or_default()
273            .push(message);
274
275        // Update session message count
276        if let Some(session) = self.sessions.write().get_mut(&session_id) {
277            session.message_count += 1;
278            session.updated_at = Utc::now();
279        }
280    }
281
282    /// Get messages for a session
283    pub fn get_session_messages(&self, session_id: &str) -> Vec<SessionMessage> {
284        self.session_messages
285            .read()
286            .get(session_id)
287            .cloned()
288            .unwrap_or_default()
289    }
290
291    // ==================== Traces ====================
292
293    /// Add a trace entry
294    pub fn add_trace(&self, trace: TraceEntry) {
295        self.traces.write().push(trace.clone());
296        let _ = self.broadcast_tx.send(WsMessage::TraceUpdate(trace));
297    }
298
299    /// Get all traces
300    pub fn get_traces(&self) -> Vec<TraceEntry> {
301        let mut traces = self.traces.read().clone();
302        traces.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
303        traces.truncate(1000); // Limit to last 1000 traces
304        traces
305    }
306
307    /// Get traces for a specific session
308    pub fn get_session_traces(&self, session_id: &str) -> Vec<TraceEntry> {
309        self.traces
310            .read()
311            .iter()
312            .filter(|t| t.session_id == session_id)
313            .cloned()
314            .collect()
315    }
316
317    // ==================== Messages & Metrics ====================
318
319    /// Increment message counter
320    pub fn record_message(&self) {
321        *self.total_messages.write() += 1;
322    }
323
324    /// Record a new LLM request (broadcasts to clients)
325    pub fn record_request(&self, model: &str, cost: f64, tokens: u64, latency_ms: f64) {
326        let _ = self.broadcast_tx.send(WsMessage::RequestRecorded {
327            model: model.to_string(),
328            cost,
329            tokens,
330            latency_ms,
331        });
332    }
333
334    /// Get current metrics snapshot
335    pub fn get_metrics(&self) -> DashboardMetrics {
336        let agents: Vec<AgentStatus> = self.agents.read().values().cloned().collect();
337        let active_agents = agents.iter().filter(|a| a.status == "running").count();
338        let uptime = Utc::now()
339            .signed_duration_since(self.started_at)
340            .num_seconds() as u64;
341
342        DashboardMetrics {
343            timestamp: Utc::now(),
344            cost_stats: self.cost_tracker.stats(),
345            agents,
346            active_agents,
347            total_messages: *self.total_messages.read(),
348            uptime_seconds: uptime,
349        }
350    }
351
352    /// Subscribe to updates
353    pub fn subscribe(&self) -> broadcast::Receiver<WsMessage> {
354        self.broadcast_tx.subscribe()
355    }
356
357    /// Broadcast metrics to all connected clients
358    pub fn broadcast_metrics(&self) {
359        let metrics = self.get_metrics();
360        let _ = self.broadcast_tx.send(WsMessage::Metrics(metrics));
361    }
362}