1use serde::{Deserialize, Serialize};
11use std::path::PathBuf;
12use thiserror::Error;
13
14pub fn default_socket_path() -> PathBuf {
16 let runtime_dir = std::env::var("XDG_RUNTIME_DIR")
17 .map(PathBuf::from)
18 .unwrap_or_else(|_| std::env::temp_dir());
19 runtime_dir.join("egui-mcp.sock")
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct NodeInfo {
25 pub id: u64,
27 pub role: String,
29 pub label: Option<String>,
31 pub value: Option<String>,
33 pub bounds: Option<Rect>,
35 pub children: Vec<u64>,
37 pub toggled: Option<bool>,
39 pub disabled: bool,
41 pub focused: bool,
43}
44
45#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
47pub struct Rect {
48 pub x: f32,
49 pub y: f32,
50 pub width: f32,
51 pub height: f32,
52}
53
54#[derive(Debug, Clone, Default, Serialize, Deserialize)]
56pub struct UiTree {
57 pub roots: Vec<u64>,
59 pub nodes: Vec<NodeInfo>,
61}
62
63#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
65pub enum MouseButton {
66 Left,
67 Right,
68 Middle,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct LogEntry {
74 pub level: String,
76 pub target: String,
78 pub message: String,
80 pub timestamp_ms: u64,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct FrameStats {
87 pub fps: f32,
89 pub frame_time_ms: f32,
91 pub frame_time_min_ms: f32,
93 pub frame_time_max_ms: f32,
95 pub sample_count: usize,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct PerfReport {
102 pub duration_ms: u64,
104 pub total_frames: usize,
106 pub avg_fps: f32,
108 pub avg_frame_time_ms: f32,
110 pub min_frame_time_ms: f32,
112 pub max_frame_time_ms: f32,
114 pub p95_frame_time_ms: f32,
116 pub p99_frame_time_ms: f32,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
125#[serde(tag = "type")]
126pub enum Request {
127 Ping,
129
130 TakeScreenshot,
132
133 TakeScreenshotRegion {
135 x: f32,
137 y: f32,
139 width: f32,
141 height: f32,
143 },
144
145 ClickAt {
147 x: f32,
149 y: f32,
151 button: MouseButton,
153 },
154
155 KeyboardInput {
157 key: String,
159 },
160
161 Scroll {
163 x: f32,
165 y: f32,
167 delta_x: f32,
169 delta_y: f32,
171 },
172
173 MoveMouse {
175 x: f32,
177 y: f32,
179 },
180
181 Drag {
183 start_x: f32,
185 start_y: f32,
187 end_x: f32,
189 end_y: f32,
191 button: MouseButton,
193 },
194
195 DoubleClick {
197 x: f32,
199 y: f32,
201 button: MouseButton,
203 },
204
205 HighlightElement {
207 x: f32,
209 y: f32,
211 width: f32,
213 height: f32,
215 color: [u8; 4],
217 duration_ms: u64,
219 },
220
221 ClearHighlights,
223
224 GetLogs {
226 level: Option<String>,
229 limit: Option<usize>,
231 },
232
233 ClearLogs,
235
236 GetFrameStats,
238
239 StartPerfRecording {
241 duration_ms: u64,
243 },
244
245 GetPerfReport,
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize)]
251#[serde(tag = "type")]
252pub enum Response {
253 Pong,
255
256 Screenshot {
258 data: String,
260 format: String,
262 },
263
264 Success,
266
267 Error { message: String },
269
270 Logs {
272 entries: Vec<LogEntry>,
274 },
275
276 FrameStatsResponse {
278 stats: FrameStats,
280 },
281
282 PerfReportResponse {
284 report: Option<PerfReport>,
286 },
287}
288
289#[derive(Debug, Error)]
291pub enum ProtocolError {
292 #[error("IO error: {0}")]
293 Io(#[from] std::io::Error),
294 #[error("JSON error: {0}")]
295 Json(#[from] serde_json::Error),
296 #[error("Connection closed")]
297 ConnectionClosed,
298 #[error("Message too large: {0} bytes")]
299 MessageTooLarge(usize),
300}
301
302pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
304
305pub async fn read_message<R: tokio::io::AsyncReadExt + Unpin>(
307 reader: &mut R,
308) -> Result<Vec<u8>, ProtocolError> {
309 let mut len_buf = [0u8; 4];
310 match reader.read_exact(&mut len_buf).await {
311 Ok(_) => {}
312 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
313 return Err(ProtocolError::ConnectionClosed);
314 }
315 Err(e) => return Err(e.into()),
316 }
317
318 let len = u32::from_be_bytes(len_buf) as usize;
319 if len > MAX_MESSAGE_SIZE {
320 return Err(ProtocolError::MessageTooLarge(len));
321 }
322
323 let mut buf = vec![0u8; len];
324 reader.read_exact(&mut buf).await?;
325 Ok(buf)
326}
327
328pub async fn write_message<W: tokio::io::AsyncWriteExt + Unpin>(
330 writer: &mut W,
331 data: &[u8],
332) -> Result<(), ProtocolError> {
333 if data.len() > MAX_MESSAGE_SIZE {
334 return Err(ProtocolError::MessageTooLarge(data.len()));
335 }
336
337 let len = (data.len() as u32).to_be_bytes();
338 writer.write_all(&len).await?;
339 writer.write_all(data).await?;
340 writer.flush().await?;
341 Ok(())
342}
343
344pub async fn read_request<R: tokio::io::AsyncReadExt + Unpin>(
346 reader: &mut R,
347) -> Result<Request, ProtocolError> {
348 let data = read_message(reader).await?;
349 let request = serde_json::from_slice(&data)?;
350 Ok(request)
351}
352
353pub async fn write_response<W: tokio::io::AsyncWriteExt + Unpin>(
355 writer: &mut W,
356 response: &Response,
357) -> Result<(), ProtocolError> {
358 let data = serde_json::to_vec(response)?;
359 write_message(writer, &data).await
360}
361
362pub async fn read_response<R: tokio::io::AsyncReadExt + Unpin>(
364 reader: &mut R,
365) -> Result<Response, ProtocolError> {
366 let data = read_message(reader).await?;
367 let response = serde_json::from_slice(&data)?;
368 Ok(response)
369}
370
371pub async fn write_request<W: tokio::io::AsyncWriteExt + Unpin>(
373 writer: &mut W,
374 request: &Request,
375) -> Result<(), ProtocolError> {
376 let data = serde_json::to_vec(request)?;
377 write_message(writer, &data).await
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[test]
385 fn test_serialize_request() {
386 let req = Request::Ping;
387 let json = serde_json::to_string(&req).unwrap();
388 assert!(json.contains("Ping"));
389 }
390
391 #[test]
392 fn test_serialize_response() {
393 let resp = Response::Pong;
394 let json = serde_json::to_string(&resp).unwrap();
395 assert!(json.contains("Pong"));
396 }
397
398 #[test]
399 fn test_default_socket_path() {
400 let path = default_socket_path();
401 assert!(path.to_string_lossy().contains("egui-mcp.sock"));
402 }
403
404 #[test]
405 fn test_click_at_request() {
406 let req = Request::ClickAt {
407 x: 100.0,
408 y: 200.0,
409 button: MouseButton::Left,
410 };
411 let json = serde_json::to_string(&req).unwrap();
412 assert!(json.contains("ClickAt"));
413 assert!(json.contains("100"));
414 }
415
416 #[test]
417 fn test_keyboard_input_request() {
418 let req = Request::KeyboardInput {
419 key: "Enter".to_string(),
420 };
421 let json = serde_json::to_string(&req).unwrap();
422 assert!(json.contains("KeyboardInput"));
423 assert!(json.contains("Enter"));
424 }
425
426 #[test]
427 fn test_request_roundtrip_drag() {
428 let req = Request::Drag {
429 start_x: 10.0,
430 start_y: 20.0,
431 end_x: 100.0,
432 end_y: 200.0,
433 button: MouseButton::Left,
434 };
435 let json = serde_json::to_string(&req).unwrap();
436 let decoded: Request = serde_json::from_str(&json).unwrap();
437 if let Request::Drag {
438 start_x,
439 start_y,
440 end_x,
441 end_y,
442 button,
443 } = decoded
444 {
445 assert_eq!(start_x, 10.0);
446 assert_eq!(start_y, 20.0);
447 assert_eq!(end_x, 100.0);
448 assert_eq!(end_y, 200.0);
449 assert!(matches!(button, MouseButton::Left));
450 } else {
451 panic!("Expected Drag request");
452 }
453 }
454
455 #[test]
456 fn test_request_roundtrip_scroll() {
457 let req = Request::Scroll {
458 x: 50.0,
459 y: 60.0,
460 delta_x: -10.0,
461 delta_y: 20.0,
462 };
463 let json = serde_json::to_string(&req).unwrap();
464 let decoded: Request = serde_json::from_str(&json).unwrap();
465 if let Request::Scroll {
466 x,
467 y,
468 delta_x,
469 delta_y,
470 } = decoded
471 {
472 assert_eq!(x, 50.0);
473 assert_eq!(y, 60.0);
474 assert_eq!(delta_x, -10.0);
475 assert_eq!(delta_y, 20.0);
476 } else {
477 panic!("Expected Scroll request");
478 }
479 }
480
481 #[test]
482 fn test_response_roundtrip_screenshot() {
483 let resp = Response::Screenshot {
484 data: "base64data".to_string(),
485 format: "png".to_string(),
486 };
487 let json = serde_json::to_string(&resp).unwrap();
488 let decoded: Response = serde_json::from_str(&json).unwrap();
489 if let Response::Screenshot { data, format } = decoded {
490 assert_eq!(data, "base64data");
491 assert_eq!(format, "png");
492 } else {
493 panic!("Expected Screenshot response");
494 }
495 }
496
497 #[test]
498 fn test_response_roundtrip_error() {
499 let resp = Response::Error {
500 message: "Something went wrong".to_string(),
501 };
502 let json = serde_json::to_string(&resp).unwrap();
503 let decoded: Response = serde_json::from_str(&json).unwrap();
504 if let Response::Error { message } = decoded {
505 assert_eq!(message, "Something went wrong");
506 } else {
507 panic!("Expected Error response");
508 }
509 }
510
511 #[test]
512 fn test_node_info_serialization() {
513 let node = NodeInfo {
514 id: 42,
515 role: "Button".to_string(),
516 label: Some("Click me".to_string()),
517 value: None,
518 bounds: Some(Rect {
519 x: 10.0,
520 y: 20.0,
521 width: 100.0,
522 height: 50.0,
523 }),
524 children: vec![1, 2, 3],
525 toggled: Some(true),
526 disabled: false,
527 focused: true,
528 };
529 let json = serde_json::to_string(&node).unwrap();
530 let decoded: NodeInfo = serde_json::from_str(&json).unwrap();
531
532 assert_eq!(decoded.id, 42);
533 assert_eq!(decoded.role, "Button");
534 assert_eq!(decoded.label, Some("Click me".to_string()));
535 assert!(decoded.bounds.is_some());
536 assert_eq!(decoded.children, vec![1, 2, 3]);
537 assert_eq!(decoded.toggled, Some(true));
538 assert!(!decoded.disabled);
539 assert!(decoded.focused);
540 }
541
542 #[test]
543 fn test_log_entry_serialization() {
544 let entry = LogEntry {
545 level: "INFO".to_string(),
546 target: "my_app".to_string(),
547 message: "Hello world".to_string(),
548 timestamp_ms: 1234567890,
549 };
550 let json = serde_json::to_string(&entry).unwrap();
551 let decoded: LogEntry = serde_json::from_str(&json).unwrap();
552
553 assert_eq!(decoded.level, "INFO");
554 assert_eq!(decoded.target, "my_app");
555 assert_eq!(decoded.message, "Hello world");
556 assert_eq!(decoded.timestamp_ms, 1234567890);
557 }
558
559 #[test]
560 fn test_frame_stats_serialization() {
561 let stats = FrameStats {
562 fps: 60.0,
563 frame_time_ms: 16.67,
564 frame_time_min_ms: 15.0,
565 frame_time_max_ms: 20.0,
566 sample_count: 100,
567 };
568 let json = serde_json::to_string(&stats).unwrap();
569 let decoded: FrameStats = serde_json::from_str(&json).unwrap();
570
571 assert_eq!(decoded.fps, 60.0);
572 assert_eq!(decoded.sample_count, 100);
573 }
574
575 #[test]
576 fn test_mouse_button_variants() {
577 let buttons = [MouseButton::Left, MouseButton::Right, MouseButton::Middle];
578 for button in buttons {
579 let json = serde_json::to_string(&button).unwrap();
580 let decoded: MouseButton = serde_json::from_str(&json).unwrap();
581 assert!(matches!(
582 (&button, &decoded),
583 (MouseButton::Left, MouseButton::Left)
584 | (MouseButton::Right, MouseButton::Right)
585 | (MouseButton::Middle, MouseButton::Middle)
586 ));
587 }
588 }
589}