1use crate::encryption::Encryptor;
2use crate::error::{ConnectionError, EncryptionError, KittyError};
3use crate::protocol::{KittyMessage, KittyResponse};
4use std::path::Path;
5use std::process::Command;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::UnixStream;
9use tokio::time::timeout;
10
11pub struct Kitty {
12 stream: UnixStream,
13 timeout: Duration,
14 socket_path: String,
15 password: Option<String>,
16 encryptor: Option<Encryptor>,
17}
18
19pub struct KittyBuilder {
20 socket_path: Option<String>,
21 password: Option<String>,
22 public_key: Option<String>,
23 timeout: Duration,
24}
25
26impl KittyBuilder {
27 pub fn new() -> Self {
28 Self {
29 socket_path: None,
30 password: None,
31 public_key: None,
32 timeout: Duration::from_secs(10),
33 }
34 }
35
36 fn extract_pid_from_socket(socket_path: &str) -> Option<u32> {
37 let filename = Path::new(socket_path)
38 .file_name()?
39 .to_str()?;
40
41 let pid_str = filename.strip_prefix("kitty-")?;
42 let pid_str = pid_str.strip_suffix(".sock")?;
43 pid_str.parse().ok()
44 }
45
46 fn query_public_key_database(pid: u32) -> Result<Option<String>, EncryptionError> {
47 let output = Command::new("kitty-pubkey-db")
48 .arg("get")
49 .arg(pid.to_string())
50 .output()
51 .map_err(|e| {
52 EncryptionError::PublicKeyDatabaseError(format!("Failed to run kitty-pubkey-db: {}", e))
53 })?;
54
55 if !output.status.success() {
56 return Ok(None);
57 }
58
59 let pubkey = String::from_utf8(output.stdout)
60 .map_err(|e| {
61 EncryptionError::PublicKeyDatabaseError(format!("Invalid UTF-8 output: {}", e))
62 })?
63 .trim()
64 .to_string();
65
66 if pubkey.is_empty() {
67 Ok(None)
68 } else {
69 Ok(Some(pubkey))
70 }
71 }
72
73 pub fn socket_path<P: AsRef<Path>>(mut self, path: P) -> Self {
74 self.socket_path = Some(path.as_ref().to_string_lossy().to_string());
75 self
76 }
77
78 pub fn timeout(mut self, duration: Duration) -> Self {
79 self.timeout = duration;
80 self
81 }
82
83 pub fn password(mut self, password: impl Into<String>) -> Self {
84 self.password = Some(password.into());
85 self
86 }
87
88 pub fn public_key(mut self, public_key: impl Into<String>) -> Self {
89 self.public_key = Some(public_key.into());
90 self
91 }
92
93 pub async fn connect(self) -> Result<Kitty, KittyError> {
94 let socket_path = self.socket_path.ok_or_else(|| {
95 KittyError::Connection(ConnectionError::SocketNotFound(
96 "No socket path provided".to_string(),
97 ))
98 })?;
99
100 let stream = timeout(self.timeout, UnixStream::connect(&socket_path))
101 .await
102 .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
103 .map_err(|e| ConnectionError::ConnectionFailed(socket_path.clone(), e))?;
104
105 let encryptor = if self.password.is_some() {
106 let public_key = if let Some(pk) = self.public_key {
107 Some(pk)
108 } else if std::env::var("KITTY_PUBLIC_KEY").is_ok() {
109 None
110 } else if let Some(pid) = Self::extract_pid_from_socket(&socket_path) {
111 Self::query_public_key_database(pid).map_err(KittyError::Encryption)?
112 } else {
113 None
114 };
115
116 Some(Encryptor::new_with_public_key(public_key.as_deref())?)
117 } else {
118 None
119 };
120
121 Ok(Kitty {
122 stream,
123 timeout: self.timeout,
124 socket_path,
125 password: self.password,
126 encryptor,
127 })
128 }
129}
130
131impl Kitty {
132 pub fn builder() -> KittyBuilder {
133 KittyBuilder::new()
134 }
135
136 fn encrypt_command(&self, mut message: KittyMessage) -> Result<KittyMessage, KittyError> {
137 let Some(encryptor) = &self.encryptor else {
138 return Ok(message);
139 };
140
141 let Some(password) = &self.password else {
142 return Ok(message);
143 };
144
145 let timestamp = SystemTime::now()
146 .duration_since(UNIX_EPOCH)
147 .map_err(|_| {
148 KittyError::Encryption(crate::error::EncryptionError::EncryptionFailed(
149 "Failed to get timestamp".to_string(),
150 ))
151 })?
152 .as_nanos();
153
154 if let Some(payload) = &mut message.payload {
155 if let Some(obj) = payload.as_object_mut() {
156 obj.insert("password".to_string(), serde_json::json!(password));
157 obj.insert("timestamp".to_string(), serde_json::json!(timestamp));
158 }
159 } else {
160 let mut obj = serde_json::Map::new();
161 obj.insert("password".to_string(), serde_json::json!(password));
162 obj.insert("timestamp".to_string(), serde_json::json!(timestamp));
163 message.payload = Some(serde_json::Value::Object(obj));
164 }
165
166 let encrypted_payload = encryptor.encrypt_command(message.payload.unwrap())?;
167 message.payload = Some(encrypted_payload);
168
169 Ok(message)
170 }
171
172 async fn send(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
173 let encrypted_msg = self.encrypt_command(message.clone())?;
174 let data = encrypted_msg.encode()?;
175
176 timeout(self.timeout, self.stream.write_all(&data))
177 .await
178 .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
179
180 Ok(())
181 }
182
183 async fn receive(&mut self) -> Result<KittyResponse, KittyError> {
184 const SUFFIX: &[u8] = b"\x1b\\";
185
186 let mut buffer = Vec::new();
187
188 loop {
189 let mut chunk = vec![0u8; 8192];
190 let n = timeout(self.timeout, self.stream.read(&mut chunk))
191 .await
192 .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
193
194 if n == 0 {
195 break;
196 }
197
198 buffer.extend_from_slice(&chunk[..n]);
199
200 if buffer.ends_with(SUFFIX) {
201 break;
202 }
203 }
204
205 if buffer.is_empty() {
206 return Err(KittyError::Connection(ConnectionError::ConnectionClosed));
207 }
208
209 Ok(KittyResponse::decode(&buffer)?)
210 }
211
212 pub async fn execute(&mut self, message: &KittyMessage) -> Result<KittyResponse, KittyError> {
213 self.send(message).await?;
214 self.receive().await
215 }
216
217 pub async fn send_all(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
218 if message.needs_streaming() {
219 for chunk in message.clone().into_chunks() {
220 let encrypted_chunk = self.encrypt_command(chunk)?;
221 self.send(&encrypted_chunk).await?;
222 }
223 } else {
224 let encrypted_msg = self.encrypt_command(message.clone())?;
225 self.send(&encrypted_msg).await?;
226 }
227
228 Ok(())
229 }
230
231 pub async fn execute_all(
232 &mut self,
233 message: &KittyMessage,
234 ) -> Result<KittyResponse, KittyError> {
235 self.send_all(message).await?;
236 self.receive().await
237 }
238
239 pub async fn send_command<T: Into<KittyMessage>>(
240 &mut self,
241 command: T,
242 ) -> Result<(), KittyError> {
243 self.send_all(&command.into()).await
244 }
245
246 pub async fn reconnect(&mut self) -> Result<(), KittyError> {
247 let _ = self.stream.shutdown().await;
248
249 let new_stream = timeout(self.timeout, UnixStream::connect(&self.socket_path))
250 .await
251 .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
252 .map_err(|e| ConnectionError::ConnectionFailed(self.socket_path.clone(), e))?;
253
254 self.stream = new_stream;
255 Ok(())
256 }
257
258 pub async fn close(&mut self) -> Result<(), KittyError> {
259 self.stream.shutdown().await.ok();
260 Ok(())
261 }
262}
263
264impl Drop for Kitty {
265 fn drop(&mut self) {
266 let _ = self.stream.shutdown();
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_builder_creation() {
276 let builder = KittyBuilder::new()
277 .socket_path("/tmp/test.sock")
278 .timeout(Duration::from_secs(5));
279
280 assert_eq!(builder.socket_path, Some("/tmp/test.sock".to_string()));
281 assert_eq!(builder.timeout, Duration::from_secs(5));
282 }
283
284 #[test]
285 fn test_builder_with_password() {
286 let builder = KittyBuilder::new().password("test-password");
287
288 assert_eq!(builder.password, Some("test-password".to_string()));
289 }
290
291 #[test]
292 fn test_builder_with_public_key() {
293 let builder = KittyBuilder::new().public_key("1:abc123");
294
295 assert_eq!(builder.public_key, Some("1:abc123".to_string()));
296 }
297
298 #[test]
299 fn test_extract_pid_from_socket_standard() {
300 let pid = KittyBuilder::extract_pid_from_socket("/tmp/kitty-12345.sock");
301 assert_eq!(pid, Some(12345));
302 }
303
304 #[test]
305 fn test_extract_pid_from_socket_xdg_runtime_dir() {
306 let pid = KittyBuilder::extract_pid_from_socket(
307 "/run/user/1000/kitty/kitty-67890.sock",
308 );
309 assert_eq!(pid, Some(67890));
310 }
311
312 #[test]
313 fn test_extract_pid_from_socket_invalid() {
314 let pid = KittyBuilder::extract_pid_from_socket("/tmp/invalid.sock");
315 assert_eq!(pid, None);
316 }
317
318 #[test]
319 fn test_extract_pid_from_socket_no_prefix() {
320 let pid = KittyBuilder::extract_pid_from_socket("/tmp/12345.sock");
321 assert_eq!(pid, None);
322 }
323
324 #[test]
325 fn test_extract_pid_from_socket_invalid_pid() {
326 let pid = KittyBuilder::extract_pid_from_socket("/tmp/kitty-abc.sock");
327 assert_eq!(pid, None);
328 }
329
330 #[tokio::test]
331 async fn test_builder_missing_socket() {
332 let builder = KittyBuilder::new();
333 let result = builder.connect().await;
334
335 assert!(result.is_err());
336 }
337}