use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use thiserror::Error;
pub fn default_socket_path() -> PathBuf {
let runtime_dir = std::env::var("XDG_RUNTIME_DIR")
.map(PathBuf::from)
.unwrap_or_else(|_| std::env::temp_dir());
runtime_dir.join("egui-mcp.sock")
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeInfo {
pub id: u64,
pub role: String,
pub label: Option<String>,
pub value: Option<String>,
pub bounds: Option<Rect>,
pub children: Vec<u64>,
pub toggled: Option<bool>,
pub disabled: bool,
pub focused: bool,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct Rect {
pub x: f32,
pub y: f32,
pub width: f32,
pub height: f32,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct UiTree {
pub roots: Vec<u64>,
pub nodes: Vec<NodeInfo>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum MouseButton {
Left,
Right,
Middle,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogEntry {
pub level: String,
pub target: String,
pub message: String,
pub timestamp_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FrameStats {
pub fps: f32,
pub frame_time_ms: f32,
pub frame_time_min_ms: f32,
pub frame_time_max_ms: f32,
pub sample_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerfReport {
pub duration_ms: u64,
pub total_frames: usize,
pub avg_fps: f32,
pub avg_frame_time_ms: f32,
pub min_frame_time_ms: f32,
pub max_frame_time_ms: f32,
pub p95_frame_time_ms: f32,
pub p99_frame_time_ms: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Request {
Ping,
TakeScreenshot,
TakeScreenshotRegion {
x: f32,
y: f32,
width: f32,
height: f32,
},
ClickAt {
x: f32,
y: f32,
button: MouseButton,
},
KeyboardInput {
key: String,
},
Scroll {
x: f32,
y: f32,
delta_x: f32,
delta_y: f32,
},
MoveMouse {
x: f32,
y: f32,
},
Drag {
start_x: f32,
start_y: f32,
end_x: f32,
end_y: f32,
button: MouseButton,
},
DoubleClick {
x: f32,
y: f32,
button: MouseButton,
},
HighlightElement {
x: f32,
y: f32,
width: f32,
height: f32,
color: [u8; 4],
duration_ms: u64,
},
ClearHighlights,
GetLogs {
level: Option<String>,
limit: Option<usize>,
},
ClearLogs,
GetFrameStats,
StartPerfRecording {
duration_ms: u64,
},
GetPerfReport,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Response {
Pong,
Screenshot {
data: String,
format: String,
},
Success,
Error { message: String },
Logs {
entries: Vec<LogEntry>,
},
FrameStatsResponse {
stats: FrameStats,
},
PerfReportResponse {
report: Option<PerfReport>,
},
}
#[derive(Debug, Error)]
pub enum ProtocolError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("Connection closed")]
ConnectionClosed,
#[error("Message too large: {0} bytes")]
MessageTooLarge(usize),
}
pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
pub async fn read_message<R: tokio::io::AsyncReadExt + Unpin>(
reader: &mut R,
) -> Result<Vec<u8>, ProtocolError> {
let mut len_buf = [0u8; 4];
match reader.read_exact(&mut len_buf).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
return Err(ProtocolError::ConnectionClosed);
}
Err(e) => return Err(e.into()),
}
let len = u32::from_be_bytes(len_buf) as usize;
if len > MAX_MESSAGE_SIZE {
return Err(ProtocolError::MessageTooLarge(len));
}
let mut buf = vec![0u8; len];
reader.read_exact(&mut buf).await?;
Ok(buf)
}
pub async fn write_message<W: tokio::io::AsyncWriteExt + Unpin>(
writer: &mut W,
data: &[u8],
) -> Result<(), ProtocolError> {
if data.len() > MAX_MESSAGE_SIZE {
return Err(ProtocolError::MessageTooLarge(data.len()));
}
let len = (data.len() as u32).to_be_bytes();
writer.write_all(&len).await?;
writer.write_all(data).await?;
writer.flush().await?;
Ok(())
}
pub async fn read_request<R: tokio::io::AsyncReadExt + Unpin>(
reader: &mut R,
) -> Result<Request, ProtocolError> {
let data = read_message(reader).await?;
let request = serde_json::from_slice(&data)?;
Ok(request)
}
pub async fn write_response<W: tokio::io::AsyncWriteExt + Unpin>(
writer: &mut W,
response: &Response,
) -> Result<(), ProtocolError> {
let data = serde_json::to_vec(response)?;
write_message(writer, &data).await
}
pub async fn read_response<R: tokio::io::AsyncReadExt + Unpin>(
reader: &mut R,
) -> Result<Response, ProtocolError> {
let data = read_message(reader).await?;
let response = serde_json::from_slice(&data)?;
Ok(response)
}
pub async fn write_request<W: tokio::io::AsyncWriteExt + Unpin>(
writer: &mut W,
request: &Request,
) -> Result<(), ProtocolError> {
let data = serde_json::to_vec(request)?;
write_message(writer, &data).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_request() {
let req = Request::Ping;
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("Ping"));
}
#[test]
fn test_serialize_response() {
let resp = Response::Pong;
let json = serde_json::to_string(&resp).unwrap();
assert!(json.contains("Pong"));
}
#[test]
fn test_default_socket_path() {
let path = default_socket_path();
assert!(path.to_string_lossy().contains("egui-mcp.sock"));
}
#[test]
fn test_click_at_request() {
let req = Request::ClickAt {
x: 100.0,
y: 200.0,
button: MouseButton::Left,
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("ClickAt"));
assert!(json.contains("100"));
}
#[test]
fn test_keyboard_input_request() {
let req = Request::KeyboardInput {
key: "Enter".to_string(),
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("KeyboardInput"));
assert!(json.contains("Enter"));
}
#[test]
fn test_request_roundtrip_drag() {
let req = Request::Drag {
start_x: 10.0,
start_y: 20.0,
end_x: 100.0,
end_y: 200.0,
button: MouseButton::Left,
};
let json = serde_json::to_string(&req).unwrap();
let decoded: Request = serde_json::from_str(&json).unwrap();
if let Request::Drag {
start_x,
start_y,
end_x,
end_y,
button,
} = decoded
{
assert_eq!(start_x, 10.0);
assert_eq!(start_y, 20.0);
assert_eq!(end_x, 100.0);
assert_eq!(end_y, 200.0);
assert!(matches!(button, MouseButton::Left));
} else {
panic!("Expected Drag request");
}
}
#[test]
fn test_request_roundtrip_scroll() {
let req = Request::Scroll {
x: 50.0,
y: 60.0,
delta_x: -10.0,
delta_y: 20.0,
};
let json = serde_json::to_string(&req).unwrap();
let decoded: Request = serde_json::from_str(&json).unwrap();
if let Request::Scroll {
x,
y,
delta_x,
delta_y,
} = decoded
{
assert_eq!(x, 50.0);
assert_eq!(y, 60.0);
assert_eq!(delta_x, -10.0);
assert_eq!(delta_y, 20.0);
} else {
panic!("Expected Scroll request");
}
}
#[test]
fn test_response_roundtrip_screenshot() {
let resp = Response::Screenshot {
data: "base64data".to_string(),
format: "png".to_string(),
};
let json = serde_json::to_string(&resp).unwrap();
let decoded: Response = serde_json::from_str(&json).unwrap();
if let Response::Screenshot { data, format } = decoded {
assert_eq!(data, "base64data");
assert_eq!(format, "png");
} else {
panic!("Expected Screenshot response");
}
}
#[test]
fn test_response_roundtrip_error() {
let resp = Response::Error {
message: "Something went wrong".to_string(),
};
let json = serde_json::to_string(&resp).unwrap();
let decoded: Response = serde_json::from_str(&json).unwrap();
if let Response::Error { message } = decoded {
assert_eq!(message, "Something went wrong");
} else {
panic!("Expected Error response");
}
}
#[test]
fn test_node_info_serialization() {
let node = NodeInfo {
id: 42,
role: "Button".to_string(),
label: Some("Click me".to_string()),
value: None,
bounds: Some(Rect {
x: 10.0,
y: 20.0,
width: 100.0,
height: 50.0,
}),
children: vec![1, 2, 3],
toggled: Some(true),
disabled: false,
focused: true,
};
let json = serde_json::to_string(&node).unwrap();
let decoded: NodeInfo = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.id, 42);
assert_eq!(decoded.role, "Button");
assert_eq!(decoded.label, Some("Click me".to_string()));
assert!(decoded.bounds.is_some());
assert_eq!(decoded.children, vec![1, 2, 3]);
assert_eq!(decoded.toggled, Some(true));
assert!(!decoded.disabled);
assert!(decoded.focused);
}
#[test]
fn test_log_entry_serialization() {
let entry = LogEntry {
level: "INFO".to_string(),
target: "my_app".to_string(),
message: "Hello world".to_string(),
timestamp_ms: 1234567890,
};
let json = serde_json::to_string(&entry).unwrap();
let decoded: LogEntry = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.level, "INFO");
assert_eq!(decoded.target, "my_app");
assert_eq!(decoded.message, "Hello world");
assert_eq!(decoded.timestamp_ms, 1234567890);
}
#[test]
fn test_frame_stats_serialization() {
let stats = FrameStats {
fps: 60.0,
frame_time_ms: 16.67,
frame_time_min_ms: 15.0,
frame_time_max_ms: 20.0,
sample_count: 100,
};
let json = serde_json::to_string(&stats).unwrap();
let decoded: FrameStats = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.fps, 60.0);
assert_eq!(decoded.sample_count, 100);
}
#[test]
fn test_mouse_button_variants() {
let buttons = [MouseButton::Left, MouseButton::Right, MouseButton::Middle];
for button in buttons {
let json = serde_json::to_string(&button).unwrap();
let decoded: MouseButton = serde_json::from_str(&json).unwrap();
assert!(matches!(
(&button, &decoded),
(MouseButton::Left, MouseButton::Left)
| (MouseButton::Right, MouseButton::Right)
| (MouseButton::Middle, MouseButton::Middle)
));
}
}
}