Skip to main content

kitty_rc/
client.rs

1use crate::encryption::Encryptor;
2use crate::error::{ConnectionError, EncryptionError, KittyError};
3use crate::protocol::{KittyMessage, KittyResponse};
4use std::path::Path;
5use std::process::Command;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::UnixStream;
9use tokio::time::timeout;
10
11pub struct Kitty {
12    stream: UnixStream,
13    timeout: Duration,
14    socket_path: String,
15    password: Option<String>,
16    encryptor: Option<Encryptor>,
17}
18
19pub struct KittyBuilder {
20    socket_path: Option<String>,
21    password: Option<String>,
22    public_key: Option<String>,
23    timeout: Duration,
24}
25
26impl KittyBuilder {
27    pub fn new() -> Self {
28        Self {
29            socket_path: None,
30            password: None,
31            public_key: None,
32            timeout: Duration::from_secs(10),
33        }
34    }
35
36    fn extract_pid_from_socket(socket_path: &str) -> Option<u32> {
37        let filename = Path::new(socket_path)
38            .file_name()?
39            .to_str()?;
40
41        let pid_str = filename.strip_prefix("kitty-")?;
42        let pid_str = pid_str.strip_suffix(".sock")?;
43        pid_str.parse().ok()
44    }
45
46    fn query_public_key_database(pid: u32) -> Result<Option<String>, EncryptionError> {
47        let output = Command::new("kitty-pubkey-db")
48            .arg("get")
49            .arg(pid.to_string())
50            .output()
51            .map_err(|e| {
52                EncryptionError::PublicKeyDatabaseError(format!("Failed to run kitty-pubkey-db: {}", e))
53            })?;
54
55        if !output.status.success() {
56            return Ok(None);
57        }
58
59        let pubkey = String::from_utf8(output.stdout)
60            .map_err(|e| {
61                EncryptionError::PublicKeyDatabaseError(format!("Invalid UTF-8 output: {}", e))
62            })?
63            .trim()
64            .to_string();
65
66        if pubkey.is_empty() {
67            Ok(None)
68        } else {
69            Ok(Some(pubkey))
70        }
71    }
72
73    pub fn socket_path<P: AsRef<Path>>(mut self, path: P) -> Self {
74        self.socket_path = Some(path.as_ref().to_string_lossy().to_string());
75        self
76    }
77
78    pub fn timeout(mut self, duration: Duration) -> Self {
79        self.timeout = duration;
80        self
81    }
82
83    pub fn password(mut self, password: impl Into<String>) -> Self {
84        self.password = Some(password.into());
85        self
86    }
87
88    pub fn public_key(mut self, public_key: impl Into<String>) -> Self {
89        self.public_key = Some(public_key.into());
90        self
91    }
92
93    pub async fn connect(self) -> Result<Kitty, KittyError> {
94        let socket_path = self.socket_path.ok_or_else(|| {
95            KittyError::Connection(ConnectionError::SocketNotFound(
96                "No socket path provided".to_string(),
97            ))
98        })?;
99
100        let stream = timeout(self.timeout, UnixStream::connect(&socket_path))
101            .await
102            .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
103            .map_err(|e| ConnectionError::ConnectionFailed(socket_path.clone(), e))?;
104
105        let encryptor = if self.password.is_some() {
106            let public_key = if let Some(pk) = self.public_key {
107                Some(pk)
108            } else if std::env::var("KITTY_PUBLIC_KEY").is_ok() {
109                None
110            } else if let Some(pid) = Self::extract_pid_from_socket(&socket_path) {
111                Self::query_public_key_database(pid).map_err(KittyError::Encryption)?
112            } else {
113                None
114            };
115
116            Some(Encryptor::new_with_public_key(public_key.as_deref())?)
117        } else {
118            None
119        };
120
121        Ok(Kitty {
122            stream,
123            timeout: self.timeout,
124            socket_path,
125            password: self.password,
126            encryptor,
127        })
128    }
129}
130
131impl Kitty {
132    pub fn builder() -> KittyBuilder {
133        KittyBuilder::new()
134    }
135
136    fn encrypt_command(&self, mut message: KittyMessage) -> Result<KittyMessage, KittyError> {
137        let Some(encryptor) = &self.encryptor else {
138            return Ok(message);
139        };
140
141        let Some(password) = &self.password else {
142            return Ok(message);
143        };
144
145        let timestamp = SystemTime::now()
146            .duration_since(UNIX_EPOCH)
147            .map_err(|_| {
148                KittyError::Encryption(crate::error::EncryptionError::EncryptionFailed(
149                    "Failed to get timestamp".to_string(),
150                ))
151            })?
152            .as_nanos();
153
154        if let Some(payload) = &mut message.payload {
155            if let Some(obj) = payload.as_object_mut() {
156                obj.insert("password".to_string(), serde_json::json!(password));
157                obj.insert("timestamp".to_string(), serde_json::json!(timestamp));
158            }
159        } else {
160            let mut obj = serde_json::Map::new();
161            obj.insert("password".to_string(), serde_json::json!(password));
162            obj.insert("timestamp".to_string(), serde_json::json!(timestamp));
163            message.payload = Some(serde_json::Value::Object(obj));
164        }
165
166        let encrypted_payload = encryptor.encrypt_command(message.payload.unwrap())?;
167        message.payload = Some(encrypted_payload);
168
169        Ok(message)
170    }
171
172    async fn send(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
173        let encrypted_msg = self.encrypt_command(message.clone())?;
174        let data = encrypted_msg.encode()?;
175
176        timeout(self.timeout, self.stream.write_all(&data))
177            .await
178            .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
179
180        Ok(())
181    }
182
183    async fn receive(&mut self) -> Result<KittyResponse, KittyError> {
184        const SUFFIX: &[u8] = b"\x1b\\";
185
186        let mut buffer = Vec::new();
187
188        loop {
189            let mut chunk = vec![0u8; 8192];
190            let n = timeout(self.timeout, self.stream.read(&mut chunk))
191                .await
192                .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
193
194            if n == 0 {
195                break;
196            }
197
198            buffer.extend_from_slice(&chunk[..n]);
199
200            if buffer.ends_with(SUFFIX) {
201                break;
202            }
203        }
204
205        if buffer.is_empty() {
206            return Err(KittyError::Connection(ConnectionError::ConnectionClosed));
207        }
208
209        Ok(KittyResponse::decode(&buffer)?)
210    }
211
212    pub async fn execute(&mut self, message: &KittyMessage) -> Result<KittyResponse, KittyError> {
213        self.send(message).await?;
214        self.receive().await
215    }
216
217    pub async fn send_all(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
218        if message.needs_streaming() {
219            for chunk in message.clone().into_chunks() {
220                let encrypted_chunk = self.encrypt_command(chunk)?;
221                self.send(&encrypted_chunk).await?;
222            }
223        } else {
224            let encrypted_msg = self.encrypt_command(message.clone())?;
225            self.send(&encrypted_msg).await?;
226        }
227
228        Ok(())
229    }
230
231    pub async fn execute_all(
232        &mut self,
233        message: &KittyMessage,
234    ) -> Result<KittyResponse, KittyError> {
235        self.send_all(message).await?;
236        self.receive().await
237    }
238
239    pub async fn send_command<T: Into<KittyMessage>>(
240        &mut self,
241        command: T,
242    ) -> Result<(), KittyError> {
243        self.send_all(&command.into()).await
244    }
245
246    pub async fn reconnect(&mut self) -> Result<(), KittyError> {
247        let _ = self.stream.shutdown().await;
248
249        let new_stream = timeout(self.timeout, UnixStream::connect(&self.socket_path))
250            .await
251            .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
252            .map_err(|e| ConnectionError::ConnectionFailed(self.socket_path.clone(), e))?;
253
254        self.stream = new_stream;
255        Ok(())
256    }
257
258    pub async fn close(&mut self) -> Result<(), KittyError> {
259        self.stream.shutdown().await.ok();
260        Ok(())
261    }
262}
263
264impl Drop for Kitty {
265    fn drop(&mut self) {
266        let _ = self.stream.shutdown();
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_builder_creation() {
276        let builder = KittyBuilder::new()
277            .socket_path("/tmp/test.sock")
278            .timeout(Duration::from_secs(5));
279
280        assert_eq!(builder.socket_path, Some("/tmp/test.sock".to_string()));
281        assert_eq!(builder.timeout, Duration::from_secs(5));
282    }
283
284    #[test]
285    fn test_builder_with_password() {
286        let builder = KittyBuilder::new().password("test-password");
287
288        assert_eq!(builder.password, Some("test-password".to_string()));
289    }
290
291    #[test]
292    fn test_builder_with_public_key() {
293        let builder = KittyBuilder::new().public_key("1:abc123");
294
295        assert_eq!(builder.public_key, Some("1:abc123".to_string()));
296    }
297
298    #[test]
299    fn test_extract_pid_from_socket_standard() {
300        let pid = KittyBuilder::extract_pid_from_socket("/tmp/kitty-12345.sock");
301        assert_eq!(pid, Some(12345));
302    }
303
304    #[test]
305    fn test_extract_pid_from_socket_xdg_runtime_dir() {
306        let pid = KittyBuilder::extract_pid_from_socket(
307            "/run/user/1000/kitty/kitty-67890.sock",
308        );
309        assert_eq!(pid, Some(67890));
310    }
311
312    #[test]
313    fn test_extract_pid_from_socket_invalid() {
314        let pid = KittyBuilder::extract_pid_from_socket("/tmp/invalid.sock");
315        assert_eq!(pid, None);
316    }
317
318    #[test]
319    fn test_extract_pid_from_socket_no_prefix() {
320        let pid = KittyBuilder::extract_pid_from_socket("/tmp/12345.sock");
321        assert_eq!(pid, None);
322    }
323
324    #[test]
325    fn test_extract_pid_from_socket_invalid_pid() {
326        let pid = KittyBuilder::extract_pid_from_socket("/tmp/kitty-abc.sock");
327        assert_eq!(pid, None);
328    }
329
330    #[tokio::test]
331    async fn test_builder_missing_socket() {
332        let builder = KittyBuilder::new();
333        let result = builder.connect().await;
334
335        assert!(result.is_err());
336    }
337}