Skip to main content

egui_mcp_protocol/
lib.rs

1//! Common protocol definitions for egui-mcp
2//!
3//! This crate defines the shared types and protocols used for IPC communication
4//! between the MCP server and egui client applications.
5//!
6//! Note: UI tree access, element search, and click/text input operations are
7//! handled via AT-SPI on Linux. This protocol is only used for features that
8//! require direct client integration (screenshots, coordinate-based input, etc.).
9
10use serde::{Deserialize, Serialize};
11use std::path::PathBuf;
12use thiserror::Error;
13
14/// Default socket path for IPC communication
15pub fn default_socket_path() -> PathBuf {
16    let runtime_dir = std::env::var("XDG_RUNTIME_DIR")
17        .map(PathBuf::from)
18        .unwrap_or_else(|_| std::env::temp_dir());
19    runtime_dir.join("egui-mcp.sock")
20}
21
22/// Information about a UI node (used for AT-SPI responses)
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct NodeInfo {
25    /// Unique identifier for the node
26    pub id: u64,
27    /// Role of the node (e.g., "Button", "TextInput", "Window")
28    pub role: String,
29    /// Human-readable label
30    pub label: Option<String>,
31    /// Current value (for inputs, sliders, etc.)
32    pub value: Option<String>,
33    /// Bounding rectangle
34    pub bounds: Option<Rect>,
35    /// Child node IDs
36    pub children: Vec<u64>,
37    /// Whether the node is toggled (for checkboxes, toggles)
38    pub toggled: Option<bool>,
39    /// Whether the node is disabled
40    pub disabled: bool,
41    /// Whether the node has focus
42    pub focused: bool,
43}
44
45/// A rectangle in screen coordinates
46#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
47pub struct Rect {
48    pub x: f32,
49    pub y: f32,
50    pub width: f32,
51    pub height: f32,
52}
53
54/// UI tree containing all nodes (used for AT-SPI responses)
55#[derive(Debug, Clone, Default, Serialize, Deserialize)]
56pub struct UiTree {
57    /// Root node IDs
58    pub roots: Vec<u64>,
59    /// All nodes in the tree
60    pub nodes: Vec<NodeInfo>,
61}
62
63/// Mouse button for click operations
64#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
65pub enum MouseButton {
66    Left,
67    Right,
68    Middle,
69}
70
71/// Log entry captured from the application
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct LogEntry {
74    /// Log level (TRACE, DEBUG, INFO, WARN, ERROR)
75    pub level: String,
76    /// Target module/crate
77    pub target: String,
78    /// Log message
79    pub message: String,
80    /// Timestamp in milliseconds since UNIX epoch
81    pub timestamp_ms: u64,
82}
83
84/// Frame statistics for performance monitoring
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct FrameStats {
87    /// Current frames per second
88    pub fps: f32,
89    /// Average frame time in milliseconds
90    pub frame_time_ms: f32,
91    /// Minimum frame time in milliseconds
92    pub frame_time_min_ms: f32,
93    /// Maximum frame time in milliseconds
94    pub frame_time_max_ms: f32,
95    /// Number of frames sampled
96    pub sample_count: usize,
97}
98
99/// Performance report from a recording session
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct PerfReport {
102    /// Recording duration in milliseconds
103    pub duration_ms: u64,
104    /// Total frames recorded
105    pub total_frames: usize,
106    /// Average FPS over the recording
107    pub avg_fps: f32,
108    /// Average frame time in milliseconds
109    pub avg_frame_time_ms: f32,
110    /// Minimum frame time in milliseconds
111    pub min_frame_time_ms: f32,
112    /// Maximum frame time in milliseconds
113    pub max_frame_time_ms: f32,
114    /// 95th percentile frame time in milliseconds
115    pub p95_frame_time_ms: f32,
116    /// 99th percentile frame time in milliseconds
117    pub p99_frame_time_ms: f32,
118}
119
120/// Request types for IPC communication
121///
122/// These are operations that require direct client integration and cannot be
123/// performed via AT-SPI.
124#[derive(Debug, Clone, Serialize, Deserialize)]
125#[serde(tag = "type")]
126pub enum Request {
127    /// Ping the client to check connection
128    Ping,
129
130    /// Request a screenshot of the application window
131    TakeScreenshot,
132
133    /// Request a screenshot of a specific region of the application window
134    TakeScreenshotRegion {
135        /// X coordinate of the region (relative to window)
136        x: f32,
137        /// Y coordinate of the region (relative to window)
138        y: f32,
139        /// Width of the region
140        width: f32,
141        /// Height of the region
142        height: f32,
143    },
144
145    /// Click at specific screen coordinates
146    ClickAt {
147        /// X coordinate (relative to window)
148        x: f32,
149        /// Y coordinate (relative to window)
150        y: f32,
151        /// Mouse button to click
152        button: MouseButton,
153    },
154
155    /// Send keyboard input
156    KeyboardInput {
157        /// Key to press (e.g., "Enter", "Tab", "a", "Ctrl+C")
158        key: String,
159    },
160
161    /// Scroll at specific coordinates
162    Scroll {
163        /// X coordinate (relative to window)
164        x: f32,
165        /// Y coordinate (relative to window)
166        y: f32,
167        /// Horizontal scroll delta
168        delta_x: f32,
169        /// Vertical scroll delta
170        delta_y: f32,
171    },
172
173    /// Move mouse to specific coordinates (for hover effects)
174    MoveMouse {
175        /// X coordinate (relative to window)
176        x: f32,
177        /// Y coordinate (relative to window)
178        y: f32,
179    },
180
181    /// Drag from one position to another
182    Drag {
183        /// Start X coordinate
184        start_x: f32,
185        /// Start Y coordinate
186        start_y: f32,
187        /// End X coordinate
188        end_x: f32,
189        /// End Y coordinate
190        end_y: f32,
191        /// Mouse button to use
192        button: MouseButton,
193    },
194
195    /// Double click at specific screen coordinates
196    DoubleClick {
197        /// X coordinate (relative to window)
198        x: f32,
199        /// Y coordinate (relative to window)
200        y: f32,
201        /// Mouse button to click
202        button: MouseButton,
203    },
204
205    /// Highlight an element with a colored border
206    HighlightElement {
207        /// Bounding box x coordinate
208        x: f32,
209        /// Bounding box y coordinate
210        y: f32,
211        /// Bounding box width
212        width: f32,
213        /// Bounding box height
214        height: f32,
215        /// Color as RGBA (0-255 each)
216        color: [u8; 4],
217        /// Duration in milliseconds (0 = until cleared)
218        duration_ms: u64,
219    },
220
221    /// Clear all highlights
222    ClearHighlights,
223
224    /// Get recent log entries
225    GetLogs {
226        /// Minimum log level to return (TRACE, DEBUG, INFO, WARN, ERROR)
227        /// If None, returns all levels
228        level: Option<String>,
229        /// Maximum number of entries to return
230        limit: Option<usize>,
231    },
232
233    /// Clear the log buffer
234    ClearLogs,
235
236    /// Get current frame statistics
237    GetFrameStats,
238
239    /// Start recording performance data
240    StartPerfRecording {
241        /// Duration to record in milliseconds (0 = until stopped)
242        duration_ms: u64,
243    },
244
245    /// Stop and get performance report
246    GetPerfReport,
247}
248
249/// Response types for IPC communication
250#[derive(Debug, Clone, Serialize, Deserialize)]
251#[serde(tag = "type")]
252pub enum Response {
253    /// Pong response to Ping
254    Pong,
255
256    /// Screenshot response
257    Screenshot {
258        /// Base64 encoded PNG data
259        data: String,
260        /// Image format (always "png")
261        format: String,
262    },
263
264    /// Success response (for operations without data)
265    Success,
266
267    /// Error response
268    Error { message: String },
269
270    /// Log entries response
271    Logs {
272        /// Log entries (oldest first)
273        entries: Vec<LogEntry>,
274    },
275
276    /// Frame statistics response
277    FrameStatsResponse {
278        /// Current frame statistics
279        stats: FrameStats,
280    },
281
282    /// Performance report response
283    PerfReportResponse {
284        /// Performance report (None if not recording or no data)
285        report: Option<PerfReport>,
286    },
287}
288
289/// Protocol errors
290#[derive(Debug, Error)]
291pub enum ProtocolError {
292    #[error("IO error: {0}")]
293    Io(#[from] std::io::Error),
294    #[error("JSON error: {0}")]
295    Json(#[from] serde_json::Error),
296    #[error("Connection closed")]
297    ConnectionClosed,
298    #[error("Message too large: {0} bytes")]
299    MessageTooLarge(usize),
300}
301
302/// Maximum message size (1 MB)
303pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
304
305/// Read a length-prefixed message from a reader
306pub async fn read_message<R: tokio::io::AsyncReadExt + Unpin>(
307    reader: &mut R,
308) -> Result<Vec<u8>, ProtocolError> {
309    let mut len_buf = [0u8; 4];
310    match reader.read_exact(&mut len_buf).await {
311        Ok(_) => {}
312        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
313            return Err(ProtocolError::ConnectionClosed);
314        }
315        Err(e) => return Err(e.into()),
316    }
317
318    let len = u32::from_be_bytes(len_buf) as usize;
319    if len > MAX_MESSAGE_SIZE {
320        return Err(ProtocolError::MessageTooLarge(len));
321    }
322
323    let mut buf = vec![0u8; len];
324    reader.read_exact(&mut buf).await?;
325    Ok(buf)
326}
327
328/// Write a length-prefixed message to a writer
329pub async fn write_message<W: tokio::io::AsyncWriteExt + Unpin>(
330    writer: &mut W,
331    data: &[u8],
332) -> Result<(), ProtocolError> {
333    if data.len() > MAX_MESSAGE_SIZE {
334        return Err(ProtocolError::MessageTooLarge(data.len()));
335    }
336
337    let len = (data.len() as u32).to_be_bytes();
338    writer.write_all(&len).await?;
339    writer.write_all(data).await?;
340    writer.flush().await?;
341    Ok(())
342}
343
344/// Read and deserialize a request
345pub async fn read_request<R: tokio::io::AsyncReadExt + Unpin>(
346    reader: &mut R,
347) -> Result<Request, ProtocolError> {
348    let data = read_message(reader).await?;
349    let request = serde_json::from_slice(&data)?;
350    Ok(request)
351}
352
353/// Write and serialize a response
354pub async fn write_response<W: tokio::io::AsyncWriteExt + Unpin>(
355    writer: &mut W,
356    response: &Response,
357) -> Result<(), ProtocolError> {
358    let data = serde_json::to_vec(response)?;
359    write_message(writer, &data).await
360}
361
362/// Read and deserialize a response
363pub async fn read_response<R: tokio::io::AsyncReadExt + Unpin>(
364    reader: &mut R,
365) -> Result<Response, ProtocolError> {
366    let data = read_message(reader).await?;
367    let response = serde_json::from_slice(&data)?;
368    Ok(response)
369}
370
371/// Write and serialize a request
372pub async fn write_request<W: tokio::io::AsyncWriteExt + Unpin>(
373    writer: &mut W,
374    request: &Request,
375) -> Result<(), ProtocolError> {
376    let data = serde_json::to_vec(request)?;
377    write_message(writer, &data).await
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn test_serialize_request() {
386        let req = Request::Ping;
387        let json = serde_json::to_string(&req).unwrap();
388        assert!(json.contains("Ping"));
389    }
390
391    #[test]
392    fn test_serialize_response() {
393        let resp = Response::Pong;
394        let json = serde_json::to_string(&resp).unwrap();
395        assert!(json.contains("Pong"));
396    }
397
398    #[test]
399    fn test_default_socket_path() {
400        let path = default_socket_path();
401        assert!(path.to_string_lossy().contains("egui-mcp.sock"));
402    }
403
404    #[test]
405    fn test_click_at_request() {
406        let req = Request::ClickAt {
407            x: 100.0,
408            y: 200.0,
409            button: MouseButton::Left,
410        };
411        let json = serde_json::to_string(&req).unwrap();
412        assert!(json.contains("ClickAt"));
413        assert!(json.contains("100"));
414    }
415
416    #[test]
417    fn test_keyboard_input_request() {
418        let req = Request::KeyboardInput {
419            key: "Enter".to_string(),
420        };
421        let json = serde_json::to_string(&req).unwrap();
422        assert!(json.contains("KeyboardInput"));
423        assert!(json.contains("Enter"));
424    }
425
426    #[test]
427    fn test_request_roundtrip_drag() {
428        let req = Request::Drag {
429            start_x: 10.0,
430            start_y: 20.0,
431            end_x: 100.0,
432            end_y: 200.0,
433            button: MouseButton::Left,
434        };
435        let json = serde_json::to_string(&req).unwrap();
436        let decoded: Request = serde_json::from_str(&json).unwrap();
437        if let Request::Drag {
438            start_x,
439            start_y,
440            end_x,
441            end_y,
442            button,
443        } = decoded
444        {
445            assert_eq!(start_x, 10.0);
446            assert_eq!(start_y, 20.0);
447            assert_eq!(end_x, 100.0);
448            assert_eq!(end_y, 200.0);
449            assert!(matches!(button, MouseButton::Left));
450        } else {
451            panic!("Expected Drag request");
452        }
453    }
454
455    #[test]
456    fn test_request_roundtrip_scroll() {
457        let req = Request::Scroll {
458            x: 50.0,
459            y: 60.0,
460            delta_x: -10.0,
461            delta_y: 20.0,
462        };
463        let json = serde_json::to_string(&req).unwrap();
464        let decoded: Request = serde_json::from_str(&json).unwrap();
465        if let Request::Scroll {
466            x,
467            y,
468            delta_x,
469            delta_y,
470        } = decoded
471        {
472            assert_eq!(x, 50.0);
473            assert_eq!(y, 60.0);
474            assert_eq!(delta_x, -10.0);
475            assert_eq!(delta_y, 20.0);
476        } else {
477            panic!("Expected Scroll request");
478        }
479    }
480
481    #[test]
482    fn test_response_roundtrip_screenshot() {
483        let resp = Response::Screenshot {
484            data: "base64data".to_string(),
485            format: "png".to_string(),
486        };
487        let json = serde_json::to_string(&resp).unwrap();
488        let decoded: Response = serde_json::from_str(&json).unwrap();
489        if let Response::Screenshot { data, format } = decoded {
490            assert_eq!(data, "base64data");
491            assert_eq!(format, "png");
492        } else {
493            panic!("Expected Screenshot response");
494        }
495    }
496
497    #[test]
498    fn test_response_roundtrip_error() {
499        let resp = Response::Error {
500            message: "Something went wrong".to_string(),
501        };
502        let json = serde_json::to_string(&resp).unwrap();
503        let decoded: Response = serde_json::from_str(&json).unwrap();
504        if let Response::Error { message } = decoded {
505            assert_eq!(message, "Something went wrong");
506        } else {
507            panic!("Expected Error response");
508        }
509    }
510
511    #[test]
512    fn test_node_info_serialization() {
513        let node = NodeInfo {
514            id: 42,
515            role: "Button".to_string(),
516            label: Some("Click me".to_string()),
517            value: None,
518            bounds: Some(Rect {
519                x: 10.0,
520                y: 20.0,
521                width: 100.0,
522                height: 50.0,
523            }),
524            children: vec![1, 2, 3],
525            toggled: Some(true),
526            disabled: false,
527            focused: true,
528        };
529        let json = serde_json::to_string(&node).unwrap();
530        let decoded: NodeInfo = serde_json::from_str(&json).unwrap();
531
532        assert_eq!(decoded.id, 42);
533        assert_eq!(decoded.role, "Button");
534        assert_eq!(decoded.label, Some("Click me".to_string()));
535        assert!(decoded.bounds.is_some());
536        assert_eq!(decoded.children, vec![1, 2, 3]);
537        assert_eq!(decoded.toggled, Some(true));
538        assert!(!decoded.disabled);
539        assert!(decoded.focused);
540    }
541
542    #[test]
543    fn test_log_entry_serialization() {
544        let entry = LogEntry {
545            level: "INFO".to_string(),
546            target: "my_app".to_string(),
547            message: "Hello world".to_string(),
548            timestamp_ms: 1234567890,
549        };
550        let json = serde_json::to_string(&entry).unwrap();
551        let decoded: LogEntry = serde_json::from_str(&json).unwrap();
552
553        assert_eq!(decoded.level, "INFO");
554        assert_eq!(decoded.target, "my_app");
555        assert_eq!(decoded.message, "Hello world");
556        assert_eq!(decoded.timestamp_ms, 1234567890);
557    }
558
559    #[test]
560    fn test_frame_stats_serialization() {
561        let stats = FrameStats {
562            fps: 60.0,
563            frame_time_ms: 16.67,
564            frame_time_min_ms: 15.0,
565            frame_time_max_ms: 20.0,
566            sample_count: 100,
567        };
568        let json = serde_json::to_string(&stats).unwrap();
569        let decoded: FrameStats = serde_json::from_str(&json).unwrap();
570
571        assert_eq!(decoded.fps, 60.0);
572        assert_eq!(decoded.sample_count, 100);
573    }
574
575    #[test]
576    fn test_mouse_button_variants() {
577        let buttons = [MouseButton::Left, MouseButton::Right, MouseButton::Middle];
578        for button in buttons {
579            let json = serde_json::to_string(&button).unwrap();
580            let decoded: MouseButton = serde_json::from_str(&json).unwrap();
581            assert!(matches!(
582                (&button, &decoded),
583                (MouseButton::Left, MouseButton::Left)
584                    | (MouseButton::Right, MouseButton::Right)
585                    | (MouseButton::Middle, MouseButton::Middle)
586            ));
587        }
588    }
589}