use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
response::Response,
};
use chrono::{DateTime, Utc};
use futures_util::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::collections::{HashSet, VecDeque};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant};
use super::SharedState;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActiveTool {
pub name: String,
#[serde(with = "instant_serde")]
pub started_at: Instant,
pub parameters: serde_json::Value,
pub progress: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FileAccessEvent {
pub path: PathBuf,
pub access_type: AccessType,
#[serde(with = "instant_serde")]
pub timestamp: Instant,
pub tool_name: String,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum AccessType {
Read,
Write,
Analyze,
Search,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolExecution {
pub name: String,
pub completed_at: DateTime<Utc>,
pub duration_ms: u64,
pub success: bool,
pub summary: String,
}
#[derive(Debug, Default)]
pub struct McpActivityState {
pub active_tool: Option<ActiveTool>,
pub file_access_log: VecDeque<FileAccessEvent>,
pub current_operation: String,
pub files_touched: HashSet<PathBuf>,
pub directories_explored: HashSet<PathBuf>,
pub tool_history: VecDeque<ToolExecution>,
pub last_update: Option<Instant>,
}
impl McpActivityState {
const MAX_FILE_EVENTS: usize = 100;
const MAX_TOOL_HISTORY: usize = 20;
pub fn start_tool(&mut self, name: &str, parameters: serde_json::Value) {
self.active_tool = Some(ActiveTool {
name: name.to_string(),
started_at: Instant::now(),
parameters,
progress: None,
});
self.current_operation = format!("Executing {}...", name);
self.last_update = Some(Instant::now());
}
pub fn update_progress(&mut self, progress: f32) {
if let Some(ref mut tool) = self.active_tool {
tool.progress = Some(progress.clamp(0.0, 1.0));
self.last_update = Some(Instant::now());
}
}
pub fn record_file_access(&mut self, path: PathBuf, access_type: AccessType, tool_name: &str) {
let event = FileAccessEvent {
path: path.clone(),
access_type,
timestamp: Instant::now(),
tool_name: tool_name.to_string(),
};
self.file_access_log.push_back(event);
if self.file_access_log.len() > Self::MAX_FILE_EVENTS {
self.file_access_log.pop_front();
}
self.files_touched.insert(path.clone());
if let Some(parent) = path.parent() {
self.directories_explored.insert(parent.to_path_buf());
}
self.last_update = Some(Instant::now());
}
pub fn complete_tool(&mut self, success: bool, summary: &str) {
if let Some(tool) = self.active_tool.take() {
let duration = tool.started_at.elapsed();
let execution = ToolExecution {
name: tool.name,
completed_at: Utc::now(),
duration_ms: duration.as_millis() as u64,
success,
summary: summary.to_string(),
};
self.tool_history.push_back(execution);
if self.tool_history.len() > Self::MAX_TOOL_HISTORY {
self.tool_history.pop_front();
}
}
self.current_operation = if success {
"Ready".to_string()
} else {
"Error occurred".to_string()
};
self.last_update = Some(Instant::now());
}
pub fn recent_file_events(&self, max_age_ms: u64) -> Vec<FileEventDto> {
let now = Instant::now();
self.file_access_log
.iter()
.filter_map(|event| {
let age = now.duration_since(event.timestamp).as_millis() as u64;
if age <= max_age_ms {
Some(FileEventDto {
path: event.path.to_string_lossy().to_string(),
access_type: event.access_type,
age_ms: age,
tool_name: event.tool_name.clone(),
})
} else {
None
}
})
.collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserHint {
pub hint_type: HintType,
pub content: String,
pub timestamp: DateTime<Utc>,
#[serde(default)]
pub consumed: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum HintType {
Click { target: String },
Text { message: String },
Voice { transcript: String, salience: f32 },
}
#[derive(Debug, Default)]
pub struct UserHintsQueue {
pub hints: VecDeque<UserHint>,
pub max_size: usize,
}
impl UserHintsQueue {
const DEFAULT_MAX_SIZE: usize = 50;
pub fn new() -> Self {
Self {
hints: VecDeque::new(),
max_size: Self::DEFAULT_MAX_SIZE,
}
}
pub fn push(&mut self, hint: UserHint) {
self.hints.push_back(hint);
while self.hints.len() > self.max_size {
self.hints.pop_front();
}
}
pub fn consume_next(&mut self) -> Option<UserHint> {
for hint in &mut self.hints {
if !hint.consumed {
hint.consumed = true;
return Some(hint.clone());
}
}
None
}
pub fn peek_unconsumed(&self) -> Vec<&UserHint> {
self.hints.iter().filter(|h| !h.consumed).collect()
}
pub fn unconsumed_count(&self) -> usize {
self.hints.iter().filter(|h| !h.consumed).count()
}
pub fn gc(&mut self, max_age: Duration) {
let now = Utc::now();
self.hints.retain(|h| {
!h.consumed || (now - h.timestamp).num_milliseconds() < max_age.as_millis() as i64
});
}
}
#[derive(Debug, Serialize)]
pub struct StateUpdateDto {
#[serde(rename = "type")]
pub msg_type: &'static str,
pub timestamp: i64,
pub mcp: McpStateDto,
pub file_log: Vec<FileEventDto>,
pub wave_compass: WaveCompassDto,
pub hints_pending: usize,
}
#[derive(Debug, Serialize)]
pub struct McpStateDto {
pub active_tool: Option<String>,
pub current_operation: String,
pub progress: Option<f32>,
pub tools_executed: usize,
}
#[derive(Debug, Serialize)]
pub struct FileEventDto {
pub path: String,
pub access_type: AccessType,
pub age_ms: u64,
pub tool_name: String,
}
#[derive(Debug, Serialize)]
pub struct WaveCompassDto {
pub hot_regions: Vec<HotRegion>,
pub trail: Vec<[f32; 2]>,
}
#[derive(Debug, Serialize)]
pub struct HotRegion {
pub x: f32,
pub y: f32,
pub intensity: f32,
pub label: String,
}
#[derive(Debug, Deserialize)]
pub struct HintMessageDto {
#[serde(rename = "type")]
pub msg_type: String,
pub hint_type: String,
#[serde(default)]
pub target: Option<String>,
#[serde(default)]
pub content: Option<String>,
#[serde(default)]
pub transcript: Option<String>,
#[serde(default)]
pub salience: Option<f32>,
}
pub async fn state_handler(
ws: WebSocketUpgrade,
State(state): State<SharedState>,
) -> Response {
ws.on_upgrade(|socket| handle_state_socket(socket, state))
}
async fn handle_state_socket(socket: WebSocket, state: SharedState) {
let (mut sender, mut receiver) = socket.split();
{
let mut dashboard = state.write().await;
dashboard.connections += 1;
}
let mcp_activity = {
let dashboard = state.read().await;
Arc::clone(&dashboard.mcp_activity)
};
let user_hints = {
let dashboard = state.read().await;
Arc::clone(&dashboard.user_hints)
};
let mcp_activity_send = Arc::clone(&mcp_activity);
let user_hints_send = Arc::clone(&user_hints);
let send_task = tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_millis(16)); let mut last_update: Option<Instant> = None;
loop {
interval.tick().await;
let activity = mcp_activity_send.read().await;
if activity.last_update == last_update {
continue;
}
last_update = activity.last_update;
let update = build_state_update(&activity, &*user_hints_send.read().await);
drop(activity);
let json = match serde_json::to_string(&update) {
Ok(j) => j,
Err(_) => continue,
};
if sender.send(Message::Text(json)).await.is_err() {
break; }
}
});
let recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = receiver.next().await {
if let Message::Text(text) = msg {
if let Ok(hint_msg) = serde_json::from_str::<HintMessageDto>(&text) {
if hint_msg.msg_type == "hint" {
let hint = parse_hint_message(hint_msg);
user_hints.write().await.push(hint);
}
}
}
}
});
tokio::select! {
_ = send_task => {},
_ = recv_task => {},
}
{
let mut dashboard = state.write().await;
dashboard.connections = dashboard.connections.saturating_sub(1);
}
}
fn build_state_update(activity: &McpActivityState, hints: &UserHintsQueue) -> StateUpdateDto {
let file_log = activity.recent_file_events(10_000);
let wave_compass = build_wave_compass(&file_log, &activity.directories_explored);
StateUpdateDto {
msg_type: "state_update",
timestamp: Utc::now().timestamp_millis(),
mcp: McpStateDto {
active_tool: activity.active_tool.as_ref().map(|t| t.name.clone()),
current_operation: activity.current_operation.clone(),
progress: activity.active_tool.as_ref().and_then(|t| t.progress),
tools_executed: activity.tool_history.len(),
},
file_log,
wave_compass,
hints_pending: hints.unconsumed_count(),
}
}
fn build_wave_compass(file_log: &[FileEventDto], _directories: &HashSet<PathBuf>) -> WaveCompassDto {
use std::collections::HashMap;
let mut dir_intensity: HashMap<String, f32> = HashMap::new();
for event in file_log {
let dir = PathBuf::from(&event.path)
.parent()
.map(|p| p.to_string_lossy().to_string())
.unwrap_or_default();
let intensity = 1.0 - (event.age_ms as f32 / 10_000.0).min(1.0);
*dir_intensity.entry(dir).or_default() += intensity * 0.3;
}
let hot_regions: Vec<HotRegion> = dir_intensity
.into_iter()
.map(|(dir, intensity)| {
let (x, y) = path_to_coords(&dir);
HotRegion {
x,
y,
intensity: intensity.min(1.0),
label: dir.split('/').next_back().unwrap_or(&dir).to_string(),
}
})
.filter(|r| r.intensity > 0.05)
.collect();
let trail: Vec<[f32; 2]> = file_log
.iter()
.rev()
.take(20)
.map(|e| {
let (x, y) = path_to_coords(&e.path);
[x, y]
})
.collect();
WaveCompassDto { hot_regions, trail }
}
fn path_to_coords(path: &str) -> (f32, f32) {
let parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
if parts.is_empty() {
return (0.5, 0.5);
}
let quadrant = match parts.first().copied() {
Some("src") => (0.0, 0.0), Some("tests") => (0.5, 0.0), Some("docs") => (0.0, 0.5), Some("scripts") => (0.5, 0.5), Some("examples") => (0.25, 0.25),
_ => (0.25, 0.75),
};
let sub_hash = simple_hash(&parts[1..].join("/"));
let x = quadrant.0 + (sub_hash % 100) as f32 / 200.0;
let y = quadrant.1 + ((sub_hash / 100) % 100) as f32 / 200.0;
(x.clamp(0.02, 0.98), y.clamp(0.02, 0.98))
}
fn simple_hash(s: &str) -> u32 {
s.bytes().fold(0u32, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u32))
}
fn parse_hint_message(msg: HintMessageDto) -> UserHint {
let content = msg.content.unwrap_or_default();
let hint_type = match msg.hint_type.as_str() {
"click" => HintType::Click {
target: msg.target.unwrap_or_default(),
},
"text" => HintType::Text {
message: content.clone(),
},
"voice" => HintType::Voice {
transcript: msg.transcript.unwrap_or_default(),
salience: msg.salience.unwrap_or(0.5),
},
_ => HintType::Text {
message: content.clone(),
},
};
UserHint {
hint_type,
content,
timestamp: Utc::now(),
consumed: false,
}
}
mod instant_serde {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::time::Instant;
pub fn serialize<S>(instant: &Instant, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let age_ms = instant.elapsed().as_millis() as u64;
age_ms.serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Instant, D::Error>
where
D: Deserializer<'de>,
{
let age_ms = u64::deserialize(deserializer)?;
Ok(Instant::now() - std::time::Duration::from_millis(age_ms))
}
}