Skip to main content

kitty_rc/
client.rs

1use crate::encryption::Encryptor;
2use crate::error::{ConnectionError, EncryptionError, KittyError};
3use crate::protocol::{KittyMessage, KittyResponse};
4use std::path::Path;
5use std::process::Command;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::UnixStream;
9use tokio::time::timeout;
10
11pub struct Kitty {
12    stream: UnixStream,
13    timeout: Duration,
14    socket_path: String,
15    password: Option<String>,
16    encryptor: Option<Encryptor>,
17}
18
19pub struct KittyBuilder {
20    socket_path: Option<String>,
21    password: Option<String>,
22    public_key: Option<String>,
23    timeout: Duration,
24}
25
26impl KittyBuilder {
27    pub fn new() -> Self {
28        Self {
29            socket_path: None,
30            password: None,
31            public_key: None,
32            timeout: Duration::from_secs(10),
33        }
34    }
35
36    fn extract_pid_from_socket(socket_path: &str) -> Option<u32> {
37        let filename = Path::new(socket_path)
38            .file_name()?
39            .to_str()?;
40
41        let pid_str = filename.strip_prefix("kitty-")?;
42        let pid_str = pid_str.strip_suffix(".sock")?;
43        pid_str.parse().ok()
44    }
45
46    fn query_public_key_database(pid: u32) -> Result<Option<String>, EncryptionError> {
47        let output = Command::new("kitty-pubkey-db")
48            .arg("get")
49            .arg(pid.to_string())
50            .output()
51            .map_err(|e| {
52                EncryptionError::PublicKeyDatabaseError(format!("Failed to run kitty-pubkey-db: {}", e))
53            })?;
54
55        if !output.status.success() {
56            return Ok(None);
57        }
58
59        let pubkey = String::from_utf8(output.stdout)
60            .map_err(|e| {
61                EncryptionError::PublicKeyDatabaseError(format!("Invalid UTF-8 output: {}", e))
62            })?
63            .trim()
64            .to_string();
65
66        if pubkey.is_empty() {
67            Ok(None)
68        } else {
69            Ok(Some(pubkey))
70        }
71    }
72
73    pub fn socket_path<P: AsRef<Path>>(mut self, path: P) -> Self {
74        self.socket_path = Some(path.as_ref().to_string_lossy().to_string());
75        self
76    }
77
78    pub fn timeout(mut self, duration: Duration) -> Self {
79        self.timeout = duration;
80        self
81    }
82
83    pub fn password(mut self, password: impl Into<String>) -> Self {
84        self.password = Some(password.into());
85        self
86    }
87
88    /// Set kitty's public key explicitly.
89    ///
90    /// Format: `1:<base85_encoded_key>` where `1` is protocol version.
91    ///
92    /// When set, this key is used instead of querying KITTY_PUBLIC_KEY
93    /// env var or kitty-pubkey-db database.
94    ///
95    /// Example:
96    /// ```no_run
97    /// use kitty_rc::Kitty;
98    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
99    /// let kitty = Kitty::builder()
100    ///     .socket_path("/run/user/1000/kitty/kitty-12345.sock")
101    ///     .password("your-password")
102    ///     .public_key("1:z3;{}!NzNzgiXreB&ywA!8y1H8hq^$cMG!OE$QNa")
103    ///     .connect()
104    ///     .await?;
105    /// # Ok(())
106    /// # }
107    /// ```
108    pub fn public_key(mut self, public_key: impl Into<String>) -> Self {
109        self.public_key = Some(public_key.into());
110        self
111    }
112
113    /// Connect to kitty instance with configured authentication.
114    ///
115    /// Public key resolution order (when password is set):
116    /// 1. Explicit key set via `.public_key()` method
117    /// 2. Query kitty-pubkey-db database (extracts PID from socket path)
118    /// 3. KITTY_PUBLIC_KEY environment variable (set by kitty when launching subprocesses)
119    ///
120    /// When no password is set, no encryption is used.
121    pub async fn connect(self) -> Result<Kitty, KittyError> {
122        let socket_path = self.socket_path.ok_or_else(|| {
123            KittyError::Connection(ConnectionError::SocketNotFound(
124                "No socket path provided".to_string(),
125            ))
126        })?;
127
128        let stream = timeout(self.timeout, UnixStream::connect(&socket_path))
129            .await
130            .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
131            .map_err(|e| ConnectionError::ConnectionFailed(socket_path.clone(), e))?;
132
133        let encryptor = if self.password.is_some() {
134            let public_key = if let Some(pk) = self.public_key {
135                Some(pk)
136            } else if let Some(pid) = Self::extract_pid_from_socket(&socket_path) {
137                Self::query_public_key_database(pid).map_err(KittyError::Encryption)?
138            } else {
139                None
140            };
141
142            Some(Encryptor::new_with_public_key(public_key.as_deref())?)
143        } else {
144            None
145        };
146
147        Ok(Kitty {
148            stream,
149            timeout: self.timeout,
150            socket_path,
151            password: self.password,
152            encryptor,
153        })
154    }
155}
156
157impl Kitty {
158    pub fn builder() -> KittyBuilder {
159        KittyBuilder::new()
160    }
161
162    fn encrypt_command(&self, mut message: KittyMessage) -> Result<KittyMessage, KittyError> {
163        let Some(encryptor) = &self.encryptor else {
164            return Ok(message);
165        };
166
167        let Some(password) = &self.password else {
168            return Ok(message);
169        };
170
171        let timestamp = SystemTime::now()
172            .duration_since(UNIX_EPOCH)
173            .map_err(|_| {
174                KittyError::Encryption(crate::error::EncryptionError::EncryptionFailed(
175                    "Failed to get timestamp".to_string(),
176                ))
177            })?
178            .as_nanos();
179
180        if let Some(payload) = &mut message.payload {
181            if let Some(obj) = payload.as_object_mut() {
182                obj.insert("password".to_string(), serde_json::json!(password));
183                obj.insert("timestamp".to_string(), serde_json::json!(timestamp));
184            }
185        } else {
186            let mut obj = serde_json::Map::new();
187            obj.insert("password".to_string(), serde_json::json!(password));
188            obj.insert("timestamp".to_string(), serde_json::json!(timestamp));
189            message.payload = Some(serde_json::Value::Object(obj));
190        }
191
192        let encrypted_payload = encryptor.encrypt_command(message.payload.unwrap())?;
193        message.payload = Some(encrypted_payload);
194
195        Ok(message)
196    }
197
198    async fn send(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
199        let encrypted_msg = self.encrypt_command(message.clone())?;
200        let data = encrypted_msg.encode()?;
201
202        timeout(self.timeout, self.stream.write_all(&data))
203            .await
204            .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
205
206        Ok(())
207    }
208
209    async fn receive(&mut self) -> Result<KittyResponse, KittyError> {
210        const SUFFIX: &[u8] = b"\x1b\\";
211
212        let mut buffer = Vec::new();
213
214        loop {
215            let mut chunk = vec![0u8; 8192];
216            let n = timeout(self.timeout, self.stream.read(&mut chunk))
217                .await
218                .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
219
220            if n == 0 {
221                break;
222            }
223
224            buffer.extend_from_slice(&chunk[..n]);
225
226            if buffer.ends_with(SUFFIX) {
227                break;
228            }
229        }
230
231        if buffer.is_empty() {
232            return Err(KittyError::Connection(ConnectionError::ConnectionClosed));
233        }
234
235        Ok(KittyResponse::decode(&buffer)?)
236    }
237
238    pub async fn execute(&mut self, message: &KittyMessage) -> Result<KittyResponse, KittyError> {
239        self.send(message).await?;
240        self.receive().await
241    }
242
243    pub async fn send_all(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
244        if message.needs_streaming() {
245            for chunk in message.clone().into_chunks() {
246                let encrypted_chunk = self.encrypt_command(chunk)?;
247                self.send(&encrypted_chunk).await?;
248            }
249        } else {
250            let encrypted_msg = self.encrypt_command(message.clone())?;
251            self.send(&encrypted_msg).await?;
252        }
253
254        Ok(())
255    }
256
257    pub async fn execute_all(
258        &mut self,
259        message: &KittyMessage,
260    ) -> Result<KittyResponse, KittyError> {
261        self.send_all(message).await?;
262        self.receive().await
263    }
264
265    pub async fn send_command<T: Into<KittyMessage>>(
266        &mut self,
267        command: T,
268    ) -> Result<(), KittyError> {
269        self.send_all(&command.into()).await
270    }
271
272    pub async fn reconnect(&mut self) -> Result<(), KittyError> {
273        let _ = self.stream.shutdown().await;
274
275        let new_stream = timeout(self.timeout, UnixStream::connect(&self.socket_path))
276            .await
277            .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
278            .map_err(|e| ConnectionError::ConnectionFailed(self.socket_path.clone(), e))?;
279
280        self.stream = new_stream;
281        Ok(())
282    }
283
284    pub async fn close(&mut self) -> Result<(), KittyError> {
285        self.stream.shutdown().await.ok();
286        Ok(())
287    }
288}
289
290impl Drop for Kitty {
291    fn drop(&mut self) {
292        let _ = self.stream.shutdown();
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn test_builder_creation() {
302        let builder = KittyBuilder::new()
303            .socket_path("/tmp/test.sock")
304            .timeout(Duration::from_secs(5));
305
306        assert_eq!(builder.socket_path, Some("/tmp/test.sock".to_string()));
307        assert_eq!(builder.timeout, Duration::from_secs(5));
308    }
309
310    #[test]
311    fn test_builder_with_password() {
312        let builder = KittyBuilder::new().password("test-password");
313
314        assert_eq!(builder.password, Some("test-password".to_string()));
315    }
316
317    #[test]
318    fn test_builder_with_public_key() {
319        let builder = KittyBuilder::new().public_key("1:abc123");
320
321        assert_eq!(builder.public_key, Some("1:abc123".to_string()));
322    }
323
324    #[test]
325    fn test_extract_pid_from_socket_standard() {
326        let pid = KittyBuilder::extract_pid_from_socket("/tmp/kitty-12345.sock");
327        assert_eq!(pid, Some(12345));
328    }
329
330    #[test]
331    fn test_extract_pid_from_socket_xdg_runtime_dir() {
332        let pid = KittyBuilder::extract_pid_from_socket(
333            "/run/user/1000/kitty/kitty-67890.sock",
334        );
335        assert_eq!(pid, Some(67890));
336    }
337
338    #[test]
339    fn test_extract_pid_from_socket_invalid() {
340        let pid = KittyBuilder::extract_pid_from_socket("/tmp/invalid.sock");
341        assert_eq!(pid, None);
342    }
343
344    #[test]
345    fn test_extract_pid_from_socket_no_prefix() {
346        let pid = KittyBuilder::extract_pid_from_socket("/tmp/12345.sock");
347        assert_eq!(pid, None);
348    }
349
350    #[test]
351    fn test_extract_pid_from_socket_invalid_pid() {
352        let pid = KittyBuilder::extract_pid_from_socket("/tmp/kitty-abc.sock");
353        assert_eq!(pid, None);
354    }
355
356    #[tokio::test]
357    async fn test_builder_missing_socket() {
358        let builder = KittyBuilder::new();
359        let result = builder.connect().await;
360
361        assert!(result.is_err());
362    }
363}