Skip to main content

kitty_rc/
client.rs

1use crate::encryption::Encryptor;
2use crate::error::{ConnectionError, EncryptionError, KittyError};
3use crate::protocol::{KittyMessage, KittyResponse};
4use std::path::Path;
5use std::process::Command;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::UnixStream;
9use tokio::time::timeout;
10use xdg::BaseDirectories;
11
12pub struct Kitty {
13    stream: UnixStream,
14    timeout: Duration,
15    socket_path: String,
16    password: Option<String>,
17    encryptor: Option<Encryptor>,
18}
19
20pub struct KittyBuilder {
21    socket_path: Option<String>,
22    password: Option<String>,
23    public_key: Option<String>,
24    timeout: Duration,
25}
26
27impl Default for KittyBuilder {
28    fn default() -> Self {
29        Self {
30            socket_path: None,
31            password: None,
32            public_key: None,
33            timeout: Duration::from_secs(10),
34        }
35    }
36}
37
38impl KittyBuilder {
39    pub fn new() -> Self {
40        Self::default()
41    }
42
43    pub fn socket_path<P: AsRef<Path>>(mut self, path: P) -> Self {
44        self.socket_path = Some(path.as_ref().to_string_lossy().to_string());
45        self
46    }
47
48    pub fn from_pid(mut self, pid: u32) -> Self {
49        let xdg_dirs = BaseDirectories::new();
50        let runtime_dir = xdg_dirs.runtime_dir.clone()
51            .unwrap_or_else(|| Path::new("/tmp").to_path_buf());
52        let socket_path = runtime_dir.join(format!("kitty-{}.sock", pid));
53        self.socket_path = Some(socket_path.to_string_lossy().to_string());
54        self
55    }
56
57    pub fn timeout(mut self, duration: Duration) -> Self {
58        self.timeout = duration;
59        self
60    }
61
62    pub fn password(mut self, password: impl Into<String>) -> Self {
63        self.password = Some(password.into());
64        self
65    }
66
67    pub fn public_key(mut self, public_key: impl Into<String>) -> Self {
68        self.public_key = Some(public_key.into());
69        self
70    }
71
72    pub async fn connect(self) -> Result<Kitty, KittyError> {
73        let socket_path = self.socket_path.ok_or_else(|| {
74            KittyError::Connection(ConnectionError::SocketNotFound(
75                "No socket path provided".to_string(),
76            ))
77        })?;
78
79        let stream = timeout(self.timeout, UnixStream::connect(&socket_path))
80            .await
81            .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
82            .map_err(|e| ConnectionError::ConnectionFailed(socket_path.clone(), e))?;
83
84        let encryptor = if self.password.is_some() {
85            let public_key = if let Some(pk) = self.public_key {
86                Some(pk)
87            } else if let Some(pid) = Self::extract_pid_from_socket(&socket_path) {
88                Self::query_public_key_database(pid).map_err(KittyError::Encryption)?
89            } else {
90                None
91            };
92
93            Some(Encryptor::new_with_public_key(public_key.as_deref())?)
94        } else {
95            None
96        };
97
98        Ok(Kitty {
99            stream,
100            timeout: self.timeout,
101            socket_path,
102            password: self.password,
103            encryptor,
104        })
105    }
106
107    fn extract_pid_from_socket(socket_path: &str) -> Option<u32> {
108        let filename = Path::new(socket_path)
109            .file_name()?
110            .to_str()?;
111
112        let pid_str = filename.strip_prefix("kitty-")?;
113        let pid_str = pid_str.strip_suffix(".sock")?;
114        pid_str.parse().ok()
115    }
116
117    fn query_public_key_database(pid: u32) -> Result<Option<String>, EncryptionError> {
118        let output = Command::new("kitty-pubkey-db")
119            .arg("get")
120            .arg(pid.to_string())
121            .output()
122            .map_err(|e| {
123                EncryptionError::PublicKeyDatabaseError(format!("Failed to run kitty-pubkey-db: {}", e))
124            })?;
125
126        if !output.status.success() {
127            return Ok(None);
128        }
129
130        let pubkey = String::from_utf8(output.stdout)
131            .map_err(|e| {
132                EncryptionError::PublicKeyDatabaseError(format!("Invalid UTF-8 output: {}", e))
133            })?
134            .trim()
135            .to_string();
136
137        if pubkey.is_empty() {
138            Ok(None)
139        } else {
140            Ok(Some(pubkey))
141        }
142    }
143}
144
145impl Kitty {
146    pub fn builder() -> KittyBuilder {
147        KittyBuilder::new()
148    }
149
150    fn encrypt_command(&self, message: KittyMessage) -> Result<KittyMessage, KittyError> {
151        let Some(encryptor) = &self.encryptor else {
152            return Ok(message);
153        };
154
155        let Some(password) = &self.password else {
156            return Ok(message);
157        };
158
159        let timestamp = SystemTime::now()
160            .duration_since(UNIX_EPOCH)
161            .map_err(|_| {
162                KittyError::Encryption(crate::error::EncryptionError::EncryptionFailed(
163                    "Failed to get timestamp".to_string(),
164                ))
165            })?
166            .as_nanos();
167
168        let mut command_json = serde_json::to_value(&message)
169            .map_err(|e| KittyError::Encryption(EncryptionError::EncryptionFailed(e.to_string())))?;
170
171        if let Some(obj) = command_json.as_object_mut() {
172            obj.insert("password".to_string(), serde_json::json!(password));
173            obj.insert("timestamp".to_string(), serde_json::json!(timestamp));
174        }
175
176        let encrypted = encryptor.encrypt_command(command_json)?;
177
178        Ok(KittyMessage {
179            cmd: String::new(),
180            version: vec![0, 43, 1],
181            no_response: None,
182            kitty_window_id: None,
183            payload: None,
184            async_id: None,
185            cancel_async: None,
186            stream_id: None,
187            stream: None,
188            encrypted: encrypted.get("encrypted").and_then(|v| v.as_str().map(String::from)),
189            iv: encrypted.get("iv").and_then(|v| v.as_str().map(String::from)),
190            tag: encrypted.get("tag").and_then(|v| v.as_str().map(String::from)),
191            pubkey: encrypted.get("pubkey").and_then(|v| v.as_str().map(String::from)),
192        })
193    }
194
195    async fn send(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
196        let encrypted_msg = self.encrypt_command(message.clone())?;
197        let data = encrypted_msg.encode()?;
198
199        timeout(self.timeout, self.stream.write_all(&data))
200            .await
201            .map_err(|_| ConnectionError::TimeoutError(self.timeout))?;
202
203        Ok(())
204    }
205
206    async fn receive(&mut self) -> Result<KittyResponse, KittyError> {
207        const SUFFIX: &[u8] = b"\x1b\\";
208
209        let mut buffer = Vec::new();
210
211        loop {
212            let mut chunk = vec![0u8; 8192];
213            let n = timeout(self.timeout, self.stream.read(&mut chunk))
214                .await
215                .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
216
217            if n == 0 {
218                break;
219            }
220
221            buffer.extend_from_slice(&chunk[..n]);
222
223            if buffer.ends_with(SUFFIX) {
224                break;
225            }
226        }
227
228        if buffer.is_empty() {
229            return Err(KittyError::Connection(ConnectionError::ConnectionClosed));
230        }
231
232        Ok(KittyResponse::decode(&buffer)?)
233    }
234
235    pub async fn execute(&mut self, message: &KittyMessage) -> Result<KittyResponse, KittyError> {
236        self.send(message).await?;
237        self.receive().await
238    }
239
240    pub async fn send_all(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
241        if message.needs_streaming() {
242            for chunk in message.clone().into_chunks() {
243                let encrypted_chunk = self.encrypt_command(chunk)?;
244                self.send(&encrypted_chunk).await?;
245            }
246        } else {
247            let encrypted_msg = self.encrypt_command(message.clone())?;
248            self.send(&encrypted_msg).await?;
249        }
250
251        Ok(())
252    }
253
254    pub async fn execute_all(
255        &mut self,
256        message: &KittyMessage,
257    ) -> Result<KittyResponse, KittyError> {
258        self.send_all(message).await?;
259        self.receive().await
260    }
261
262    pub async fn send_command<T: Into<KittyMessage>>(
263        &mut self,
264        command: T,
265    ) -> Result<(), KittyError> {
266        self.send_all(&command.into()).await
267    }
268
269    pub async fn reconnect(&mut self) -> Result<(), KittyError> {
270        let _ = self.stream.shutdown().await;
271
272        let new_stream = timeout(self.timeout, UnixStream::connect(&self.socket_path))
273            .await
274            .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
275            .map_err(|e| ConnectionError::ConnectionFailed(self.socket_path.clone(), e))?;
276
277        self.stream = new_stream;
278        Ok(())
279    }
280
281    pub async fn close(&mut self) -> Result<(), KittyError> {
282        self.stream.shutdown().await.ok();
283        Ok(())
284    }
285}
286
287impl Drop for Kitty {
288    fn drop(&mut self) {
289        let _ = self.stream.shutdown();
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[test]
298    fn test_builder_creation() {
299        let builder = KittyBuilder::new()
300            .socket_path("/tmp/test.sock")
301            .timeout(Duration::from_secs(5));
302
303        assert_eq!(builder.socket_path, Some("/tmp/test.sock".to_string()));
304        assert_eq!(builder.timeout, Duration::from_secs(5));
305    }
306
307    #[test]
308    fn test_builder_with_password() {
309        let builder = KittyBuilder::new().password("test-password");
310
311        assert_eq!(builder.password, Some("test-password".to_string()));
312    }
313
314    #[test]
315    fn test_builder_with_public_key() {
316        let builder = KittyBuilder::new().public_key("1:abc123");
317
318        assert_eq!(builder.public_key, Some("1:abc123".to_string()));
319    }
320
321    #[test]
322    fn test_builder_from_pid() {
323        let builder = KittyBuilder::new().from_pid(12345);
324
325        assert!(builder.socket_path.is_some());
326        assert!(builder.socket_path.as_ref().unwrap().ends_with("kitty-12345.sock"));
327    }
328
329    #[test]
330    fn test_extract_pid_from_socket_standard() {
331        let pid = KittyBuilder::extract_pid_from_socket("/tmp/kitty-12345.sock");
332        assert_eq!(pid, Some(12345));
333    }
334
335    #[test]
336    fn test_extract_pid_from_socket_xdg_runtime_dir() {
337        let pid = KittyBuilder::extract_pid_from_socket(
338            "/run/user/1000/kitty-67890.sock",
339        );
340        assert_eq!(pid, Some(67890));
341    }
342
343    #[test]
344    fn test_extract_pid_from_socket_invalid() {
345        let pid = KittyBuilder::extract_pid_from_socket("/tmp/invalid.sock");
346        assert_eq!(pid, None);
347    }
348
349    #[test]
350    fn test_extract_pid_from_socket_no_prefix() {
351        let pid = KittyBuilder::extract_pid_from_socket("/tmp/12345.sock");
352        assert_eq!(pid, None);
353    }
354
355    #[test]
356    fn test_extract_pid_from_socket_invalid_pid() {
357        let pid = KittyBuilder::extract_pid_from_socket("/tmp/kitty-abc.sock");
358        assert_eq!(pid, None);
359    }
360
361    #[tokio::test]
362    async fn test_builder_missing_socket() {
363        let builder = KittyBuilder::new();
364        let result: Result<Kitty, KittyError> = builder.connect().await;
365
366        assert!(result.is_err());
367    }
368}