ckeylock_api/
lib.rs

1use std::str::FromStr;
2
3use ckeylock_core::response::ErrorResponse;
4use ckeylock_core::{Request, RequestWrapper, Response};
5use futures_util::{SinkExt, StreamExt};
6use std::sync::Arc;
7use thiserror::Error;
8use tokio::net::TcpStream;
9use tokio::sync::Mutex;
10use tokio_tungstenite::tungstenite::Error as WsError;
11use tokio_tungstenite::tungstenite::client::IntoClientRequest as _;
12use tokio_tungstenite::{
13    MaybeTlsStream, WebSocketStream, connect_async,
14    tungstenite::{ClientRequestBuilder, http::Uri, protocol::Message},
15};
16
17pub struct CKeyLockAPI {
18    bind: String,
19    password: Option<String>,
20}
21
22impl CKeyLockAPI {
23    pub fn new(bind: &str, password: Option<&str>) -> Self {
24        CKeyLockAPI {
25            bind: bind.to_owned(),
26            password: password.map(|p| p.to_owned()),
27        }
28    }
29
30    pub async fn connect(&self) -> Result<CKeyLockConnection, Error> {
31        let url = format!("ws://{}", self.bind);
32        let request = match &self.password {
33            Some(password) => ClientRequestBuilder::new(Uri::from_str(&url)?)
34                .with_header("Authorization", password)
35                .into_client_request()
36                .map_err(|e| Error::Custom(format!("Failed to build client request: {}", e)))?,
37            None => url
38                .into_client_request()
39                .map_err(|e| Error::Custom(format!("Failed to build client request: {}", e)))?,
40        };
41        let (ws_stream, _) = connect_async(request)
42            .await
43            .map_err(|e| Error::Custom(format!("Failed to connect to WebSocket: {}", e)))?;
44
45        Ok(CKeyLockConnection {
46            inner: CkeyLockConnectionInner::new(ws_stream).into(),
47        })
48    }
49}
50
51pub struct CKeyLockConnection {
52    inner: Arc<CkeyLockConnectionInner>,
53}
54
55impl CKeyLockConnection {
56    async fn send_request(&self, request: Request) -> Result<Response, Error> {
57        let request = RequestWrapper::new(request);
58
59        self.inner
60            .send(request_into_message(request.clone()))
61            .await?;
62
63        while let Some(msg) = self.inner.lock().await.next().await {
64            let msg =
65                msg.map_err(|e| Error::Custom(format!("Failed to receive message: {}", e)))?;
66            if let Some(parsed_response) = self.parse_response(&msg, request.id()) {
67                return parsed_response;
68            }
69        }
70        Err(Error::Custom(
71            "Response with matching ID not found".to_string(),
72        ))
73    }
74
75    fn parse_response(&self, msg: &Message, req_id: Vec<u8>) -> Option<Result<Response, Error>> {
76        if let Message::Text(text) = msg {
77            if let Ok(response) = serde_json::from_str::<Response>(text) {
78                if response.reqid() == req_id {
79                    return Some(Ok(response));
80                }
81            } else if let Ok(err_response) = serde_json::from_str::<ErrorResponse>(text) {
82                if err_response.reqid == req_id {
83                    return Some(Err(Error::Custom(format!(
84                        "Error response received: {}",
85                        err_response.message
86                    ))));
87                }
88            }
89        }
90        None
91    }
92
93    pub async fn set(&self, key: Vec<u8>, value: Vec<u8>) -> Result<Vec<u8>, Error> {
94        let res = self.send_request(Request::Set { key, value }).await?;
95        if let Some(ckeylock_core::ResponseData::SetResponse { key }) = res.data() {
96            Ok(key.to_vec())
97        } else {
98            Err(Error::WrongResponseFormat)
99        }
100    }
101
102    pub async fn get(&self, key: Vec<u8>) -> Result<Option<Vec<u8>>, Error> {
103        let res = self.send_request(Request::Get { key }).await?;
104        if let Some(ckeylock_core::ResponseData::GetResponse { value }) = res.data() {
105            Ok(value.as_ref().map(|v| v.to_vec()))
106        } else {
107            Err(Error::WrongResponseFormat)
108        }
109    }
110    pub async fn batch_get(&self, keys: Vec<Vec<u8>>) -> Result<Vec<Option<Vec<u8>>>, Error> {
111        let res = self.send_request(Request::BatchGet { keys }).await?;
112        if let Some(ckeylock_core::ResponseData::BatchGetResponse { values }) = res.data() {
113            Ok(values.clone())
114        } else {
115            Err(Error::WrongResponseFormat)
116        }
117    }
118    pub async fn delete(&self, key: Vec<u8>) -> Result<Option<Vec<u8>>, Error> {
119        let res = self.send_request(Request::Delete { key }).await?;
120        if let Some(ckeylock_core::ResponseData::DeleteResponse { key }) = res.data() {
121            Ok(key.as_ref().map(|v| v.to_vec()))
122        } else {
123            Err(Error::WrongResponseFormat)
124        }
125    }
126
127    pub async fn list(&self) -> Result<Vec<Vec<u8>>, Error> {
128        let res = self.send_request(Request::List).await?;
129        if let Some(ckeylock_core::ResponseData::ListResponse { keys }) = res.data() {
130            Ok(keys.clone())
131        } else {
132            Err(Error::WrongResponseFormat)
133        }
134    }
135
136    pub async fn exists(&self, key: Vec<u8>) -> Result<bool, Error> {
137        let res = self.send_request(Request::Exists { key }).await?;
138        if let Some(ckeylock_core::ResponseData::ExistsResponse { exists }) = res.data() {
139            Ok(*exists)
140        } else {
141            Err(Error::WrongResponseFormat)
142        }
143    }
144
145    pub async fn count(&self) -> Result<usize, Error> {
146        let res = self.send_request(Request::Count).await?;
147        if let Some(ckeylock_core::ResponseData::CountResponse { count }) = res.data() {
148            Ok(*count)
149        } else {
150            Err(Error::WrongResponseFormat)
151        }
152    }
153
154    pub async fn clear(&self) -> Result<(), Error> {
155        let res = self.send_request(Request::Clear).await?;
156        if let Some(ckeylock_core::ResponseData::ClearResponse) = res.data() {
157            Ok(())
158        } else {
159            Err(Error::WrongResponseFormat)
160        }
161    }
162
163    pub async fn close(&self) -> Result<(), Box<dyn std::error::Error>> {
164        self.inner
165            .lock()
166            .await
167            .close(None)
168            .await
169            .map_err(|e| Box::new(Error::Custom(format!("Failed to close WebSocket: {}", e))) as _)
170    }
171}
172
173fn request_into_message(req: ckeylock_core::RequestWrapper) -> Message {
174    Message::Text(req.to_string().into())
175}
176
177pub struct CkeyLockConnectionInner(Mutex<WebSocketStream<MaybeTlsStream<TcpStream>>>);
178
179impl CkeyLockConnectionInner {
180    pub fn new(ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
181        CkeyLockConnectionInner(Mutex::new(ws_stream))
182    }
183
184    pub async fn send(&self, msg: Message) -> Result<(), Error> {
185        self.0
186            .lock()
187            .await
188            .send(msg)
189            .await
190            .map_err(|e| Error::Custom(format!("Failed to send message: {}", e)))
191    }
192    pub async fn lock(
193        &self,
194    ) -> tokio::sync::MutexGuard<'_, WebSocketStream<MaybeTlsStream<TcpStream>>> {
195        self.0.lock().await
196    }
197}
198
199#[derive(Error, Debug)]
200pub enum Error {
201    #[error("WebSocket error: {0}")]
202    WsError(#[from] WsError),
203    #[error("Wrong response format")]
204    WrongResponseFormat,
205    #[error("Failed to parse uri: {0}")]
206    UriParseError(#[from] tokio_tungstenite::tungstenite::http::uri::InvalidUri),
207    #[error("{0}")]
208    Custom(String),
209}
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    #[tokio::test]
215    async fn test_set() {
216        let api = CKeyLockAPI::new("127.0.0.1:5830", Some("helloworld"));
217        let connection = api.connect().await.unwrap();
218
219        let key = b"popa".to_vec();
220        let value = b"pizdec".to_vec();
221
222        let result = connection.set(key.clone(), value.clone()).await;
223        assert!(result.is_ok());
224        assert_eq!(result.unwrap(), key);
225    }
226
227    #[tokio::test]
228    async fn test_get() {
229        let api = CKeyLockAPI::new("127.0.0.1:5830", Some("helloworld"));
230        let connection = api.connect().await.unwrap();
231
232        let key = b"test_key".to_vec();
233        let value = b"test_value".to_vec();
234
235        connection.set(key.clone(), value.clone()).await.unwrap();
236        let result = connection.get(key.clone()).await;
237        assert!(result.is_ok());
238        let unwrapped_value = result.unwrap();
239        assert_eq!(unwrapped_value, Some(value));
240        println!("Value: {:?}", unwrapped_value);
241    }
242
243    #[tokio::test]
244    async fn test_delete() {
245        let api = CKeyLockAPI::new("127.0.0.1:5830", Some("helloworld"));
246        let connection = api.connect().await.unwrap();
247
248        let key = b"test_key".to_vec();
249        let value = b"test_value".to_vec();
250
251        connection.set(key.clone(), value.clone()).await.unwrap();
252        let result = connection.delete(key.clone()).await;
253        assert!(result.is_ok());
254        assert_eq!(result.unwrap(), Some(key));
255    }
256
257    #[tokio::test]
258    async fn test_list() {
259        let api = CKeyLockAPI::new("127.0.0.1:5830", Some("helloworld"));
260        let connection = api.connect().await.unwrap();
261
262        let key1 = b"test_key1".to_vec();
263        let key2 = b"test_key2".to_vec();
264        let value = b"test_value".to_vec();
265
266        connection.set(key1.clone(), value.clone()).await.unwrap();
267        connection.set(key2.clone(), value.clone()).await.unwrap();
268
269        let result = connection.list().await;
270        assert!(result.is_ok());
271        let keys = result.unwrap();
272        assert!(keys.contains(&key1));
273        assert!(keys.contains(&key2));
274    }
275    #[tokio::test]
276    async fn test_batch_get() {
277        let api = CKeyLockAPI::new("127.0.0.1:5830", Some("helloworld"));
278        let connection = api.connect().await.unwrap();
279
280        let key1 = b"batch_key1".to_vec();
281        let value1 = b"batch_value1".to_vec();
282        let key2 = b"batch_key2".to_vec();
283        let value2 = b"batch_value2".to_vec();
284        let key3 = b"batch_key3".to_vec();
285
286        connection.set(key1.clone(), value1.clone()).await.unwrap();
287        connection.set(key2.clone(), value2.clone()).await.unwrap();
288
289        let keys = vec![key1.clone(), key2.clone(), key3.clone()];
290        let result = connection.batch_get(keys).await;
291
292        assert!(result.is_ok());
293        let values = result.unwrap();
294        println!("Values: {:?}", values);
295        assert_eq!(values.len(), 3);
296        assert_eq!(values[0], Some(value1));
297        assert_eq!(values[1], Some(value2));
298        assert_eq!(values[2], None);
299    }
300}