ckeylock_api/
lib.rs

1use std::str::FromStr;
2
3use ckeylock_core::response::ErrorResponse;
4use ckeylock_core::{Request, RequestWrapper, Response};
5use futures_util::{SinkExt, StreamExt};
6use std::cell::RefCell;
7use thiserror::Error;
8use tokio::net::TcpStream;
9use tokio_tungstenite::tungstenite::Error as WsError;
10use tokio_tungstenite::tungstenite::client::IntoClientRequest as _;
11use tokio_tungstenite::{
12    MaybeTlsStream, WebSocketStream, connect_async,
13    tungstenite::{ClientRequestBuilder, http::Uri, protocol::Message},
14};
15
16pub struct CKeyLockAPI {
17    bind: String,
18    password: Option<String>,
19}
20impl CKeyLockAPI {
21    /// Creates a new instance of `CKeyLockAPI` with the given bind address.
22    /// The bind address should be in the format "host:port".
23    pub fn new(bind: &str, password: Option<&str>) -> Self {
24        CKeyLockAPI {
25            bind: bind.to_owned(),
26            password: password.map(|p| p.to_owned()),
27        }
28    }
29
30    pub async fn connect(&self) -> Result<CKeyLockConnection, Error> {
31        let url = format!("ws://{}", self.bind);
32        let request = match &self.password {
33            Some(password) => ClientRequestBuilder::new(Uri::from_str(&url)?)
34                .with_header("Authorization", password)
35                .into_client_request()?,
36            None => url.into_client_request()?,
37        };
38        let (ws_stream, _) = connect_async(request).await?;
39
40        Ok(CKeyLockConnection {
41            ws_stream: RefCell::new(ws_stream),
42        })
43    }
44}
45
46pub struct CKeyLockConnection {
47    ws_stream: RefCell<WebSocketStream<MaybeTlsStream<TcpStream>>>,
48}
49impl CKeyLockConnection {
50    async fn send_request(&self, request: Request) -> Result<Response, Error> {
51        let wrapper = RequestWrapper::new(request);
52        let mut ws_stream = self.ws_stream.borrow_mut();
53        ws_stream
54            .send(request_into_message(wrapper.clone()))
55            .await?;
56        while let Some(msg) = ws_stream.next().await {
57            let msg = msg?;
58            if let Message::Text(text) = msg {
59                if let Ok(response) = serde_json::from_str::<Response>(&text) {
60                    if response.reqid() == wrapper.id() {
61                        return Ok(response);
62                    }
63                } else if let Ok(err_response) = serde_json::from_str::<ErrorResponse>(&text) {
64                    if err_response.reqid == wrapper.id() {
65                        return Err(Error::Custom(err_response.message));
66                    }
67                }
68            }
69        }
70        Err(Error::Custom(
71            "Response with matching ID not found".to_string(),
72        ))
73    }
74    pub async fn set(&self, key: Vec<u8>, value: Vec<u8>) -> Result<Vec<u8>, Error> {
75        let res = self.send_request(Request::Set { key, value }).await?;
76        if let Some(ckeylock_core::ResponseData::SetResponse { key }) = res.data() {
77            Ok(key.to_vec())
78        } else {
79            Err(Error::WrongResponseFormat)
80        }
81    }
82
83    pub async fn get(&self, key: Vec<u8>) -> Result<Option<Vec<u8>>, Error> {
84        let res = self.send_request(Request::Get { key }).await?;
85        if let Some(ckeylock_core::ResponseData::GetResponse { value }) = res.data() {
86            Ok(value.as_ref().map(|v| v.to_vec()))
87        } else {
88            Err(Error::WrongResponseFormat)
89        }
90    }
91
92    pub async fn delete(&self, key: Vec<u8>) -> Result<Option<Vec<u8>>, Error> {
93        let res = self.send_request(Request::Delete { key }).await?;
94        if let Some(ckeylock_core::ResponseData::DeleteResponse { key }) = res.data() {
95            Ok(key.as_ref().map(|v| v.to_vec()))
96        } else {
97            Err(Error::WrongResponseFormat)
98        }
99    }
100
101    pub async fn list(&self) -> Result<Vec<Vec<u8>>, Error> {
102        let res = self.send_request(Request::List).await?;
103        if let Some(ckeylock_core::ResponseData::ListResponse { keys }) = res.data() {
104            Ok(keys.clone())
105        } else {
106            Err(Error::WrongResponseFormat)
107        }
108    }
109
110    pub async fn exists(&self, key: Vec<u8>) -> Result<bool, Error> {
111        let res = self.send_request(Request::Exists { key }).await?;
112        if let Some(ckeylock_core::ResponseData::ExistsResponse { exists }) = res.data() {
113            Ok(*exists)
114        } else {
115            Err(Error::WrongResponseFormat)
116        }
117    }
118
119    pub async fn count(&self) -> Result<usize, Error> {
120        let res = self.send_request(Request::Count).await?;
121        if let Some(ckeylock_core::ResponseData::CountResponse { count }) = res.data() {
122            Ok(*count)
123        } else {
124            Err(Error::WrongResponseFormat)
125        }
126    }
127
128    pub async fn clear(&self) -> Result<(), Error> {
129        let res = self.send_request(Request::Clear).await?;
130        if let Some(ckeylock_core::ResponseData::ClearResponse) = res.data() {
131            Ok(())
132        } else {
133            Err(Error::WrongResponseFormat)
134        }
135    }
136    pub async fn close(&self) -> Result<(), Box<dyn std::error::Error>> {
137        let mut ws_stream = self.ws_stream.borrow_mut();
138        Ok(ws_stream.close(None).await?)
139    }
140}
141
142fn request_into_message(req: ckeylock_core::RequestWrapper) -> Message {
143    Message::Text(req.to_string().into())
144}
145
146#[derive(Error, Debug)]
147pub enum Error {
148    #[error("WebSocket error: {0}")]
149    WsError(#[from] WsError),
150    #[error("Wrong response format")]
151    WrongResponseFormat,
152    #[error("Failed to parse uri: {0}")]
153    UriParseError(#[from] tokio_tungstenite::tungstenite::http::uri::InvalidUri),
154    #[error("{0}")]
155    Custom(String),
156}
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[tokio::test]
162    async fn test_set() {
163        let api = CKeyLockAPI::new(
164            "127.0.0.1:8080",
165            Some(
166                "4f69e5532544b557dcee8dd077318f353af4ebcbed3280c8a9ceaa337ed45d283f9a7c9faebfb0c476a92e07a47c28bc603a93c74c090d41b3d625ea35902f35",
167            ),
168        );
169        let mut connection = api.connect().await.unwrap();
170
171        let key = b"popa".to_vec();
172        let value = b"pisya".to_vec();
173
174        let result = connection.set(key.clone(), value.clone()).await;
175        assert!(result.is_ok());
176        assert_eq!(result.unwrap(), key);
177    }
178
179    #[tokio::test]
180    async fn test_get() {
181        let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
182        let mut connection = api.connect().await.unwrap();
183
184        let key = b"test_key".to_vec();
185        let value = b"test_value".to_vec();
186
187        connection.set(key.clone(), value.clone()).await.unwrap();
188        let result = connection.get(key.clone()).await;
189        assert!(result.is_ok());
190        let unwrapped_value = result.unwrap();
191        assert_eq!(unwrapped_value, Some(value));
192        println!("Value: {:?}", unwrapped_value);
193    }
194
195    #[tokio::test]
196    async fn test_delete() {
197        let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
198        let mut connection = api.connect().await.unwrap();
199
200        let key = b"test_key".to_vec();
201        let value = b"test_value".to_vec();
202
203        connection.set(key.clone(), value.clone()).await.unwrap();
204        let result = connection.delete(key.clone()).await;
205        assert!(result.is_ok());
206        assert_eq!(result.unwrap(), Some(key));
207    }
208
209    #[tokio::test]
210    async fn test_list() {
211        let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
212        let mut connection = api.connect().await.unwrap();
213
214        let key1 = b"test_key1".to_vec();
215        let key2 = b"test_key2".to_vec();
216        let value = b"test_value".to_vec();
217
218        connection.set(key1.clone(), value.clone()).await.unwrap();
219        connection.set(key2.clone(), value.clone()).await.unwrap();
220
221        let result = connection.list().await;
222        assert!(result.is_ok());
223        let keys = result.unwrap();
224        assert!(keys.contains(&key1));
225        assert!(keys.contains(&key2));
226    }
227    #[tokio::test]
228    pub async fn req() {
229        let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
230        let mut connection = api.connect().await.unwrap();
231        let key = b"popa".to_vec();
232        let res = connection.get(key.clone()).await.unwrap();
233        println!("Response: {:?}", res.unwrap());
234    }
235}