1use crate::encryption::Encryptor;
2use crate::error::{ConnectionError, EncryptionError, KittyError};
3use crate::protocol::{KittyMessage, KittyResponse};
4use std::path::Path;
5use std::process::Command;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::UnixStream;
9use tokio::time::timeout;
10
11pub struct Kitty {
12 stream: UnixStream,
13 timeout: Duration,
14 socket_path: String,
15 password: Option<String>,
16 encryptor: Option<Encryptor>,
17}
18
19pub struct KittyBuilder {
20 socket_path: Option<String>,
21 password: Option<String>,
22 public_key: Option<String>,
23 timeout: Duration,
24}
25
26impl KittyBuilder {
27 pub fn new() -> Self {
28 Self {
29 socket_path: None,
30 password: None,
31 public_key: None,
32 timeout: Duration::from_secs(10),
33 }
34 }
35
36 fn extract_pid_from_socket(socket_path: &str) -> Option<u32> {
37 let filename = Path::new(socket_path)
38 .file_name()?
39 .to_str()?;
40
41 let pid_str = filename.strip_prefix("kitty-")?;
42 let pid_str = pid_str.strip_suffix(".sock")?;
43 pid_str.parse().ok()
44 }
45
46 fn query_public_key_database(pid: u32) -> Result<Option<String>, EncryptionError> {
47 let output = Command::new("kitty-pubkey-db")
48 .arg("get")
49 .arg(pid.to_string())
50 .output()
51 .map_err(|e| {
52 EncryptionError::PublicKeyDatabaseError(format!("Failed to run kitty-pubkey-db: {}", e))
53 })?;
54
55 if !output.status.success() {
56 return Ok(None);
57 }
58
59 let pubkey = String::from_utf8(output.stdout)
60 .map_err(|e| {
61 EncryptionError::PublicKeyDatabaseError(format!("Invalid UTF-8 output: {}", e))
62 })?
63 .trim()
64 .to_string();
65
66 if pubkey.is_empty() {
67 Ok(None)
68 } else {
69 Ok(Some(pubkey))
70 }
71 }
72
73 pub fn socket_path<P: AsRef<Path>>(mut self, path: P) -> Self {
74 self.socket_path = Some(path.as_ref().to_string_lossy().to_string());
75 self
76 }
77
78 pub fn timeout(mut self, duration: Duration) -> Self {
79 self.timeout = duration;
80 self
81 }
82
83 pub fn password(mut self, password: impl Into<String>) -> Self {
84 self.password = Some(password.into());
85 self
86 }
87
88 pub fn public_key(mut self, public_key: impl Into<String>) -> Self {
109 self.public_key = Some(public_key.into());
110 self
111 }
112
113 pub async fn connect(self) -> Result<Kitty, KittyError> {
122 let socket_path = self.socket_path.ok_or_else(|| {
123 KittyError::Connection(ConnectionError::SocketNotFound(
124 "No socket path provided".to_string(),
125 ))
126 })?;
127
128 let stream = timeout(self.timeout, UnixStream::connect(&socket_path))
129 .await
130 .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
131 .map_err(|e| ConnectionError::ConnectionFailed(socket_path.clone(), e))?;
132
133 let encryptor = if self.password.is_some() {
134 let public_key = if let Some(pk) = self.public_key {
135 Some(pk)
136 } else if let Some(pid) = Self::extract_pid_from_socket(&socket_path) {
137 Self::query_public_key_database(pid).map_err(KittyError::Encryption)?
138 } else {
139 None
140 };
141
142 Some(Encryptor::new_with_public_key(public_key.as_deref())?)
143 } else {
144 None
145 };
146
147 Ok(Kitty {
148 stream,
149 timeout: self.timeout,
150 socket_path,
151 password: self.password,
152 encryptor,
153 })
154 }
155}
156
157impl Kitty {
158 pub fn builder() -> KittyBuilder {
159 KittyBuilder::new()
160 }
161
162 fn encrypt_command(&self, mut message: KittyMessage) -> Result<KittyMessage, KittyError> {
163 let Some(encryptor) = &self.encryptor else {
164 return Ok(message);
165 };
166
167 let Some(password) = &self.password else {
168 return Ok(message);
169 };
170
171 let timestamp = SystemTime::now()
172 .duration_since(UNIX_EPOCH)
173 .map_err(|_| {
174 KittyError::Encryption(crate::error::EncryptionError::EncryptionFailed(
175 "Failed to get timestamp".to_string(),
176 ))
177 })?
178 .as_nanos();
179
180 if let Some(payload) = &mut message.payload {
181 if let Some(obj) = payload.as_object_mut() {
182 obj.insert("password".to_string(), serde_json::json!(password));
183 obj.insert("timestamp".to_string(), serde_json::json!(timestamp));
184 }
185 } else {
186 let mut obj = serde_json::Map::new();
187 obj.insert("password".to_string(), serde_json::json!(password));
188 obj.insert("timestamp".to_string(), serde_json::json!(timestamp));
189 message.payload = Some(serde_json::Value::Object(obj));
190 }
191
192 let encrypted_payload = encryptor.encrypt_command(message.payload.unwrap())?;
193 message.payload = Some(encrypted_payload);
194
195 Ok(message)
196 }
197
198 async fn send(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
199 let encrypted_msg = self.encrypt_command(message.clone())?;
200 let data = encrypted_msg.encode()?;
201
202 timeout(self.timeout, self.stream.write_all(&data))
203 .await
204 .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
205
206 Ok(())
207 }
208
209 async fn receive(&mut self) -> Result<KittyResponse, KittyError> {
210 const SUFFIX: &[u8] = b"\x1b\\";
211
212 let mut buffer = Vec::new();
213
214 loop {
215 let mut chunk = vec![0u8; 8192];
216 let n = timeout(self.timeout, self.stream.read(&mut chunk))
217 .await
218 .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
219
220 if n == 0 {
221 break;
222 }
223
224 buffer.extend_from_slice(&chunk[..n]);
225
226 if buffer.ends_with(SUFFIX) {
227 break;
228 }
229 }
230
231 if buffer.is_empty() {
232 return Err(KittyError::Connection(ConnectionError::ConnectionClosed));
233 }
234
235 Ok(KittyResponse::decode(&buffer)?)
236 }
237
238 pub async fn execute(&mut self, message: &KittyMessage) -> Result<KittyResponse, KittyError> {
239 self.send(message).await?;
240 self.receive().await
241 }
242
243 pub async fn send_all(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
244 if message.needs_streaming() {
245 for chunk in message.clone().into_chunks() {
246 let encrypted_chunk = self.encrypt_command(chunk)?;
247 self.send(&encrypted_chunk).await?;
248 }
249 } else {
250 let encrypted_msg = self.encrypt_command(message.clone())?;
251 self.send(&encrypted_msg).await?;
252 }
253
254 Ok(())
255 }
256
257 pub async fn execute_all(
258 &mut self,
259 message: &KittyMessage,
260 ) -> Result<KittyResponse, KittyError> {
261 self.send_all(message).await?;
262 self.receive().await
263 }
264
265 pub async fn send_command<T: Into<KittyMessage>>(
266 &mut self,
267 command: T,
268 ) -> Result<(), KittyError> {
269 self.send_all(&command.into()).await
270 }
271
272 pub async fn reconnect(&mut self) -> Result<(), KittyError> {
273 let _ = self.stream.shutdown().await;
274
275 let new_stream = timeout(self.timeout, UnixStream::connect(&self.socket_path))
276 .await
277 .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
278 .map_err(|e| ConnectionError::ConnectionFailed(self.socket_path.clone(), e))?;
279
280 self.stream = new_stream;
281 Ok(())
282 }
283
284 pub async fn close(&mut self) -> Result<(), KittyError> {
285 self.stream.shutdown().await.ok();
286 Ok(())
287 }
288}
289
290impl Drop for Kitty {
291 fn drop(&mut self) {
292 let _ = self.stream.shutdown();
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_builder_creation() {
302 let builder = KittyBuilder::new()
303 .socket_path("/tmp/test.sock")
304 .timeout(Duration::from_secs(5));
305
306 assert_eq!(builder.socket_path, Some("/tmp/test.sock".to_string()));
307 assert_eq!(builder.timeout, Duration::from_secs(5));
308 }
309
310 #[test]
311 fn test_builder_with_password() {
312 let builder = KittyBuilder::new().password("test-password");
313
314 assert_eq!(builder.password, Some("test-password".to_string()));
315 }
316
317 #[test]
318 fn test_builder_with_public_key() {
319 let builder = KittyBuilder::new().public_key("1:abc123");
320
321 assert_eq!(builder.public_key, Some("1:abc123".to_string()));
322 }
323
324 #[test]
325 fn test_extract_pid_from_socket_standard() {
326 let pid = KittyBuilder::extract_pid_from_socket("/tmp/kitty-12345.sock");
327 assert_eq!(pid, Some(12345));
328 }
329
330 #[test]
331 fn test_extract_pid_from_socket_xdg_runtime_dir() {
332 let pid = KittyBuilder::extract_pid_from_socket(
333 "/run/user/1000/kitty/kitty-67890.sock",
334 );
335 assert_eq!(pid, Some(67890));
336 }
337
338 #[test]
339 fn test_extract_pid_from_socket_invalid() {
340 let pid = KittyBuilder::extract_pid_from_socket("/tmp/invalid.sock");
341 assert_eq!(pid, None);
342 }
343
344 #[test]
345 fn test_extract_pid_from_socket_no_prefix() {
346 let pid = KittyBuilder::extract_pid_from_socket("/tmp/12345.sock");
347 assert_eq!(pid, None);
348 }
349
350 #[test]
351 fn test_extract_pid_from_socket_invalid_pid() {
352 let pid = KittyBuilder::extract_pid_from_socket("/tmp/kitty-abc.sock");
353 assert_eq!(pid, None);
354 }
355
356 #[tokio::test]
357 async fn test_builder_missing_socket() {
358 let builder = KittyBuilder::new();
359 let result = builder.connect().await;
360
361 assert!(result.is_err());
362 }
363}