1use std::str::FromStr;
2
3use ckeylock_core::response::ErrorResponse;
4use ckeylock_core::{Request, RequestWrapper, Response};
5use futures_util::{SinkExt, StreamExt};
6use std::cell::RefCell;
7use thiserror::Error;
8use tokio::net::TcpStream;
9use tokio_tungstenite::tungstenite::Error as WsError;
10use tokio_tungstenite::tungstenite::client::IntoClientRequest as _;
11use tokio_tungstenite::{
12 MaybeTlsStream, WebSocketStream, connect_async,
13 tungstenite::{ClientRequestBuilder, http::Uri, protocol::Message},
14};
15
16pub struct CKeyLockAPI {
17 bind: String,
18 password: Option<String>,
19}
20impl CKeyLockAPI {
21 pub fn new(bind: &str, password: Option<&str>) -> Self {
24 CKeyLockAPI {
25 bind: bind.to_owned(),
26 password: password.map(|p| p.to_owned()),
27 }
28 }
29
30 pub async fn connect(&self) -> Result<CKeyLockConnection, Error> {
31 let url = format!("ws://{}", self.bind);
32 let request = match &self.password {
33 Some(password) => ClientRequestBuilder::new(Uri::from_str(&url)?)
34 .with_header("Authorization", password)
35 .into_client_request()?,
36 None => url.into_client_request()?,
37 };
38 let (ws_stream, _) = connect_async(request).await?;
39
40 Ok(CKeyLockConnection {
41 ws_stream: RefCell::new(ws_stream),
42 })
43 }
44}
45
46pub struct CKeyLockConnection {
47 ws_stream: RefCell<WebSocketStream<MaybeTlsStream<TcpStream>>>,
48}
49impl CKeyLockConnection {
50 async fn send_request(&self, request: Request) -> Result<Response, Error> {
51 let wrapper = RequestWrapper::new(request);
52 let mut ws_stream = self.ws_stream.borrow_mut();
53 ws_stream
54 .send(request_into_message(wrapper.clone()))
55 .await?;
56 while let Some(msg) = ws_stream.next().await {
57 let msg = msg?;
58 if let Message::Text(text) = msg {
59 if let Ok(response) = serde_json::from_str::<Response>(&text) {
60 if response.reqid() == wrapper.id() {
61 return Ok(response);
62 }
63 } else if let Ok(err_response) = serde_json::from_str::<ErrorResponse>(&text) {
64 if err_response.reqid == wrapper.id() {
65 return Err(Error::Custom(err_response.message));
66 }
67 }
68 }
69 }
70 Err(Error::Custom(
71 "Response with matching ID not found".to_string(),
72 ))
73 }
74 pub async fn set(&self, key: Vec<u8>, value: Vec<u8>) -> Result<Vec<u8>, Error> {
75 let res = self.send_request(Request::Set { key, value }).await?;
76 if let Some(ckeylock_core::ResponseData::SetResponse { key }) = res.data() {
77 Ok(key.to_vec())
78 } else {
79 Err(Error::WrongResponseFormat)
80 }
81 }
82
83 pub async fn get(&self, key: Vec<u8>) -> Result<Option<Vec<u8>>, Error> {
84 let res = self.send_request(Request::Get { key }).await?;
85 if let Some(ckeylock_core::ResponseData::GetResponse { value }) = res.data() {
86 Ok(value.as_ref().map(|v| v.to_vec()))
87 } else {
88 Err(Error::WrongResponseFormat)
89 }
90 }
91
92 pub async fn delete(&self, key: Vec<u8>) -> Result<Option<Vec<u8>>, Error> {
93 let res = self.send_request(Request::Delete { key }).await?;
94 if let Some(ckeylock_core::ResponseData::DeleteResponse { key }) = res.data() {
95 Ok(key.as_ref().map(|v| v.to_vec()))
96 } else {
97 Err(Error::WrongResponseFormat)
98 }
99 }
100
101 pub async fn list(&self) -> Result<Vec<Vec<u8>>, Error> {
102 let res = self.send_request(Request::List).await?;
103 if let Some(ckeylock_core::ResponseData::ListResponse { keys }) = res.data() {
104 Ok(keys.clone())
105 } else {
106 Err(Error::WrongResponseFormat)
107 }
108 }
109
110 pub async fn exists(&self, key: Vec<u8>) -> Result<bool, Error> {
111 let res = self.send_request(Request::Exists { key }).await?;
112 if let Some(ckeylock_core::ResponseData::ExistsResponse { exists }) = res.data() {
113 Ok(*exists)
114 } else {
115 Err(Error::WrongResponseFormat)
116 }
117 }
118
119 pub async fn count(&self) -> Result<usize, Error> {
120 let res = self.send_request(Request::Count).await?;
121 if let Some(ckeylock_core::ResponseData::CountResponse { count }) = res.data() {
122 Ok(*count)
123 } else {
124 Err(Error::WrongResponseFormat)
125 }
126 }
127
128 pub async fn clear(&self) -> Result<(), Error> {
129 let res = self.send_request(Request::Clear).await?;
130 if let Some(ckeylock_core::ResponseData::ClearResponse) = res.data() {
131 Ok(())
132 } else {
133 Err(Error::WrongResponseFormat)
134 }
135 }
136 pub async fn close(&self) -> Result<(), Box<dyn std::error::Error>> {
137 let mut ws_stream = self.ws_stream.borrow_mut();
138 Ok(ws_stream.close(None).await?)
139 }
140}
141
142fn request_into_message(req: ckeylock_core::RequestWrapper) -> Message {
143 Message::Text(req.to_string().into())
144}
145
146#[derive(Error, Debug)]
147pub enum Error {
148 #[error("WebSocket error: {0}")]
149 WsError(#[from] WsError),
150 #[error("Wrong response format")]
151 WrongResponseFormat,
152 #[error("Failed to parse uri: {0}")]
153 UriParseError(#[from] tokio_tungstenite::tungstenite::http::uri::InvalidUri),
154 #[error("{0}")]
155 Custom(String),
156}
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[tokio::test]
162 async fn test_set() {
163 let api = CKeyLockAPI::new(
164 "127.0.0.1:8080",
165 Some(
166 "4f69e5532544b557dcee8dd077318f353af4ebcbed3280c8a9ceaa337ed45d283f9a7c9faebfb0c476a92e07a47c28bc603a93c74c090d41b3d625ea35902f35",
167 ),
168 );
169 let mut connection = api.connect().await.unwrap();
170
171 let key = b"popa".to_vec();
172 let value = b"pisya".to_vec();
173
174 let result = connection.set(key.clone(), value.clone()).await;
175 assert!(result.is_ok());
176 assert_eq!(result.unwrap(), key);
177 }
178
179 #[tokio::test]
180 async fn test_get() {
181 let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
182 let mut connection = api.connect().await.unwrap();
183
184 let key = b"test_key".to_vec();
185 let value = b"test_value".to_vec();
186
187 connection.set(key.clone(), value.clone()).await.unwrap();
188 let result = connection.get(key.clone()).await;
189 assert!(result.is_ok());
190 let unwrapped_value = result.unwrap();
191 assert_eq!(unwrapped_value, Some(value));
192 println!("Value: {:?}", unwrapped_value);
193 }
194
195 #[tokio::test]
196 async fn test_delete() {
197 let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
198 let mut connection = api.connect().await.unwrap();
199
200 let key = b"test_key".to_vec();
201 let value = b"test_value".to_vec();
202
203 connection.set(key.clone(), value.clone()).await.unwrap();
204 let result = connection.delete(key.clone()).await;
205 assert!(result.is_ok());
206 assert_eq!(result.unwrap(), Some(key));
207 }
208
209 #[tokio::test]
210 async fn test_list() {
211 let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
212 let mut connection = api.connect().await.unwrap();
213
214 let key1 = b"test_key1".to_vec();
215 let key2 = b"test_key2".to_vec();
216 let value = b"test_value".to_vec();
217
218 connection.set(key1.clone(), value.clone()).await.unwrap();
219 connection.set(key2.clone(), value.clone()).await.unwrap();
220
221 let result = connection.list().await;
222 assert!(result.is_ok());
223 let keys = result.unwrap();
224 assert!(keys.contains(&key1));
225 assert!(keys.contains(&key2));
226 }
227 #[tokio::test]
228 pub async fn req() {
229 let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
230 let mut connection = api.connect().await.unwrap();
231 let key = b"popa".to_vec();
232 let res = connection.get(key.clone()).await.unwrap();
233 println!("Response: {:?}", res.unwrap());
234 }
235}