Skip to main content

kitty_rc/
protocol.rs

1use crate::error::ProtocolError;
2use serde::{Deserialize, Serialize};
3use std::sync::atomic::{AtomicU32, Ordering};
4
5const PREFIX: &str = "\x1bP@kitty-cmd";
6const SUFFIX: &str = "\x1b\\";
7const MAX_CHUNK_SIZE: usize = 4096;
8
9static STREAM_ID_COUNTER: AtomicU32 = AtomicU32::new(1);
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct KittyMessage {
13    pub cmd: String,
14    pub version: Vec<u32>,
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub no_response: Option<bool>,
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub kitty_window_id: Option<String>,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub payload: Option<serde_json::Value>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub async_id: Option<String>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub cancel_async: Option<bool>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub stream_id: Option<String>,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub stream: Option<bool>,
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub encrypted: Option<String>,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub iv: Option<String>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub tag: Option<String>,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub pubkey: Option<String>,
37}
38
39impl KittyMessage {
40    pub fn new(cmd: impl Into<String>, version: impl Into<Vec<u32>>) -> Self {
41        Self {
42            cmd: cmd.into(),
43            version: version.into(),
44            no_response: None,
45            kitty_window_id: None,
46            payload: None,
47            async_id: None,
48            cancel_async: None,
49            stream_id: None,
50            stream: None,
51            encrypted: None,
52            iv: None,
53            tag: None,
54            pubkey: None,
55        }
56    }
57
58    pub fn no_response(mut self, value: bool) -> Self {
59        self.no_response = Some(value);
60        self
61    }
62
63    pub fn kitty_window_id(mut self, id: impl Into<String>) -> Self {
64        self.kitty_window_id = Some(id.into());
65        self
66    }
67
68    pub fn payload(mut self, payload: serde_json::Value) -> Self {
69        self.payload = Some(payload);
70        self
71    }
72
73    pub fn async_id(mut self, id: impl Into<String>) -> Self {
74        self.async_id = Some(id.into());
75        self
76    }
77
78    pub fn cancel_async(mut self, value: bool) -> Self {
79        self.cancel_async = Some(value);
80        self
81    }
82
83    pub fn stream_id(mut self, id: impl Into<String>) -> Self {
84        self.stream_id = Some(id.into());
85        self
86    }
87
88    pub fn stream(mut self, value: bool) -> Self {
89        self.stream = Some(value);
90        self
91    }
92
93    pub fn generate_unique_id() -> String {
94        let id = STREAM_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
95        format!("{:x}", id)
96    }
97
98    pub fn needs_streaming(&self) -> bool {
99        if let Some(payload) = &self.payload {
100            if let Some(obj) = payload.as_object() {
101                for (_key, value) in obj {
102                    if let Some(s) = value.as_str() {
103                        if s.len() > MAX_CHUNK_SIZE {
104                            return true;
105                        }
106                    }
107                }
108            }
109        }
110        false
111    }
112
113    pub fn into_chunks(mut self) -> Vec<KittyMessage> {
114        let mut chunks = Vec::new();
115
116        if !self.needs_streaming() {
117            return vec![self];
118        }
119
120        if let Some(payload) = self.payload.take() {
121            if let Some(obj) = payload.as_object() {
122                let stream_id = Self::generate_unique_id();
123
124                for (_key, value) in obj {
125                    if let Some(s) = value.as_str() {
126                        if s.len() > MAX_CHUNK_SIZE {
127                            for (i, chunk_data) in s.as_bytes().chunks(MAX_CHUNK_SIZE).enumerate() {
128                                let mut chunk_msg = self.clone();
129                                chunk_msg.stream_id = Some(stream_id.clone());
130                                chunk_msg.stream = Some(true);
131
132                                let mut chunk_payload = serde_json::Map::new();
133                                chunk_payload.insert(
134                                    "data".to_string(),
135                                    serde_json::Value::String(
136                                        String::from_utf8_lossy(chunk_data).to_string(),
137                                    ),
138                                );
139                                chunk_payload.insert("chunk_num".to_string(), serde_json::json!(i));
140                                chunk_msg.payload = Some(serde_json::Value::Object(chunk_payload));
141
142                                chunks.push(chunk_msg);
143                            }
144
145                            let mut end_chunk = self.clone();
146                            end_chunk.stream_id = Some(stream_id);
147                            end_chunk.stream = Some(true);
148                            let mut end_payload = serde_json::Map::new();
149                            end_payload.insert(
150                                "data".to_string(),
151                                serde_json::Value::String(String::new()),
152                            );
153                            end_chunk.payload = Some(serde_json::Value::Object(end_payload));
154                            chunks.push(end_chunk);
155
156                            return chunks;
157                        }
158                    }
159                }
160            }
161        }
162
163        chunks.push(self);
164        chunks
165    }
166
167    pub fn encode(&self) -> Result<Vec<u8>, ProtocolError> {
168        let json = if self.encrypted.is_some() {
169            serde_json::json!({
170                "version": self.version,
171                "iv": self.iv,
172                "tag": self.tag,
173                "pubkey": self.pubkey,
174                "encrypted": self.encrypted,
175            })
176        } else {
177            serde_json::to_value(self)?
178        };
179        let json_str = serde_json::to_string(&json)?;
180        let message = format!("{}{}{}", PREFIX, json_str, SUFFIX);
181        Ok(message.into_bytes())
182    }
183
184    pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
185        let s = std::str::from_utf8(data)
186            .map_err(|e| ProtocolError::InvalidMessageFormat(e.to_string()))?;
187
188        if !s.starts_with(PREFIX) {
189            return Err(ProtocolError::InvalidEscapeSequence);
190        }
191
192        if !s.ends_with(SUFFIX) {
193            return Err(ProtocolError::InvalidEscapeSequence);
194        }
195
196        let json_start = PREFIX.len();
197        let json_end = s.len() - SUFFIX.len();
198        let json_str = &s[json_start..json_end];
199
200        serde_json::from_str(json_str).map_err(ProtocolError::JsonError)
201    }
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct KittyResponse {
206    pub ok: bool,
207    pub data: Option<serde_json::Value>,
208    pub error: Option<String>,
209}
210
211impl KittyResponse {
212    pub fn decode(data: &[u8]) -> Result<Self, ProtocolError> {
213        let s = std::str::from_utf8(data)
214            .map_err(|e| ProtocolError::EnvelopeParseError(e.to_string()))?;
215
216        if !s.starts_with("\x1bP@kitty-cmd") {
217            return Err(ProtocolError::EnvelopeParseError(
218                "Invalid response prefix".to_string(),
219            ));
220        }
221
222        if !s.ends_with("\x1b\\") {
223            return Err(ProtocolError::EnvelopeParseError(
224                "Invalid response suffix".to_string(),
225            ));
226        }
227
228        let json_start = PREFIX.len();
229        let json_end = s.len() - SUFFIX.len();
230        let json_str = &s[json_start..json_end];
231
232        let msg: serde_json::Value =
233            serde_json::from_str(json_str).map_err(ProtocolError::JsonError)?;
234
235        if !msg.is_object() {
236            return Err(ProtocolError::EnvelopeParseError(
237                "Response is not a JSON object".to_string(),
238            ));
239        }
240
241        serde_json::from_value(msg).map_err(ProtocolError::JsonError)
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_message_encode() {
251        let msg = KittyMessage::new("ls", vec![0, 14, 2]);
252        let encoded = msg.encode().unwrap();
253        let decoded = KittyMessage::decode(&encoded).unwrap();
254        assert_eq!(decoded.cmd, "ls");
255        assert_eq!(decoded.version, vec![0, 14, 2]);
256    }
257
258    #[test]
259    fn test_message_with_payload() {
260        let msg = KittyMessage::new("send-text", vec![0, 14, 2])
261            .payload(serde_json::json!({"match": "id:1", "data": "text:hello"}));
262        let encoded = msg.encode().unwrap();
263        let decoded = KittyMessage::decode(&encoded).unwrap();
264        assert_eq!(decoded.cmd, "send-text");
265        assert!(decoded.payload.is_some());
266    }
267
268    #[test]
269    fn test_message_no_response() {
270        let msg = KittyMessage::new("close-window", vec![0, 14, 2]).no_response(true);
271        let encoded = msg.encode().unwrap();
272        let decoded = KittyMessage::decode(&encoded).unwrap();
273        assert_eq!(decoded.no_response, Some(true));
274    }
275
276    #[test]
277    fn test_invalid_escape_sequence() {
278        let data = b"invalid message";
279        let result = KittyMessage::decode(data);
280        assert!(result.is_err());
281    }
282
283    #[test]
284    fn test_response_decode() {
285        let raw = b"\x1bP@kitty-cmd{\"ok\":true,\"data\":[{\"id\":1,\"title\":\"test\"}]}\x1b\\";
286        let response = KittyResponse::decode(raw).unwrap();
287        assert!(response.ok);
288        assert!(response.data.is_some());
289    }
290
291    #[test]
292    fn test_async_id() {
293        let msg = KittyMessage::new("select-window", vec![0, 14, 2]).async_id("abc123");
294        let encoded = msg.encode().unwrap();
295        let decoded = KittyMessage::decode(&encoded).unwrap();
296        assert_eq!(decoded.async_id, Some("abc123".to_string()));
297    }
298
299    #[test]
300    fn test_cancel_async() {
301        let msg = KittyMessage::new("select-window", vec![0, 14, 2])
302            .async_id("abc123")
303            .cancel_async(true);
304        let encoded = msg.encode().unwrap();
305        let decoded = KittyMessage::decode(&encoded).unwrap();
306        assert_eq!(decoded.cancel_async, Some(true));
307    }
308
309    #[test]
310    fn test_unique_id_generation() {
311        let id1 = KittyMessage::generate_unique_id();
312        let id2 = KittyMessage::generate_unique_id();
313        assert_ne!(id1, id2);
314    }
315
316    #[test]
317    fn test_needs_streaming_false() {
318        let msg = KittyMessage::new("send-text", vec![0, 14, 2])
319            .payload(serde_json::json!({"data": "hello"}));
320        assert!(!msg.needs_streaming());
321    }
322
323    #[test]
324    fn test_needs_streaming_true() {
325        let large_data = "x".repeat(5000);
326        let msg = KittyMessage::new("send-text", vec![0, 14, 2])
327            .payload(serde_json::json!({"data": large_data}));
328        assert!(msg.needs_streaming());
329    }
330
331    #[test]
332    fn test_into_chunks_no_streaming() {
333        let msg = KittyMessage::new("send-text", vec![0, 14, 2])
334            .payload(serde_json::json!({"data": "hello"}));
335        let chunks = msg.into_chunks();
336        assert_eq!(chunks.len(), 1);
337    }
338
339    #[test]
340    fn test_into_chunks_with_streaming() {
341        let large_data = "x".repeat(5000);
342        let msg = KittyMessage::new("set-background-image", vec![0, 14, 2])
343            .payload(serde_json::json!({"data": large_data}));
344        let chunks = msg.into_chunks();
345        assert!(chunks.len() > 1);
346        assert!(chunks.iter().all(|c| c.stream_id.is_some()));
347        assert!(chunks.iter().all(|c| c.stream == Some(true)));
348    }
349}