1use crate::{Cache, Error, Stats, System, stats::ConnectionState, system};
2use serde::Deserialize;
3use serde_json::Value;
4use std::{
5 collections::HashMap,
6 io::ErrorKind,
7 net::TcpStream,
8 sync::{
9 Arc, Mutex,
10 atomic::{AtomicU64, Ordering},
11 mpsc::{self, Receiver, Sender},
12 },
13 time::{Duration, SystemTime},
14};
15use strum_macros::{AsRefStr, Display};
16use tungstenite::{Message, WebSocket, stream::MaybeTlsStream};
17
18pub const DEFAULT_TIMEOUT_SECS: u64 = 5;
19
20#[derive(AsRefStr, Clone, Debug, Display)]
21pub enum Format {
22 #[strum(serialize = "json")]
23 Json,
24}
25
26#[derive(Clone, Debug)]
27pub(crate) struct Response {
28 error: Option<String>,
29}
30
31struct State {
32 cache: Arc<Mutex<Cache>>,
33 next_transaction_id: AtomicU64,
34 requests: Arc<Mutex<HashMap<u64, Option<Response>>>>,
35 socket: Arc<Mutex<WebSocket<MaybeTlsStream<TcpStream>>>>,
36 subscribers: Arc<Mutex<Vec<Sender<Cache>>>>,
37}
38
39#[derive(Clone)]
40pub struct Client {
41 state: Arc<State>,
42}
43
44#[derive(Clone, Debug, Default, Deserialize)]
45struct SparseStats {
46 started: Option<String>,
47 reads: Option<u32>,
48 writes: Option<u32>,
49 updates: Option<u32>,
50 errors: Option<u32>,
51 connection: Option<ConnectionState>,
52}
53
54#[derive(Debug, Default, Deserialize)]
55struct SparseCache {
56 stats: Option<SparseStats>,
57 version: Option<String>,
58 keys: Option<HashMap<String, Value>>,
59}
60
61impl From<SparseCache> for Cache {
62 fn from(other: SparseCache) -> Self {
63 let mut cache = Cache::default();
64 if let Some(stats) = other.stats {
65 if let Some(started) = stats.started {
66 cache.stats.started = started.clone();
67 }
68 if let Some(reads) = stats.reads {
69 cache.stats.reads = reads;
70 }
71 if let Some(writes) = stats.writes {
72 cache.stats.writes = writes;
73 }
74 if let Some(updates) = stats.updates {
75 cache.stats.updates = updates;
76 }
77 if let Some(errors) = stats.errors {
78 cache.stats.errors = errors;
79 }
80 if let Some(connection) = stats.connection {
81 cache.stats.connection = connection.clone();
82 }
83 }
84 if let Some(version) = other.version {
85 cache.version = version.clone();
86 }
87 if let Some(keys) = other.keys {
88 cache.keys = keys.clone();
89 }
90 cache
91 }
92}
93
94impl Client {
95 pub(crate) fn new(socket: Arc<Mutex<WebSocket<MaybeTlsStream<TcpStream>>>>) -> Self {
96 let next_transaction_id = AtomicU64::new(1);
97 let cache = Arc::new(Mutex::new(Cache::default()));
98 let requests = Arc::new(Mutex::new(HashMap::new()));
99 let subscribers = Arc::new(Mutex::<Vec<Sender<Cache>>>::new(vec![]));
100
101 let socket_2 = socket.clone();
103 let cache_2 = cache.clone();
104 let requests_2 = requests.clone();
105 let subscribers_2 = subscribers.clone();
106 std::thread::spawn(move || {
107 loop {
108 let maybe_message = {
109 if let Ok(mut socket) = socket_2.lock() {
110 match socket.read() {
111 Ok(message) => Some(message),
112 Err(e) => match e {
113 tungstenite::Error::Io(ref e) if e.kind() == ErrorKind::WouldBlock => None,
114 _ => panic!("{e:?}"),
115 },
116 }
117 } else {
118 continue;
119 }
120 };
121 let message = match maybe_message {
122 Some(message) => message,
123 None => {
124 std::thread::sleep(Duration::from_millis(20));
125 continue;
126 }
127 };
128 if let Message::Text(ref message) = message {
129 let incoming: HashMap<String, Value> = serde_json::from_slice(message.as_bytes()).unwrap();
130
131 if let Some(Value::Number(transaction)) = incoming.get("transaction") {
132 let transaction: u64 = transaction.as_u64().unwrap();
133 if let Some(response) = incoming.get("response") {
134 let mut requests = requests_2.lock().unwrap();
135 if response.is_string() && response.as_str().unwrap_or_default().is_empty() {
136 requests.insert(transaction, Some(Response { error: None }));
137 } else if let Value::Object(response) = response {
138 if let Some(error_msg) = response.get("error") {
139 requests.insert(
140 transaction,
141 Some(Response {
142 error: Some(error_msg.to_string()),
143 }),
144 );
145 }
146 }
147 }
148 continue;
149 }
150
151 let payload: SparseCache = serde_json::from_slice(message.as_bytes()).unwrap();
153 if let Ok(mut cache) = cache_2.lock() {
154 Self::merge(&mut cache, &payload);
155 }
156
157 {
159 let subscribers = subscribers_2.lock().unwrap();
160 let event: Cache = payload.into();
161 for tx in subscribers.iter() {
162 tx.send(event.clone()).unwrap();
163 }
164 }
165 }
166 }
167 });
168
169 let state = Arc::new(State {
170 cache,
171 next_transaction_id,
172 requests,
173 socket,
174 subscribers,
175 });
176 Self { state }
177 }
178
179 pub fn get(&self, key: &str, default_value: Option<Value>) -> Result<Option<Value>, Error> {
180 let cache = self.state.cache.lock()?;
181 let value = match cache.keys.get(key) {
182 Some(value) => Some(value.clone()),
183 None => default_value,
184 };
185 Ok(value)
186 }
187
188 pub fn keys(&self) -> Result<Vec<(String, Value)>, Error> {
189 let mut vec = vec![];
190 let cache = self.state.cache.lock()?;
191 for (k, v) in cache.keys.iter() {
192 vec.push((k.clone(), v.clone()));
193 }
194 Ok(vec)
195 }
196
197 fn merge(target: &mut Cache, source: &SparseCache) {
198 if let Some(ref stats) = source.stats {
199 if let Some(ref started) = stats.started {
200 target.stats.started = started.clone();
201 }
202 if let Some(reads) = stats.reads {
203 target.stats.reads = reads;
204 }
205 if let Some(writes) = stats.writes {
206 target.stats.writes = writes;
207 }
208 if let Some(updates) = stats.updates {
209 target.stats.updates = updates;
210 }
211 if let Some(errors) = stats.errors {
212 target.stats.errors = errors;
213 }
214 if let Some(ref connection) = stats.connection {
215 target.stats.connection = connection.clone();
216 }
217 }
218 if let Some(ref version) = source.version {
219 target.version = version.clone();
220 }
221 if let Some(ref keys) = source.keys {
222 for (k, v) in keys.iter() {
223 target.keys.insert(k.to_string(), v.clone());
224 }
225 }
226 }
227
228 pub fn on_update(&mut self) -> Result<Receiver<Cache>, Error> {
229 let (tx, rx) = mpsc::channel();
230 let mut subscribers = self.state.subscribers.lock()?;
231 subscribers.push(tx);
232 Ok(rx)
233 }
234
235 pub fn put(&mut self, key: &str, value: &str) -> Result<(), Error> {
236 let args = vec![key.to_string(), value.to_string()];
237 let _ = self.send(Format::Json, "put", &args)?;
238 Ok(())
239 }
240
241 pub fn send(&mut self, format: Format, cmd: &str, args: &[String]) -> Result<u64, Error> {
242 let mut cmd = cmd.to_string();
243 for arg in args {
244 cmd = format!("{cmd} \"{arg}\"");
245 }
246
247 let mut msg = HashMap::new();
248 let transaction_id = self.state.next_transaction_id.fetch_add(1, Ordering::SeqCst);
249 msg.insert("format".to_string(), Value::String(format.to_string()));
250 msg.insert("transaction".to_string(), Value::from(transaction_id));
251 msg.insert("command".to_string(), Value::String(cmd));
252
253 match format {
254 Format::Json => {
255 let msg_as_json = serde_json::to_string(&msg)?;
256 let message = Message::text(msg_as_json);
257 self.send_message(message)?;
258 {
259 let mut requests = self.state.requests.lock()?;
260 requests.insert(transaction_id, None);
261 }
262 Ok(transaction_id)
263 }
264 }
265 }
266
267 fn send_message(&mut self, message: Message) -> Result<(), Error> {
268 let mut socket = self.state.socket.lock()?;
269 socket.send(message)?;
270 Ok(())
271 }
272
273 pub fn stats(&self) -> Result<Stats, Error> {
274 let cache = self.state.cache.lock()?;
275 Ok(cache.stats.clone())
276 }
277
278 pub fn system(&mut self) -> Result<Arc<System>, Error> {
279 let id = system::get_system_id()?;
280 let name = hostname::get()?.to_str().unwrap().to_string();
281 let args = vec![id.to_string(), name.to_string()];
282 let transaction = self.send(Format::Json, "systems", &args)?;
283 let _response = self.wait_for_response(transaction, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
284 Ok(Arc::new(System::new(self.clone(), id)))
285 }
286
287 pub fn version(&self) -> Result<String, Error> {
288 let cache = self.state.cache.lock()?;
289 Ok(cache.version.clone())
290 }
291
292 pub fn wait_for_open(&self, timeout: Duration) -> Result<(), Error> {
293 let start_time = SystemTime::now();
294 let sleep_for = Duration::from_millis(100);
295 while SystemTime::now().duration_since(start_time)? < timeout {
296 {
297 let cache = self.state.cache.lock()?;
298 if !cache.version.is_empty() {
299 return Ok(());
300 }
301 }
302 std::thread::sleep(sleep_for);
303 }
304 Err(Error::Timeout("Timed out waiting for open".to_string()))
305 }
306
307 pub(crate) fn wait_for_response(&self, transaction: u64, timeout: Duration) -> Result<Response, Error> {
308 let start_time = SystemTime::now();
309 let sleep_for = Duration::from_millis(100);
310 while SystemTime::now().duration_since(start_time)? < timeout {
311 {
312 let requests = self.state.requests.lock()?;
313 let response = match requests.get(&transaction) {
314 None => return Err(Error::Default("No such transaction".to_string())),
315 Some(response) => response.clone(),
316 };
317 if let Some(response) = response {
318 match response.error.as_ref() {
319 None => return Ok(response.clone()),
320 Some(error_msg) => return Err(Error::Default(error_msg.clone())),
321 }
322 }
323 }
324 std::thread::sleep(sleep_for);
325 }
326 Err(Error::Timeout("Timed out waiting for response".to_string()))
327 }
328}