1use crate::encryption::Encryptor;
2use crate::error::{ConnectionError, EncryptionError, KittyError};
3use crate::protocol::{KittyMessage, KittyResponse};
4use std::path::Path;
5use std::process::Command;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::UnixStream;
9use tokio::time::timeout;
10use xdg::BaseDirectories;
11
12pub struct Kitty {
13 stream: UnixStream,
14 timeout: Duration,
15 socket_path: String,
16 password: Option<String>,
17 encryptor: Option<Encryptor>,
18}
19
20pub struct KittyBuilder {
21 socket_path: Option<String>,
22 password: Option<String>,
23 public_key: Option<String>,
24 timeout: Duration,
25}
26
27impl Default for KittyBuilder {
28 fn default() -> Self {
29 Self {
30 socket_path: None,
31 password: None,
32 public_key: None,
33 timeout: Duration::from_secs(10),
34 }
35 }
36}
37
38impl KittyBuilder {
39 pub fn new() -> Self {
40 Self::default()
41 }
42
43 pub fn socket_path<P: AsRef<Path>>(mut self, path: P) -> Self {
44 self.socket_path = Some(path.as_ref().to_string_lossy().to_string());
45 self
46 }
47
48 pub fn from_pid(mut self, pid: u32) -> Self {
49 let xdg_dirs = BaseDirectories::new();
50 let runtime_dir = xdg_dirs.runtime_dir.clone()
51 .unwrap_or_else(|| Path::new("/tmp").to_path_buf());
52 let socket_path = runtime_dir.join(format!("kitty-{}.sock", pid));
53 self.socket_path = Some(socket_path.to_string_lossy().to_string());
54 self
55 }
56
57 pub fn timeout(mut self, duration: Duration) -> Self {
58 self.timeout = duration;
59 self
60 }
61
62 pub fn password(mut self, password: impl Into<String>) -> Self {
63 self.password = Some(password.into());
64 self
65 }
66
67 pub fn public_key(mut self, public_key: impl Into<String>) -> Self {
68 self.public_key = Some(public_key.into());
69 self
70 }
71
72 pub async fn connect(self) -> Result<Kitty, KittyError> {
73 let socket_path = self.socket_path.ok_or_else(|| {
74 KittyError::Connection(ConnectionError::SocketNotFound(
75 "No socket path provided".to_string(),
76 ))
77 })?;
78
79 let stream = timeout(self.timeout, UnixStream::connect(&socket_path))
80 .await
81 .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
82 .map_err(|e| ConnectionError::ConnectionFailed(socket_path.clone(), e))?;
83
84 let encryptor = if self.password.is_some() {
85 let public_key = if let Some(pk) = self.public_key {
86 Some(pk)
87 } else if let Some(pid) = Self::extract_pid_from_socket(&socket_path) {
88 Self::query_public_key_database(pid).map_err(KittyError::Encryption)?
89 } else {
90 None
91 };
92
93 Some(Encryptor::new_with_public_key(public_key.as_deref())?)
94 } else {
95 None
96 };
97
98 Ok(Kitty {
99 stream,
100 timeout: self.timeout,
101 socket_path,
102 password: self.password,
103 encryptor,
104 })
105 }
106
107 fn extract_pid_from_socket(socket_path: &str) -> Option<u32> {
108 let filename = Path::new(socket_path)
109 .file_name()?
110 .to_str()?;
111
112 let pid_str = filename.strip_prefix("kitty-")?;
113 let pid_str = pid_str.strip_suffix(".sock")?;
114 pid_str.parse().ok()
115 }
116
117 fn query_public_key_database(pid: u32) -> Result<Option<String>, EncryptionError> {
118 let output = Command::new("kitty-pubkey-db")
119 .arg("get")
120 .arg(pid.to_string())
121 .output()
122 .map_err(|e| {
123 EncryptionError::PublicKeyDatabaseError(format!("Failed to run kitty-pubkey-db: {}", e))
124 })?;
125
126 if !output.status.success() {
127 return Ok(None);
128 }
129
130 let pubkey = String::from_utf8(output.stdout)
131 .map_err(|e| {
132 EncryptionError::PublicKeyDatabaseError(format!("Invalid UTF-8 output: {}", e))
133 })?
134 .trim()
135 .to_string();
136
137 if pubkey.is_empty() {
138 Ok(None)
139 } else {
140 Ok(Some(pubkey))
141 }
142 }
143}
144
145impl Kitty {
146 pub fn builder() -> KittyBuilder {
147 KittyBuilder::new()
148 }
149
150 fn encrypt_command(&self, message: KittyMessage) -> Result<KittyMessage, KittyError> {
151 let Some(encryptor) = &self.encryptor else {
152 return Ok(message);
153 };
154
155 let Some(password) = &self.password else {
156 return Ok(message);
157 };
158
159 let timestamp = SystemTime::now()
160 .duration_since(UNIX_EPOCH)
161 .map_err(|_| {
162 KittyError::Encryption(crate::error::EncryptionError::EncryptionFailed(
163 "Failed to get timestamp".to_string(),
164 ))
165 })?
166 .as_nanos();
167
168 let mut command_json = serde_json::to_value(&message)
169 .map_err(|e| KittyError::Encryption(EncryptionError::EncryptionFailed(e.to_string())))?;
170
171 if let Some(obj) = command_json.as_object_mut() {
172 obj.insert("password".to_string(), serde_json::json!(password));
173 obj.insert("timestamp".to_string(), serde_json::json!(timestamp));
174 }
175
176 let encrypted = encryptor.encrypt_command(command_json)?;
177
178 Ok(KittyMessage {
179 cmd: String::new(),
180 version: vec![0, 43, 1],
181 no_response: None,
182 kitty_window_id: None,
183 payload: None,
184 async_id: None,
185 cancel_async: None,
186 stream_id: None,
187 stream: None,
188 encrypted: encrypted.get("encrypted").and_then(|v| v.as_str().map(String::from)),
189 iv: encrypted.get("iv").and_then(|v| v.as_str().map(String::from)),
190 tag: encrypted.get("tag").and_then(|v| v.as_str().map(String::from)),
191 pubkey: encrypted.get("pubkey").and_then(|v| v.as_str().map(String::from)),
192 })
193 }
194
195 async fn send(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
196 let encrypted_msg = self.encrypt_command(message.clone())?;
197 let data = encrypted_msg.encode()?;
198
199 timeout(self.timeout, self.stream.write_all(&data))
200 .await
201 .map_err(|_| ConnectionError::TimeoutError(self.timeout))?;
202
203 Ok(())
204 }
205
206 async fn receive(&mut self) -> Result<KittyResponse, KittyError> {
207 const SUFFIX: &[u8] = b"\x1b\\";
208
209 let mut buffer = Vec::new();
210
211 loop {
212 let mut chunk = vec![0u8; 8192];
213 let n = timeout(self.timeout, self.stream.read(&mut chunk))
214 .await
215 .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
216
217 if n == 0 {
218 break;
219 }
220
221 buffer.extend_from_slice(&chunk[..n]);
222
223 if buffer.ends_with(SUFFIX) {
224 break;
225 }
226 }
227
228 if buffer.is_empty() {
229 return Err(KittyError::Connection(ConnectionError::ConnectionClosed));
230 }
231
232 Ok(KittyResponse::decode(&buffer)?)
233 }
234
235 pub async fn execute(&mut self, message: &KittyMessage) -> Result<KittyResponse, KittyError> {
236 self.send(message).await?;
237 self.receive().await
238 }
239
240 pub async fn send_all(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
241 if message.needs_streaming() {
242 for chunk in message.clone().into_chunks() {
243 let encrypted_chunk = self.encrypt_command(chunk)?;
244 self.send(&encrypted_chunk).await?;
245 }
246 } else {
247 let encrypted_msg = self.encrypt_command(message.clone())?;
248 self.send(&encrypted_msg).await?;
249 }
250
251 Ok(())
252 }
253
254 pub async fn execute_all(
255 &mut self,
256 message: &KittyMessage,
257 ) -> Result<KittyResponse, KittyError> {
258 self.send_all(message).await?;
259 self.receive().await
260 }
261
262 pub async fn send_command<T: Into<KittyMessage>>(
263 &mut self,
264 command: T,
265 ) -> Result<(), KittyError> {
266 self.send_all(&command.into()).await
267 }
268
269 pub async fn reconnect(&mut self) -> Result<(), KittyError> {
270 let _ = self.stream.shutdown().await;
271
272 let new_stream = timeout(self.timeout, UnixStream::connect(&self.socket_path))
273 .await
274 .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
275 .map_err(|e| ConnectionError::ConnectionFailed(self.socket_path.clone(), e))?;
276
277 self.stream = new_stream;
278 Ok(())
279 }
280
281 pub async fn close(&mut self) -> Result<(), KittyError> {
282 self.stream.shutdown().await.ok();
283 Ok(())
284 }
285}
286
287impl Drop for Kitty {
288 fn drop(&mut self) {
289 let _ = self.stream.shutdown();
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296
297 #[test]
298 fn test_builder_creation() {
299 let builder = KittyBuilder::new()
300 .socket_path("/tmp/test.sock")
301 .timeout(Duration::from_secs(5));
302
303 assert_eq!(builder.socket_path, Some("/tmp/test.sock".to_string()));
304 assert_eq!(builder.timeout, Duration::from_secs(5));
305 }
306
307 #[test]
308 fn test_builder_with_password() {
309 let builder = KittyBuilder::new().password("test-password");
310
311 assert_eq!(builder.password, Some("test-password".to_string()));
312 }
313
314 #[test]
315 fn test_builder_with_public_key() {
316 let builder = KittyBuilder::new().public_key("1:abc123");
317
318 assert_eq!(builder.public_key, Some("1:abc123".to_string()));
319 }
320
321 #[test]
322 fn test_builder_from_pid() {
323 let builder = KittyBuilder::new().from_pid(12345);
324
325 assert!(builder.socket_path.is_some());
326 assert!(builder.socket_path.as_ref().unwrap().ends_with("kitty-12345.sock"));
327 }
328
329 #[test]
330 fn test_extract_pid_from_socket_standard() {
331 let pid = KittyBuilder::extract_pid_from_socket("/tmp/kitty-12345.sock");
332 assert_eq!(pid, Some(12345));
333 }
334
335 #[test]
336 fn test_extract_pid_from_socket_xdg_runtime_dir() {
337 let pid = KittyBuilder::extract_pid_from_socket(
338 "/run/user/1000/kitty-67890.sock",
339 );
340 assert_eq!(pid, Some(67890));
341 }
342
343 #[test]
344 fn test_extract_pid_from_socket_invalid() {
345 let pid = KittyBuilder::extract_pid_from_socket("/tmp/invalid.sock");
346 assert_eq!(pid, None);
347 }
348
349 #[test]
350 fn test_extract_pid_from_socket_no_prefix() {
351 let pid = KittyBuilder::extract_pid_from_socket("/tmp/12345.sock");
352 assert_eq!(pid, None);
353 }
354
355 #[test]
356 fn test_extract_pid_from_socket_invalid_pid() {
357 let pid = KittyBuilder::extract_pid_from_socket("/tmp/kitty-abc.sock");
358 assert_eq!(pid, None);
359 }
360
361 #[tokio::test]
362 async fn test_builder_missing_socket() {
363 let builder = KittyBuilder::new();
364 let result: Result<Kitty, KittyError> = builder.connect().await;
365
366 assert!(result.is_err());
367 }
368}