1use std::str::FromStr;
2
3use ckeylock_core::response::ErrorResponse;
4use ckeylock_core::{Request, RequestWrapper, Response};
5use futures_util::{SinkExt, StreamExt};
6use thiserror::Error;
7use tokio::net::TcpStream;
8use tokio_tungstenite::tungstenite::Error as WsError;
9use tokio_tungstenite::tungstenite::client::IntoClientRequest as _;
10use tokio_tungstenite::{
11 MaybeTlsStream, WebSocketStream, connect_async,
12 tungstenite::{ClientRequestBuilder, http::Uri, protocol::Message},
13};
14
15pub struct CKeyLockAPI {
16 bind: String,
17 password: Option<String>,
18}
19impl CKeyLockAPI {
20 pub fn new(bind: &str, password: Option<&str>) -> Self {
23 CKeyLockAPI {
24 bind: bind.to_owned(),
25 password: password.map(|p| p.to_owned()),
26 }
27 }
28
29 pub async fn connect(&self) -> Result<CKeyLockConnection, Box<dyn std::error::Error>> {
30 let url = format!("ws://{}", self.bind);
31 let request = match &self.password {
32 Some(password) => ClientRequestBuilder::new(Uri::from_str(&url)?)
33 .with_header("Authorization", password)
34 .into_client_request()?,
35 None => url.into_client_request()?,
36 };
37 let (ws_stream, _) = connect_async(request).await?;
38
39 Ok(CKeyLockConnection { ws_stream })
40 }
41}
42
43pub struct CKeyLockConnection {
44 ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
45}
46impl CKeyLockConnection {
47 async fn send_request(&mut self, request: Request) -> Result<Response, Error> {
48 let wrapper = RequestWrapper::new(request);
49 self.ws_stream
50 .send(request_into_message(wrapper.clone()))
51 .await?;
52 while let Some(msg) = self.ws_stream.next().await {
53 let msg = msg?;
54 if let Message::Text(text) = msg {
55 if let Ok(response) = serde_json::from_str::<Response>(&text) {
56 if response.reqid() == wrapper.id() {
57 return Ok(response);
58 }
59 } else if let Ok(err_response) = serde_json::from_str::<ErrorResponse>(&text) {
60 if err_response.reqid == wrapper.id() {
61 return Err(Error::Custom(err_response.message));
62 }
63 }
64 }
65 }
66 Err(Error::Custom(
67 "Response with matching ID not found".to_string(),
68 ))
69 }
70 pub async fn set(&mut self, key: Vec<u8>, value: Vec<u8>) -> Result<Vec<u8>, Error> {
71 let res = self.send_request(Request::Set { key, value }).await?;
72 if let Some(ckeylock_core::ResponseData::SetResponse { key }) = res.data() {
73 Ok(key.to_vec())
74 } else {
75 Err(Error::WrongResponseFormat)
76 }
77 }
78
79 pub async fn get(&mut self, key: Vec<u8>) -> Result<Option<Vec<u8>>, Error> {
80 let res = self.send_request(Request::Get { key }).await?;
81 if let Some(ckeylock_core::ResponseData::GetResponse { value }) = res.data() {
82 Ok(value.as_ref().map(|v| v.to_vec()))
83 } else {
84 Err(Error::WrongResponseFormat)
85 }
86 }
87
88 pub async fn delete(&mut self, key: Vec<u8>) -> Result<Option<Vec<u8>>, Error> {
89 let res = self.send_request(Request::Delete { key }).await?;
90 if let Some(ckeylock_core::ResponseData::DeleteResponse { key }) = res.data() {
91 Ok(key.as_ref().map(|v| v.to_vec()))
92 } else {
93 Err(Error::WrongResponseFormat)
94 }
95 }
96
97 pub async fn list(&mut self) -> Result<Vec<Vec<u8>>, Error> {
98 let res = self.send_request(Request::List).await?;
99 if let Some(ckeylock_core::ResponseData::ListResponse { keys }) = res.data() {
100 Ok(keys.clone())
101 } else {
102 Err(Error::WrongResponseFormat)
103 }
104 }
105
106 pub async fn exists(&mut self, key: Vec<u8>) -> Result<bool, Error> {
107 let res = self.send_request(Request::Exists { key }).await?;
108 if let Some(ckeylock_core::ResponseData::ExistsResponse { exists }) = res.data() {
109 Ok(*exists)
110 } else {
111 Err(Error::WrongResponseFormat)
112 }
113 }
114
115 pub async fn count(&mut self) -> Result<usize, Error> {
116 let res = self.send_request(Request::Count).await?;
117 if let Some(ckeylock_core::ResponseData::CountResponse { count }) = res.data() {
118 Ok(*count)
119 } else {
120 Err(Error::WrongResponseFormat)
121 }
122 }
123
124 pub async fn clear(&mut self) -> Result<(), Error> {
125 let res = self.send_request(Request::Clear).await?;
126 if let Some(ckeylock_core::ResponseData::ClearResponse) = res.data() {
127 Ok(())
128 } else {
129 Err(Error::WrongResponseFormat)
130 }
131 }
132 pub async fn close(&mut self) -> Result<(), Box<dyn std::error::Error>> {
133 Ok(self.ws_stream.close(None).await?)
134 }
135}
136
137fn request_into_message(req: ckeylock_core::RequestWrapper) -> Message {
138 Message::Text(req.to_string().into())
139}
140
141#[derive(Error, Debug)]
142pub enum Error {
143 #[error("WebSocket error: {0}")]
144 WsError(#[from] WsError),
145 #[error("Wrong response format")]
146 WrongResponseFormat,
147 #[error("{0}")]
148 Custom(String),
149}
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[tokio::test]
155 async fn test_set() {
156 let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
157 let mut connection = api.connect().await.unwrap();
158
159 let key = b"test_key".to_vec();
160 let value = b"test_value".to_vec();
161
162 let result = connection.set(key.clone(), value.clone()).await;
163 assert!(result.is_ok());
164 assert_eq!(result.unwrap(), key);
165 }
166
167 #[tokio::test]
168 async fn test_get() {
169 let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
170 let mut connection = api.connect().await.unwrap();
171
172 let key = b"test_key".to_vec();
173 let value = b"test_value".to_vec();
174
175 connection.set(key.clone(), value.clone()).await.unwrap();
176 let result = connection.get(key.clone()).await;
177 assert!(result.is_ok());
178 let unwrapped_value = result.unwrap();
179 assert_eq!(unwrapped_value, Some(value));
180 println!("Value: {:?}", unwrapped_value);
181 }
182
183 #[tokio::test]
184 async fn test_delete() {
185 let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
186 let mut connection = api.connect().await.unwrap();
187
188 let key = b"test_key".to_vec();
189 let value = b"test_value".to_vec();
190
191 connection.set(key.clone(), value.clone()).await.unwrap();
192 let result = connection.delete(key.clone()).await;
193 assert!(result.is_ok());
194 assert_eq!(result.unwrap(), Some(key));
195 }
196
197 #[tokio::test]
198 async fn test_list() {
199 let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
200 let mut connection = api.connect().await.unwrap();
201
202 let key1 = b"test_key1".to_vec();
203 let key2 = b"test_key2".to_vec();
204 let value = b"test_value".to_vec();
205
206 connection.set(key1.clone(), value.clone()).await.unwrap();
207 connection.set(key2.clone(), value.clone()).await.unwrap();
208
209 let result = connection.list().await;
210 assert!(result.is_ok());
211 let keys = result.unwrap();
212 assert!(keys.contains(&key1));
213 assert!(keys.contains(&key2));
214 }
215 #[tokio::test]
216 pub async fn req() {
217 let api = CKeyLockAPI::new("127.0.0.1:8080", Some("helloworld"));
218 let mut connection = api.connect().await.unwrap();
219 let key = b"test_key".to_vec();
220 let res = connection.get(key.clone()).await.unwrap();
221 println!("Response: {:?}", res.unwrap());
222 }
223}