1use crate::error::{ConnectionError, KittyError};
2use crate::protocol::{KittyMessage, KittyResponse};
3use std::path::Path;
4use std::time::Duration;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::UnixStream;
7use tokio::time::timeout;
8
9pub struct Kitty {
10 stream: UnixStream,
11 timeout: Duration,
12 socket_path: String,
13}
14
15pub struct KittyBuilder {
16 socket_path: Option<String>,
17 timeout: Duration,
18}
19
20impl KittyBuilder {
21 pub fn new() -> Self {
22 Self {
23 socket_path: None,
24 timeout: Duration::from_secs(10),
25 }
26 }
27
28 pub fn socket_path<P: AsRef<Path>>(mut self, path: P) -> Self {
29 self.socket_path = Some(path.as_ref().to_string_lossy().to_string());
30 self
31 }
32
33 pub fn timeout(mut self, duration: Duration) -> Self {
34 self.timeout = duration;
35 self
36 }
37
38 pub async fn connect(self) -> Result<Kitty, KittyError> {
39 let socket_path = self
40 .socket_path
41 .ok_or_else(|| KittyError::Connection(ConnectionError::SocketNotFound(
42 "No socket path provided".to_string(),
43 )))?;
44
45 let stream = timeout(self.timeout, UnixStream::connect(&socket_path))
46 .await
47 .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
48 .map_err(|e| ConnectionError::ConnectionFailed(socket_path.clone(), e))?;
49
50 Ok(Kitty {
51 stream,
52 timeout: self.timeout,
53 socket_path,
54 })
55 }
56}
57
58impl Kitty {
59 pub fn builder() -> KittyBuilder {
60 KittyBuilder::new()
61 }
62 async fn send(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
63 let data = message.encode()?;
64
65 timeout(self.timeout, self.stream.write_all(&data))
66 .await
67 .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
68
69 Ok(())
70 }
71
72 async fn receive(&mut self) -> Result<KittyResponse, KittyError> {
73 const SUFFIX: &[u8] = b"\x1b\\";
74 let mut buffer = Vec::new();
75
76 loop {
77 let mut chunk = vec![0u8; 8192];
78 let n = timeout(self.timeout, self.stream.read(&mut chunk))
79 .await
80 .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
81
82 if n == 0 {
83 break;
84 }
85
86 buffer.extend_from_slice(&chunk[..n]);
87
88 if buffer.ends_with(SUFFIX) {
89 break;
90 }
91 }
92
93 if buffer.is_empty() {
94 return Err(KittyError::Connection(ConnectionError::ConnectionClosed));
95 }
96
97 Ok(KittyResponse::decode(&buffer)?)
98 }
99
100 pub async fn execute(&mut self, message: &KittyMessage) -> Result<KittyResponse, KittyError> {
101 self.send(message).await?;
102 self.receive().await
103 }
104
105 pub async fn send_all(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
106 if message.needs_streaming() {
107 for chunk in message.clone().into_chunks() {
108 self.send(&chunk).await?;
109 }
110 } else {
111 self.send(message).await?;
112 }
113 Ok(())
114 }
115
116 pub async fn execute_all(&mut self, message: &KittyMessage) -> Result<KittyResponse, KittyError> {
117 self.send_all(message).await?;
118 self.receive().await
119 }
120
121 pub async fn send_command<T: Into<KittyMessage>>(&mut self, command: T) -> Result<(), KittyError> {
122 self.send_all(&command.into()).await
123 }
124
125 pub async fn reconnect(&mut self) -> Result<(), KittyError> {
126 let _ = self.stream.shutdown().await;
127
128 let new_stream = timeout(self.timeout, UnixStream::connect(&self.socket_path))
129 .await
130 .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
131 .map_err(|e| ConnectionError::ConnectionFailed(self.socket_path.clone(), e))?;
132
133 self.stream = new_stream;
134 Ok(())
135 }
136
137 pub async fn close(&mut self) -> Result<(), KittyError> {
138 self.stream.shutdown().await.ok();
139 Ok(())
140 }
141}
142
143impl Drop for Kitty {
144 fn drop(&mut self) {
145 let _ = self.stream.shutdown();
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[test]
154 fn test_builder_creation() {
155 let builder = KittyBuilder::new()
156 .socket_path("/tmp/test.sock")
157 .timeout(Duration::from_secs(5));
158
159 assert_eq!(builder.socket_path, Some("/tmp/test.sock".to_string()));
160 assert_eq!(builder.timeout, Duration::from_secs(5));
161 }
162
163 #[tokio::test]
164 async fn test_builder_missing_socket() {
165 let builder = KittyBuilder::new();
166 let result = builder.connect().await;
167 assert!(result.is_err());
168 }
169}