Skip to main content

kitty_rc/
client.rs

1use crate::error::{ConnectionError, KittyError};
2use crate::protocol::{KittyMessage, KittyResponse};
3use std::path::Path;
4use std::time::Duration;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::UnixStream;
7use tokio::time::timeout;
8
9pub struct Kitty {
10    stream: UnixStream,
11    timeout: Duration,
12    socket_path: String,
13}
14
15pub struct KittyBuilder {
16    socket_path: Option<String>,
17    timeout: Duration,
18}
19
20impl KittyBuilder {
21    pub fn new() -> Self {
22        Self {
23            socket_path: None,
24            timeout: Duration::from_secs(10),
25        }
26    }
27
28    pub fn socket_path<P: AsRef<Path>>(mut self, path: P) -> Self {
29        self.socket_path = Some(path.as_ref().to_string_lossy().to_string());
30        self
31    }
32
33    pub fn timeout(mut self, duration: Duration) -> Self {
34        self.timeout = duration;
35        self
36    }
37
38    pub async fn connect(self) -> Result<Kitty, KittyError> {
39        let socket_path = self
40            .socket_path
41            .ok_or_else(|| KittyError::Connection(ConnectionError::SocketNotFound(
42                "No socket path provided".to_string(),
43            )))?;
44
45        let stream = timeout(self.timeout, UnixStream::connect(&socket_path))
46            .await
47            .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
48            .map_err(|e| ConnectionError::ConnectionFailed(socket_path.clone(), e))?;
49
50        Ok(Kitty {
51            stream,
52            timeout: self.timeout,
53            socket_path,
54        })
55    }
56}
57
58impl Kitty {
59    pub fn builder() -> KittyBuilder {
60        KittyBuilder::new()
61    }
62    async fn send(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
63        let data = message.encode()?;
64
65        timeout(self.timeout, self.stream.write_all(&data))
66            .await
67            .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
68
69        Ok(())
70    }
71
72    async fn receive(&mut self) -> Result<KittyResponse, KittyError> {
73        const SUFFIX: &[u8] = b"\x1b\\";
74        let mut buffer = Vec::new();
75
76        loop {
77            let mut chunk = vec![0u8; 8192];
78            let n = timeout(self.timeout, self.stream.read(&mut chunk))
79                .await
80                .map_err(|_| ConnectionError::TimeoutError(self.timeout))??;
81
82            if n == 0 {
83                break;
84            }
85
86            buffer.extend_from_slice(&chunk[..n]);
87
88            if buffer.ends_with(SUFFIX) {
89                break;
90            }
91        }
92
93        if buffer.is_empty() {
94            return Err(KittyError::Connection(ConnectionError::ConnectionClosed));
95        }
96
97        Ok(KittyResponse::decode(&buffer)?)
98    }
99
100    pub async fn execute(&mut self, message: &KittyMessage) -> Result<KittyResponse, KittyError> {
101        self.send(message).await?;
102        self.receive().await
103    }
104
105    pub async fn send_all(&mut self, message: &KittyMessage) -> Result<(), KittyError> {
106        if message.needs_streaming() {
107            for chunk in message.clone().into_chunks() {
108                self.send(&chunk).await?;
109            }
110        } else {
111            self.send(message).await?;
112        }
113        Ok(())
114    }
115
116    pub async fn execute_all(&mut self, message: &KittyMessage) -> Result<KittyResponse, KittyError> {
117        self.send_all(message).await?;
118        self.receive().await
119    }
120
121    pub async fn send_command<T: Into<KittyMessage>>(&mut self, command: T) -> Result<(), KittyError> {
122        self.send_all(&command.into()).await
123    }
124
125    pub async fn reconnect(&mut self) -> Result<(), KittyError> {
126        let _ = self.stream.shutdown().await;
127
128        let new_stream = timeout(self.timeout, UnixStream::connect(&self.socket_path))
129            .await
130            .map_err(|_| ConnectionError::TimeoutError(self.timeout))?
131            .map_err(|e| ConnectionError::ConnectionFailed(self.socket_path.clone(), e))?;
132
133        self.stream = new_stream;
134        Ok(())
135    }
136
137    pub async fn close(&mut self) -> Result<(), KittyError> {
138        self.stream.shutdown().await.ok();
139        Ok(())
140    }
141}
142
143impl Drop for Kitty {
144    fn drop(&mut self) {
145        let _ = self.stream.shutdown();
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn test_builder_creation() {
155        let builder = KittyBuilder::new()
156            .socket_path("/tmp/test.sock")
157            .timeout(Duration::from_secs(5));
158
159        assert_eq!(builder.socket_path, Some("/tmp/test.sock".to_string()));
160        assert_eq!(builder.timeout, Duration::from_secs(5));
161    }
162
163    #[tokio::test]
164    async fn test_builder_missing_socket() {
165        let builder = KittyBuilder::new();
166        let result = builder.connect().await;
167        assert!(result.is_err());
168    }
169}