ckeylock_api/
lib.rs

1use std::str::FromStr;
2
3use ckeylock_core::response::ErrorResponse;
4use ckeylock_core::{Request, RequestWrapper, Response};
5use futures_util::{SinkExt, StreamExt};
6use thiserror::Error;
7use tokio::net::TcpStream;
8use tokio_tungstenite::tungstenite::Error as WsError;
9use tokio_tungstenite::tungstenite::client::IntoClientRequest as _;
10use tokio_tungstenite::{
11    MaybeTlsStream, WebSocketStream, connect_async,
12    tungstenite::{ClientRequestBuilder, http::Uri, protocol::Message},
13};
14
15pub struct CKeyLockAPI {
16    bind: String,
17    password: Option<String>,
18}
19impl CKeyLockAPI {
20    /// Creates a new instance of `CKeyLockAPI` with the given bind address.
21    /// The bind address should be in the format "host:port".
22    pub fn new(bind: &str, password: Option<&str>) -> Self {
23        CKeyLockAPI {
24            bind: bind.to_owned(),
25            password: password.map(|p| p.to_owned()),
26        }
27    }
28
29    pub async fn connect(&self) -> Result<CKeyLockConnection, Box<dyn std::error::Error>> {
30        let url = format!("ws://{}", self.bind);
31        let request = match &self.password {
32            Some(password) => ClientRequestBuilder::new(Uri::from_str(&url)?)
33                .with_header("Authorization", password)
34                .into_client_request()?,
35            None => url.into_client_request()?,
36        };
37        let (ws_stream, _) = connect_async(request).await?;
38
39        Ok(CKeyLockConnection { ws_stream })
40    }
41}
42
43pub struct CKeyLockConnection {
44    ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
45}
46impl CKeyLockConnection {
47    async fn send_request(&mut self, request: Request) -> Result<Response, Error> {
48        let wrapper = RequestWrapper::new(request);
49        self.ws_stream
50            .send(request_into_message(wrapper.clone()))
51            .await?;
52        while let Some(msg) = self.ws_stream.next().await {
53            let msg = msg?;
54            if let Message::Text(text) = msg {
55                if let Ok(response) = serde_json::from_str::<Response>(&text) {
56                    if response.reqid() == wrapper.id() {
57                        return Ok(response);
58                    }
59                } else if let Ok(err_response) = serde_json::from_str::<ErrorResponse>(&text) {
60                    if err_response.reqid == wrapper.id() {
61                        return Err(Error::Custom(err_response.message));
62                    }
63                }
64            }
65        }
66        Err(Error::Custom(
67            "Response with matching ID not found".to_string(),
68        ))
69    }
70    pub async fn set(&mut self, key: Vec<u8>, value: Vec<u8>) -> Result<Vec<u8>, Error> {
71        let res = self.send_request(Request::Set { key, value }).await?;
72        if let Some(ckeylock_core::ResponseData::SetResponse { key }) = res.data() {
73            Ok(key.to_vec())
74        } else {
75            Err(Error::WrongResponseFormat)
76        }
77    }
78
79    pub async fn get(&mut self, key: Vec<u8>) -> Result<Option<Vec<u8>>, Error> {
80        let res = self.send_request(Request::Get { key }).await?;
81        if let Some(ckeylock_core::ResponseData::GetResponse { value }) = res.data() {
82            Ok(value.as_ref().map(|v| v.to_vec()))
83        } else {
84            Err(Error::WrongResponseFormat)
85        }
86    }
87
88    pub async fn delete(&mut self, key: Vec<u8>) -> Result<Option<Vec<u8>>, Error> {
89        let res = self.send_request(Request::Delete { key }).await?;
90        if let Some(ckeylock_core::ResponseData::DeleteResponse { key }) = res.data() {
91            Ok(key.as_ref().map(|v| v.to_vec()))
92        } else {
93            Err(Error::WrongResponseFormat)
94        }
95    }
96
97    pub async fn list(&mut self) -> Result<Vec<Vec<u8>>, Error> {
98        let res = self.send_request(Request::List).await?;
99        if let Some(ckeylock_core::ResponseData::ListResponse { keys }) = res.data() {
100            Ok(keys.clone())
101        } else {
102            Err(Error::WrongResponseFormat)
103        }
104    }
105
106    pub async fn exists(&mut self, key: Vec<u8>) -> Result<bool, Error> {
107        let res = self.send_request(Request::Exists { key }).await?;
108        if let Some(ckeylock_core::ResponseData::ExistsResponse { exists }) = res.data() {
109            Ok(*exists)
110        } else {
111            Err(Error::WrongResponseFormat)
112        }
113    }
114
115    pub async fn count(&mut self) -> Result<usize, Error> {
116        let res = self.send_request(Request::Count).await?;
117        if let Some(ckeylock_core::ResponseData::CountResponse { count }) = res.data() {
118            Ok(*count)
119        } else {
120            Err(Error::WrongResponseFormat)
121        }
122    }
123
124    pub async fn clear(&mut self) -> Result<(), Error> {
125        let res = self.send_request(Request::Clear).await?;
126        if let Some(ckeylock_core::ResponseData::ClearResponse) = res.data() {
127            Ok(())
128        } else {
129            Err(Error::WrongResponseFormat)
130        }
131    }
132    pub async fn close(&mut self) -> Result<(), Box<dyn std::error::Error>> {
133        Ok(self.ws_stream.close(None).await?)
134    }
135}
136
137fn request_into_message(req: ckeylock_core::RequestWrapper) -> Message {
138    Message::Text(req.to_string().into())
139}
140
141#[derive(Error, Debug)]
142pub enum Error {
143    #[error("WebSocket error: {0}")]
144    WsError(#[from] WsError),
145    #[error("Wrong response format")]
146    WrongResponseFormat,
147    #[error("{0}")]
148    Custom(String),
149}
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[tokio::test]
155    async fn test_set() {
156        let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
157        let mut connection = api.connect().await.unwrap();
158
159        let key = b"test_key".to_vec();
160        let value = b"test_value".to_vec();
161
162        let result = connection.set(key.clone(), value.clone()).await;
163        assert!(result.is_ok());
164        assert_eq!(result.unwrap(), key);
165    }
166
167    #[tokio::test]
168    async fn test_get() {
169        let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
170        let mut connection = api.connect().await.unwrap();
171
172        let key = b"test_key".to_vec();
173        let value = b"test_value".to_vec();
174
175        connection.set(key.clone(), value.clone()).await.unwrap();
176        let result = connection.get(key.clone()).await;
177        assert!(result.is_ok());
178        let unwrapped_value = result.unwrap();
179        assert_eq!(unwrapped_value, Some(value));
180        println!("Value: {:?}", unwrapped_value);
181    }
182
183    #[tokio::test]
184    async fn test_delete() {
185        let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
186        let mut connection = api.connect().await.unwrap();
187
188        let key = b"test_key".to_vec();
189        let value = b"test_value".to_vec();
190
191        connection.set(key.clone(), value.clone()).await.unwrap();
192        let result = connection.delete(key.clone()).await;
193        assert!(result.is_ok());
194        assert_eq!(result.unwrap(), Some(key));
195    }
196
197    #[tokio::test]
198    async fn test_list() {
199        let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
200        let mut connection = api.connect().await.unwrap();
201
202        let key1 = b"test_key1".to_vec();
203        let key2 = b"test_key2".to_vec();
204        let value = b"test_value".to_vec();
205
206        connection.set(key1.clone(), value.clone()).await.unwrap();
207        connection.set(key2.clone(), value.clone()).await.unwrap();
208
209        let result = connection.list().await;
210        assert!(result.is_ok());
211        let keys = result.unwrap();
212        assert!(keys.contains(&key1));
213        assert!(keys.contains(&key2));
214    }
215    #[tokio::test]
216    pub async fn req() {
217        let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
218        let mut connection = api.connect().await.unwrap();
219        let key = b"test_key".to_vec();
220        let res = connection.get(key.clone()).await.unwrap();
221        println!("Response: {:?}", res.unwrap());
222    }
223}