Skip to main content

kitty_rc/
protocol.rs

1use crate::error::ProtocolError;
2use serde::{Deserialize, Serialize};
3use std::sync::atomic::{AtomicU32, Ordering};
4
5const PREFIX: &str = "\x1bP@kitty-cmd";
6const SUFFIX: &str = "\x1b\\";
7const MAX_CHUNK_SIZE: usize = 4096;
8
9static STREAM_ID_COUNTER: AtomicU32 = AtomicU32::new(1);
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct KittyMessage {
13    pub cmd: String,
14    pub version: Vec<u32>,
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub no_response: Option<bool>,
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub kitty_window_id: Option<String>,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub payload: Option<serde_json::Value>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub async_id: Option<String>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub cancel_async: Option<bool>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub stream_id: Option<String>,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub stream: Option<bool>,
29}
30
31impl KittyMessage {
32    pub fn new(cmd: impl Into<String>, version: impl Into<Vec<u32>>) -> Self {
33        Self {
34            cmd: cmd.into(),
35            version: version.into(),
36            no_response: None,
37            kitty_window_id: None,
38            payload: None,
39            async_id: None,
40            cancel_async: None,
41            stream_id: None,
42            stream: None,
43        }
44    }
45
46    pub fn no_response(mut self, value: bool) -> Self {
47        self.no_response = Some(value);
48        self
49    }
50
51    pub fn kitty_window_id(mut self, id: impl Into<String>) -> Self {
52        self.kitty_window_id = Some(id.into());
53        self
54    }
55
56    pub fn payload(mut self, payload: serde_json::Value) -> Self {
57        self.payload = Some(payload);
58        self
59    }
60
61    pub fn async_id(mut self, id: impl Into<String>) -> Self {
62        self.async_id = Some(id.into());
63        self
64    }
65
66    pub fn cancel_async(mut self, value: bool) -> Self {
67        self.cancel_async = Some(value);
68        self
69    }
70
71    pub fn stream_id(mut self, id: impl Into<String>) -> Self {
72        self.stream_id = Some(id.into());
73        self
74    }
75
76    pub fn stream(mut self, value: bool) -> Self {
77        self.stream = Some(value);
78        self
79    }
80
81    pub fn generate_unique_id() -> String {
82        let id = STREAM_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
83        format!("{:x}", id)
84    }
85
86    pub fn needs_streaming(&self) -> bool {
87        if let Some(payload) = &self.payload {
88            if let Some(obj) = payload.as_object() {
89                for (_key, value) in obj {
90                    if let Some(s) = value.as_str() {
91                        if s.len() > MAX_CHUNK_SIZE {
92                            return true;
93                        }
94                    }
95                }
96            }
97        }
98        false
99    }
100
101    pub fn into_chunks(mut self) -> Vec<KittyMessage> {
102        let mut chunks = Vec::new();
103
104        if !self.needs_streaming() {
105            return vec![self];
106        }
107
108        if let Some(payload) = self.payload.take() {
109            if let Some(obj) = payload.as_object() {
110                let stream_id = Self::generate_unique_id();
111
112                for (_key, value) in obj {
113                    if let Some(s) = value.as_str() {
114                        if s.len() > MAX_CHUNK_SIZE {
115                            for (i, chunk_data) in s.as_bytes().chunks(MAX_CHUNK_SIZE).enumerate() {
116                                let mut chunk_msg = self.clone();
117                                chunk_msg.stream_id = Some(stream_id.clone());
118                                chunk_msg.stream = Some(true);
119
120                                let mut chunk_payload = serde_json::Map::new();
121                                chunk_payload.insert(
122                                    "data".to_string(),
123                                    serde_json::Value::String(
124                                        String::from_utf8_lossy(chunk_data).to_string(),
125                                    ),
126                                );
127                                chunk_payload.insert("chunk_num".to_string(), serde_json::json!(i));
128                                chunk_msg.payload = Some(serde_json::Value::Object(chunk_payload));
129
130                                chunks.push(chunk_msg);
131                            }
132
133                            let mut end_chunk = self.clone();
134                            end_chunk.stream_id = Some(stream_id);
135                            end_chunk.stream = Some(true);
136                            let mut end_payload = serde_json::Map::new();
137                            end_payload.insert(
138                                "data".to_string(),
139                                serde_json::Value::String(String::new()),
140                            );
141                            end_chunk.payload = Some(serde_json::Value::Object(end_payload));
142                            chunks.push(end_chunk);
143
144                            return chunks;
145                        }
146                    }
147                }
148            }
149        }
150
151        chunks.push(self);
152        chunks
153    }
154
155    pub fn encode(&self) -> Result<Vec<u8>, ProtocolError> {
156        let json = serde_json::to_string(self)?;
157        let message = format!("{}{}{}", PREFIX, json, SUFFIX);
158        Ok(message.into_bytes())
159    }
160
161    pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
162        let s = std::str::from_utf8(data)
163            .map_err(|e| ProtocolError::InvalidMessageFormat(e.to_string()))?;
164
165        if !s.starts_with(PREFIX) {
166            return Err(ProtocolError::InvalidEscapeSequence);
167        }
168
169        if !s.ends_with(SUFFIX) {
170            return Err(ProtocolError::InvalidEscapeSequence);
171        }
172
173        let json_start = PREFIX.len();
174        let json_end = s.len() - SUFFIX.len();
175        let json_str = &s[json_start..json_end];
176
177        serde_json::from_str(json_str).map_err(ProtocolError::JsonError)
178    }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct KittyResponse {
183    pub ok: bool,
184    pub data: Option<serde_json::Value>,
185    pub error: Option<String>,
186}
187
188impl KittyResponse {
189    pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
190        let s = std::str::from_utf8(data)
191            .map_err(|e| ProtocolError::EnvelopeParseError(e.to_string()))?;
192
193        if !s.starts_with("\x1bP@kitty-cmd") {
194            return Err(ProtocolError::EnvelopeParseError(
195                "Invalid response prefix".to_string(),
196            ));
197        }
198
199        if !s.ends_with("\x1b\\") {
200            return Err(ProtocolError::EnvelopeParseError(
201                "Invalid response suffix".to_string(),
202            ));
203        }
204
205        let json_start = PREFIX.len();
206        let json_end = s.len() - SUFFIX.len();
207        let json_str = &s[json_start..json_end];
208
209        let msg: serde_json::Value =
210            serde_json::from_str(json_str).map_err(ProtocolError::JsonError)?;
211
212        if !msg.is_object() {
213            return Err(ProtocolError::EnvelopeParseError(
214                "Response is not a JSON object".to_string(),
215            ));
216        }
217
218        serde_json::from_value(msg).map_err(ProtocolError::JsonError)
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn test_message_encode() {
228        let msg = KittyMessage::new("ls", vec![0, 14, 2]);
229        let encoded = msg.encode().unwrap();
230        let decoded = KittyMessage::decode(&encoded).unwrap();
231        assert_eq!(decoded.cmd, "ls");
232        assert_eq!(decoded.version, vec![0, 14, 2]);
233    }
234
235    #[test]
236    fn test_message_with_payload() {
237        let msg = KittyMessage::new("send-text", vec![0, 14, 2])
238            .payload(serde_json::json!({"match": "id:1", "data": "text:hello"}));
239        let encoded = msg.encode().unwrap();
240        let decoded = KittyMessage::decode(&encoded).unwrap();
241        assert_eq!(decoded.cmd, "send-text");
242        assert!(decoded.payload.is_some());
243    }
244
245    #[test]
246    fn test_message_no_response() {
247        let msg = KittyMessage::new("close-window", vec![0, 14, 2]).no_response(true);
248        let encoded = msg.encode().unwrap();
249        let decoded = KittyMessage::decode(&encoded).unwrap();
250        assert_eq!(decoded.no_response, Some(true));
251    }
252
253    #[test]
254    fn test_invalid_escape_sequence() {
255        let data = b"invalid message";
256        let result = KittyMessage::decode(data);
257        assert!(result.is_err());
258    }
259
260    #[test]
261    fn test_response_decode() {
262        let raw = b"\x1bP@kitty-cmd{\"ok\":true,\"data\":[{\"id\":1,\"title\":\"test\"}]}\x1b\\";
263        let response = KittyResponse::decode(raw).unwrap();
264        assert!(response.ok);
265        assert!(response.data.is_some());
266    }
267
268    #[test]
269    fn test_async_id() {
270        let msg = KittyMessage::new("select-window", vec![0, 14, 2]).async_id("abc123");
271        let encoded = msg.encode().unwrap();
272        let decoded = KittyMessage::decode(&encoded).unwrap();
273        assert_eq!(decoded.async_id, Some("abc123".to_string()));
274    }
275
276    #[test]
277    fn test_cancel_async() {
278        let msg = KittyMessage::new("select-window", vec![0, 14, 2])
279            .async_id("abc123")
280            .cancel_async(true);
281        let encoded = msg.encode().unwrap();
282        let decoded = KittyMessage::decode(&encoded).unwrap();
283        assert_eq!(decoded.cancel_async, Some(true));
284    }
285
286    #[test]
287    fn test_unique_id_generation() {
288        let id1 = KittyMessage::generate_unique_id();
289        let id2 = KittyMessage::generate_unique_id();
290        assert_ne!(id1, id2);
291    }
292
293    #[test]
294    fn test_needs_streaming_false() {
295        let msg = KittyMessage::new("send-text", vec![0, 14, 2])
296            .payload(serde_json::json!({"data": "hello"}));
297        assert!(!msg.needs_streaming());
298    }
299
300    #[test]
301    fn test_needs_streaming_true() {
302        let large_data = "x".repeat(5000);
303        let msg = KittyMessage::new("send-text", vec![0, 14, 2])
304            .payload(serde_json::json!({"data": large_data}));
305        assert!(msg.needs_streaming());
306    }
307
308    #[test]
309    fn test_into_chunks_no_streaming() {
310        let msg = KittyMessage::new("send-text", vec![0, 14, 2])
311            .payload(serde_json::json!({"data": "hello"}));
312        let chunks = msg.into_chunks();
313        assert_eq!(chunks.len(), 1);
314    }
315
316    #[test]
317    fn test_into_chunks_with_streaming() {
318        let large_data = "x".repeat(5000);
319        let msg = KittyMessage::new("set-background-image", vec![0, 14, 2])
320            .payload(serde_json::json!({"data": large_data}));
321        let chunks = msg.into_chunks();
322        assert!(chunks.len() > 1);
323        assert!(chunks.iter().all(|c| c.stream_id.is_some()));
324        assert!(chunks.iter().all(|c| c.stream == Some(true)));
325    }
326}