arete_sdk/
client.rs

1use crate::{Cache, Error, Stats, System, stats::ConnectionState, system};
2use serde::Deserialize;
3use serde_json::Value;
4use std::{
5    collections::HashMap,
6    io::ErrorKind,
7    net::TcpStream,
8    sync::{
9        Arc, Mutex,
10        atomic::{AtomicU64, Ordering},
11        mpsc::{self, Receiver, Sender},
12    },
13    time::{Duration, SystemTime},
14};
15use strum_macros::{AsRefStr, Display};
16use tungstenite::{Message, WebSocket, stream::MaybeTlsStream};
17
18pub const DEFAULT_TIMEOUT_SECS: u64 = 5;
19
20#[derive(AsRefStr, Clone, Debug, Display)]
21pub enum Format {
22    #[strum(serialize = "json")]
23    Json,
24}
25
26#[derive(Clone, Debug)]
27pub(crate) struct Response {
28    error: Option<String>,
29}
30
31struct State {
32    cache: Arc<Mutex<Cache>>,
33    next_transaction_id: AtomicU64,
34    requests: Arc<Mutex<HashMap<u64, Option<Response>>>>,
35    socket: Arc<Mutex<WebSocket<MaybeTlsStream<TcpStream>>>>,
36    subscribers: Arc<Mutex<Vec<Sender<Cache>>>>,
37}
38
39#[derive(Clone)]
40pub struct Client {
41    state: Arc<State>,
42}
43
44#[derive(Clone, Debug, Default, Deserialize)]
45struct SparseStats {
46    started: Option<String>,
47    reads: Option<u32>,
48    writes: Option<u32>,
49    updates: Option<u32>,
50    errors: Option<u32>,
51    connection: Option<ConnectionState>,
52}
53
54#[derive(Debug, Default, Deserialize)]
55struct SparseCache {
56    stats: Option<SparseStats>,
57    version: Option<String>,
58    keys: Option<HashMap<String, Value>>,
59}
60
61impl From<SparseCache> for Cache {
62    fn from(other: SparseCache) -> Self {
63        let mut cache = Cache::default();
64        if let Some(stats) = other.stats {
65            if let Some(started) = stats.started {
66                cache.stats.started = started.clone();
67            }
68            if let Some(reads) = stats.reads {
69                cache.stats.reads = reads;
70            }
71            if let Some(writes) = stats.writes {
72                cache.stats.writes = writes;
73            }
74            if let Some(updates) = stats.updates {
75                cache.stats.updates = updates;
76            }
77            if let Some(errors) = stats.errors {
78                cache.stats.errors = errors;
79            }
80            if let Some(connection) = stats.connection {
81                cache.stats.connection = connection.clone();
82            }
83        }
84        if let Some(version) = other.version {
85            cache.version = version.clone();
86        }
87        if let Some(keys) = other.keys {
88            cache.keys = keys.clone();
89        }
90        cache
91    }
92}
93
94impl Client {
95    pub(crate) fn new(socket: Arc<Mutex<WebSocket<MaybeTlsStream<TcpStream>>>>) -> Self {
96        let next_transaction_id = AtomicU64::new(1);
97        let cache = Arc::new(Mutex::new(Cache::default()));
98        let requests = Arc::new(Mutex::new(HashMap::new()));
99        let subscribers = Arc::new(Mutex::<Vec<Sender<Cache>>>::new(vec![]));
100
101        // Spawn a thread to handle incoming messages
102        let socket_2 = socket.clone();
103        let cache_2 = cache.clone();
104        let requests_2 = requests.clone();
105        let subscribers_2 = subscribers.clone();
106        std::thread::spawn(move || {
107            loop {
108                let maybe_message = {
109                    if let Ok(mut socket) = socket_2.lock() {
110                        match socket.read() {
111                            Ok(message) => Some(message),
112                            Err(e) => match e {
113                                tungstenite::Error::Io(ref e) if e.kind() == ErrorKind::WouldBlock => None,
114                                _ => panic!("{e:?}"),
115                            },
116                        }
117                    } else {
118                        continue;
119                    }
120                };
121                let message = match maybe_message {
122                    Some(message) => message,
123                    None => {
124                        std::thread::sleep(Duration::from_millis(20));
125                        continue;
126                    }
127                };
128                if let Message::Text(ref message) = message {
129                    let incoming: HashMap<String, Value> = serde_json::from_slice(message.as_bytes()).unwrap();
130
131                    if let Some(Value::Number(transaction)) = incoming.get("transaction") {
132                        let transaction: u64 = transaction.as_u64().unwrap();
133                        if let Some(response) = incoming.get("response") {
134                            let mut requests = requests_2.lock().unwrap();
135                            if response.is_string() && response.as_str().unwrap_or_default().is_empty() {
136                                requests.insert(transaction, Some(Response { error: None }));
137                            } else if let Value::Object(response) = response {
138                                if let Some(error_msg) = response.get("error") {
139                                    requests.insert(
140                                        transaction,
141                                        Some(Response {
142                                            error: Some(error_msg.to_string()),
143                                        }),
144                                    );
145                                }
146                            }
147                        }
148                        continue;
149                    }
150
151                    // Cache the new update
152                    let payload: SparseCache = serde_json::from_slice(message.as_bytes()).unwrap();
153                    if let Ok(mut cache) = cache_2.lock() {
154                        Self::merge(&mut cache, &payload);
155                    }
156
157                    // Notify subscribers of the new update
158                    {
159                        let subscribers = subscribers_2.lock().unwrap();
160                        let event: Cache = payload.into();
161                        for tx in subscribers.iter() {
162                            tx.send(event.clone()).unwrap();
163                        }
164                    }
165                }
166            }
167        });
168
169        let state = Arc::new(State {
170            cache,
171            next_transaction_id,
172            requests,
173            socket,
174            subscribers,
175        });
176        Self { state }
177    }
178
179    pub fn get(&self, key: &str, default_value: Option<Value>) -> Result<Option<Value>, Error> {
180        let cache = self.state.cache.lock()?;
181        let value = match cache.keys.get(key) {
182            Some(value) => Some(value.clone()),
183            None => default_value,
184        };
185        Ok(value)
186    }
187
188    pub fn keys(&self) -> Result<Vec<(String, Value)>, Error> {
189        let mut vec = vec![];
190        let cache = self.state.cache.lock()?;
191        for (k, v) in cache.keys.iter() {
192            vec.push((k.clone(), v.clone()));
193        }
194        Ok(vec)
195    }
196
197    fn merge(target: &mut Cache, source: &SparseCache) {
198        if let Some(ref stats) = source.stats {
199            if let Some(ref started) = stats.started {
200                target.stats.started = started.clone();
201            }
202            if let Some(reads) = stats.reads {
203                target.stats.reads = reads;
204            }
205            if let Some(writes) = stats.writes {
206                target.stats.writes = writes;
207            }
208            if let Some(updates) = stats.updates {
209                target.stats.updates = updates;
210            }
211            if let Some(errors) = stats.errors {
212                target.stats.errors = errors;
213            }
214            if let Some(ref connection) = stats.connection {
215                target.stats.connection = connection.clone();
216            }
217        }
218        if let Some(ref version) = source.version {
219            target.version = version.clone();
220        }
221        if let Some(ref keys) = source.keys {
222            for (k, v) in keys.iter() {
223                target.keys.insert(k.to_string(), v.clone());
224            }
225        }
226    }
227
228    pub fn on_update(&mut self) -> Result<Receiver<Cache>, Error> {
229        let (tx, rx) = mpsc::channel();
230        let mut subscribers = self.state.subscribers.lock()?;
231        subscribers.push(tx);
232        Ok(rx)
233    }
234
235    pub fn put(&mut self, key: &str, value: &str) -> Result<(), Error> {
236        let args = vec![key.to_string(), value.to_string()];
237        let _ = self.send(Format::Json, "put", &args)?;
238        Ok(())
239    }
240
241    pub fn send(&mut self, format: Format, cmd: &str, args: &[String]) -> Result<u64, Error> {
242        let mut cmd = cmd.to_string();
243        for arg in args {
244            cmd = format!("{cmd} \"{arg}\"");
245        }
246
247        let mut msg = HashMap::new();
248        let transaction_id = self.state.next_transaction_id.fetch_add(1, Ordering::SeqCst);
249        msg.insert("format".to_string(), Value::String(format.to_string()));
250        msg.insert("transaction".to_string(), Value::from(transaction_id));
251        msg.insert("command".to_string(), Value::String(cmd));
252
253        match format {
254            Format::Json => {
255                let msg_as_json = serde_json::to_string(&msg)?;
256                let message = Message::text(msg_as_json);
257                self.send_message(message)?;
258                {
259                    let mut requests = self.state.requests.lock()?;
260                    requests.insert(transaction_id, None);
261                }
262                Ok(transaction_id)
263            }
264        }
265    }
266
267    fn send_message(&mut self, message: Message) -> Result<(), Error> {
268        let mut socket = self.state.socket.lock()?;
269        socket.send(message)?;
270        Ok(())
271    }
272
273    pub fn stats(&self) -> Result<Stats, Error> {
274        let cache = self.state.cache.lock()?;
275        Ok(cache.stats.clone())
276    }
277
278    pub fn system(&mut self) -> Result<Arc<System>, Error> {
279        let id = system::get_system_id()?;
280        let name = hostname::get()?.to_str().unwrap().to_string();
281        let args = vec![id.to_string(), name.to_string()];
282        let transaction = self.send(Format::Json, "systems", &args)?;
283        let _response = self.wait_for_response(transaction, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
284        Ok(Arc::new(System::new(self.clone(), id)))
285    }
286
287    pub fn version(&self) -> Result<String, Error> {
288        let cache = self.state.cache.lock()?;
289        Ok(cache.version.clone())
290    }
291
292    pub fn wait_for_open(&self, timeout: Duration) -> Result<(), Error> {
293        let start_time = SystemTime::now();
294        let sleep_for = Duration::from_millis(100);
295        while SystemTime::now().duration_since(start_time)? < timeout {
296            {
297                let cache = self.state.cache.lock()?;
298                if !cache.version.is_empty() {
299                    return Ok(());
300                }
301            }
302            std::thread::sleep(sleep_for);
303        }
304        Err(Error::Timeout("Timed out waiting for open".to_string()))
305    }
306
307    pub(crate) fn wait_for_response(&self, transaction: u64, timeout: Duration) -> Result<Response, Error> {
308        let start_time = SystemTime::now();
309        let sleep_for = Duration::from_millis(100);
310        while SystemTime::now().duration_since(start_time)? < timeout {
311            {
312                let requests = self.state.requests.lock()?;
313                let response = match requests.get(&transaction) {
314                    None => return Err(Error::Default("No such transaction".to_string())),
315                    Some(response) => response.clone(),
316                };
317                if let Some(response) = response {
318                    match response.error.as_ref() {
319                        None => return Ok(response.clone()),
320                        Some(error_msg) => return Err(Error::Default(error_msg.clone())),
321                    }
322                }
323            }
324            std::thread::sleep(sleep_for);
325        }
326        Err(Error::Timeout("Timed out waiting for response".to_string()))
327    }
328}