Skip to main content

kitty_rc/
transport.rs

1use crate::error::{ConnectionError, KittyError};
2use crate::protocol::{KittyMessage, KittyResponse};
3use std::path::Path;
4use std::time::Duration;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::UnixStream;
7use tokio::time::timeout;
8
9pub struct KittyClient {
10    socket_path: String,
11    stream: Option<UnixStream>,
12    timeout: Duration,
13}
14
15impl KittyClient {
16    pub async fn connect<P: AsRef<Path>>(path: P) -> Result<Self, KittyError> {
17        Self::connect_with_timeout(path, Duration::from_secs(10)).await
18    }
19
20    pub async fn connect_with_timeout<P: AsRef<Path>>(
21        path: P,
22        timeout_duration: Duration,
23    ) -> Result<Self, KittyError> {
24        let path_str = path.as_ref().to_string_lossy().to_string();
25        let stream = timeout(timeout_duration, UnixStream::connect(&path))
26            .await
27            .map_err(|_| ConnectionError::TimeoutError(timeout_duration))?
28            .map_err(|e| ConnectionError::ConnectionFailed(path_str.clone(), e))?;
29
30        Ok(Self {
31            socket_path: path_str,
32            stream: Some(stream),
33            timeout: timeout_duration,
34        })
35    }
36
37    pub fn with_timeout(mut self, timeout: Duration) -> Self {
38        self.timeout = timeout;
39        self
40    }
41
42    async fn ensure_connected(&mut self) -> Result<(), KittyError> {
43        if self.stream.is_none() {
44            let stream = timeout(self.timeout, UnixStream::connect(&self.socket_path))
45                .await
46                .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
47                .map_err(|e| ConnectionError::ConnectionFailed(self.socket_path.clone(), e))?;
48            self.stream = Some(stream);
49        }
50        Ok(())
51    }
52
53    pub async fn send(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
54        self.ensure_connected().await?;
55
56        let data = message.encode()?;
57        let stream = self.stream.as_mut().ok_or(KittyError::Connection(ConnectionError::ConnectionClosed))?;
58
59        timeout(self.timeout, stream.write_all(&data))
60            .await
61            .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
62
63        Ok(())
64    }
65
66    pub async fn receive(&mut self) -> Result<KittyResponse, KittyError> {
67        let stream = self.stream.as_mut().ok_or(KittyError::Connection(ConnectionError::ConnectionClosed))?;
68        
69        const SUFFIX: &[u8] = b"\x1b\\";
70        let mut buffer = Vec::new();
71
72        loop {
73            let mut chunk = vec![0u8; 8192];
74            let n = timeout(self.timeout, stream.read(&mut chunk))
75                .await
76                .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
77
78            if n == 0 {
79                break;
80            }
81
82            buffer.extend_from_slice(&chunk[..n]);
83
84            if buffer.ends_with(SUFFIX) {
85                break;
86            }
87        }
88
89        if buffer.is_empty() {
90            return Err(KittyError::Connection(ConnectionError::ConnectionClosed));
91        }
92
93        Ok(KittyResponse::decode(&buffer)?)
94    }
95
96    pub async fn execute(&mut self, message: &KittyMessage) -> Result<KittyResponse, KittyError> {
97        self.send(message).await?;
98        self.receive().await
99    }
100
101    pub async fn send_all(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
102        if message.needs_streaming() {
103            let chunks = message.clone().into_chunks();
104            for chunk in chunks {
105                self.send(&chunk).await?;
106            }
107        } else {
108            self.send(message).await?;
109        }
110        Ok(())
111    }
112
113    pub async fn execute_all(&mut self, message: &KittyMessage) -> Result<KittyResponse, KittyError> {
114        self.send_all(message).await?;
115        self.receive().await
116    }
117
118    pub async fn reconnect(&mut self) -> Result<(), KittyError> {
119        if let Some(mut stream) = self.stream.take() {
120            let _ = stream.shutdown().await;
121        }
122
123        let new_stream = timeout(self.timeout, UnixStream::connect(&self.socket_path))
124            .await
125            .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
126            .map_err(|e| ConnectionError::ConnectionFailed(self.socket_path.clone(), e))?;
127
128        self.stream = Some(new_stream);
129        Ok(())
130    }
131
132    pub async fn close(&mut self) -> Result<(), KittyError> {
133        if let Some(mut stream) = self.stream.take() {
134            stream.shutdown().await.ok();
135        }
136        Ok(())
137    }
138}
139
140impl Drop for KittyClient {
141    fn drop(&mut self) {
142        if let Some(_stream) = self.stream.take() {
143            // The stream will be closed when dropped
144        }
145    }
146}
147
148pub struct ConnectionPool {
149    socket_path: String,
150    timeout: Duration,
151    max_size: usize,
152    connections: Vec<KittyClient>,
153}
154
155impl ConnectionPool {
156    pub fn new<P: AsRef<Path>>(path: P) -> Self {
157        Self {
158            socket_path: path.as_ref().to_string_lossy().to_string(),
159            timeout: Duration::from_secs(10),
160            max_size: 10,
161            connections: Vec::new(),
162        }
163    }
164
165    pub fn with_timeout(mut self, timeout: Duration) -> Self {
166        self.timeout = timeout;
167        self
168    }
169
170    pub fn with_max_size(mut self, max_size: usize) -> Self {
171        self.max_size = max_size;
172        self
173    }
174
175    pub async fn acquire(&mut self) -> Result<KittyClient, KittyError> {
176        if let Some(mut client) = self.connections.pop() {
177            client.ensure_connected().await?;
178            Ok(client)
179        } else {
180            KittyClient::connect_with_timeout(&self.socket_path, self.timeout).await
181        }
182    }
183
184    pub fn release(&mut self, client: KittyClient) {
185        if self.connections.len() < self.max_size {
186            self.connections.push(client);
187        }
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use crate::error::ConnectionError;
195
196    #[tokio::test]
197    async fn test_client_creation() {
198        let client = KittyClient::connect("/nonexistent/socket").await;
199        assert!(client.is_err());
200    }
201
202    #[tokio::test]
203    async fn test_client_timeout() {
204        let client = KittyClient::connect_with_timeout("/nonexistent/socket", Duration::from_millis(100)).await;
205        assert!(client.is_err());
206    }
207
208    #[tokio::test]
209    async fn test_pool_creation() {
210        let pool = ConnectionPool::new("/tmp/test.sock")
211            .with_timeout(Duration::from_secs(5))
212            .with_max_size(5);
213
214        assert_eq!(pool.max_size, 5);
215    }
216
217    #[test]
218    fn test_error_display() {
219        let err = ConnectionError::ConnectionClosed;
220        assert_eq!(err.to_string(), "Connection closed unexpectedly");
221    }
222}