1use crate::error::{ConnectionError, KittyError};
2use crate::protocol::{KittyMessage, KittyResponse};
3use std::path::Path;
4use std::time::Duration;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::UnixStream;
7use tokio::time::timeout;
8
9pub struct KittyClient {
10 socket_path: String,
11 stream: Option<UnixStream>,
12 timeout: Duration,
13}
14
15impl KittyClient {
16 pub async fn connect<P: AsRef<Path>>(path: P) -> Result<Self, KittyError> {
17 Self::connect_with_timeout(path, Duration::from_secs(10)).await
18 }
19
20 pub async fn connect_with_timeout<P: AsRef<Path>>(
21 path: P,
22 timeout_duration: Duration,
23 ) -> Result<Self, KittyError> {
24 let path_str = path.as_ref().to_string_lossy().to_string();
25 let stream = timeout(timeout_duration, UnixStream::connect(&path))
26 .await
27 .map_err(|_| ConnectionError::TimeoutError(timeout_duration))?
28 .map_err(|e| ConnectionError::ConnectionFailed(path_str.clone(), e))?;
29
30 Ok(Self {
31 socket_path: path_str,
32 stream: Some(stream),
33 timeout: timeout_duration,
34 })
35 }
36
37 pub fn with_timeout(mut self, timeout: Duration) -> Self {
38 self.timeout = timeout;
39 self
40 }
41
42 async fn ensure_connected(&mut self) -> Result<(), KittyError> {
43 if self.stream.is_none() {
44 let stream = timeout(self.timeout, UnixStream::connect(&self.socket_path))
45 .await
46 .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
47 .map_err(|e| ConnectionError::ConnectionFailed(self.socket_path.clone(), e))?;
48 self.stream = Some(stream);
49 }
50 Ok(())
51 }
52
53 pub async fn send(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
54 self.ensure_connected().await?;
55
56 let data = message.encode()?;
57 let stream = self.stream.as_mut().ok_or(KittyError::Connection(ConnectionError::ConnectionClosed))?;
58
59 timeout(self.timeout, stream.write_all(&data))
60 .await
61 .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
62
63 Ok(())
64 }
65
66 pub async fn receive(&mut self) -> Result<KittyResponse, KittyError> {
67 let stream = self.stream.as_mut().ok_or(KittyError::Connection(ConnectionError::ConnectionClosed))?;
68
69 const SUFFIX: &[u8] = b"\x1b\\";
70 let mut buffer = Vec::new();
71
72 loop {
73 let mut chunk = vec![0u8; 8192];
74 let n = timeout(self.timeout, stream.read(&mut chunk))
75 .await
76 .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
77
78 if n == 0 {
79 break;
80 }
81
82 buffer.extend_from_slice(&chunk[..n]);
83
84 if buffer.ends_with(SUFFIX) {
85 break;
86 }
87 }
88
89 if buffer.is_empty() {
90 return Err(KittyError::Connection(ConnectionError::ConnectionClosed));
91 }
92
93 Ok(KittyResponse::decode(&buffer)?)
94 }
95
96 pub async fn execute(&mut self, message: &KittyMessage) -> Result<KittyResponse, KittyError> {
97 self.send(message).await?;
98 self.receive().await
99 }
100
101 pub async fn send_all(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
102 if message.needs_streaming() {
103 let chunks = message.clone().into_chunks();
104 for chunk in chunks {
105 self.send(&chunk).await?;
106 }
107 } else {
108 self.send(message).await?;
109 }
110 Ok(())
111 }
112
113 pub async fn execute_all(&mut self, message: &KittyMessage) -> Result<KittyResponse, KittyError> {
114 self.send_all(message).await?;
115 self.receive().await
116 }
117
118 pub async fn reconnect(&mut self) -> Result<(), KittyError> {
119 if let Some(mut stream) = self.stream.take() {
120 let _ = stream.shutdown().await;
121 }
122
123 let new_stream = timeout(self.timeout, UnixStream::connect(&self.socket_path))
124 .await
125 .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
126 .map_err(|e| ConnectionError::ConnectionFailed(self.socket_path.clone(), e))?;
127
128 self.stream = Some(new_stream);
129 Ok(())
130 }
131
132 pub async fn close(&mut self) -> Result<(), KittyError> {
133 if let Some(mut stream) = self.stream.take() {
134 stream.shutdown().await.ok();
135 }
136 Ok(())
137 }
138}
139
140impl Drop for KittyClient {
141 fn drop(&mut self) {
142 if let Some(_stream) = self.stream.take() {
143 }
145 }
146}
147
148pub struct ConnectionPool {
149 socket_path: String,
150 timeout: Duration,
151 max_size: usize,
152 connections: Vec<KittyClient>,
153}
154
155impl ConnectionPool {
156 pub fn new<P: AsRef<Path>>(path: P) -> Self {
157 Self {
158 socket_path: path.as_ref().to_string_lossy().to_string(),
159 timeout: Duration::from_secs(10),
160 max_size: 10,
161 connections: Vec::new(),
162 }
163 }
164
165 pub fn with_timeout(mut self, timeout: Duration) -> Self {
166 self.timeout = timeout;
167 self
168 }
169
170 pub fn with_max_size(mut self, max_size: usize) -> Self {
171 self.max_size = max_size;
172 self
173 }
174
175 pub async fn acquire(&mut self) -> Result<KittyClient, KittyError> {
176 if let Some(mut client) = self.connections.pop() {
177 client.ensure_connected().await?;
178 Ok(client)
179 } else {
180 KittyClient::connect_with_timeout(&self.socket_path, self.timeout).await
181 }
182 }
183
184 pub fn release(&mut self, client: KittyClient) {
185 if self.connections.len() < self.max_size {
186 self.connections.push(client);
187 }
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use crate::error::ConnectionError;
195
196 #[tokio::test]
197 async fn test_client_creation() {
198 let client = KittyClient::connect("/nonexistent/socket").await;
199 assert!(client.is_err());
200 }
201
202 #[tokio::test]
203 async fn test_client_timeout() {
204 let client = KittyClient::connect_with_timeout("/nonexistent/socket", Duration::from_millis(100)).await;
205 assert!(client.is_err());
206 }
207
208 #[tokio::test]
209 async fn test_pool_creation() {
210 let pool = ConnectionPool::new("/tmp/test.sock")
211 .with_timeout(Duration::from_secs(5))
212 .with_max_size(5);
213
214 assert_eq!(pool.max_size, 5);
215 }
216
217 #[test]
218 fn test_error_display() {
219 let err = ConnectionError::ConnectionClosed;
220 assert_eq!(err.to_string(), "Connection closed unexpectedly");
221 }
222}