arete_sdk/
client.rs

1use super::{Cache, Error, Stats, system};
2use crate::stats::ConnectionState;
3use serde::Deserialize;
4use serde_json::Value;
5use std::{
6    collections::HashMap,
7    io::ErrorKind,
8    net::TcpStream,
9    sync::{
10        Arc, Mutex,
11        atomic::{AtomicU64, Ordering},
12        mpsc::{self, Receiver, Sender},
13    },
14    time::{Duration, SystemTime},
15};
16use strum_macros::{AsRefStr, Display};
17use tungstenite::{Message, WebSocket, stream::MaybeTlsStream};
18use uuid::Uuid;
19
20const DEFAULT_TIMEOUT_SECS: u64 = 5;
21
22#[derive(AsRefStr, Clone, Debug, Display)]
23pub enum Format {
24    #[strum(serialize = "json")]
25    Json,
26}
27
28#[derive(Clone, Debug)]
29struct Response {
30    error: Option<String>,
31}
32
33pub struct Client {
34    cache: Arc<Mutex<Cache>>,
35    next_transaction_id: AtomicU64,
36    requests: Arc<Mutex<HashMap<u64, Option<Response>>>>,
37    socket: Arc<Mutex<WebSocket<MaybeTlsStream<TcpStream>>>>,
38    subscribers: Arc<Mutex<Vec<Sender<Cache>>>>,
39}
40
41#[derive(Clone, Debug, Default, Deserialize)]
42struct SparseStats {
43    started: Option<String>,
44    reads: Option<u32>,
45    writes: Option<u32>,
46    updates: Option<u32>,
47    errors: Option<u32>,
48    connection: Option<ConnectionState>,
49}
50
51#[derive(Debug, Default, Deserialize)]
52struct SparseCache {
53    stats: Option<SparseStats>,
54    version: Option<String>,
55    keys: Option<HashMap<String, Value>>,
56}
57
58impl From<SparseCache> for Cache {
59    fn from(other: SparseCache) -> Self {
60        let mut cache = Cache::default();
61        if let Some(stats) = other.stats {
62            if let Some(started) = stats.started {
63                cache.stats.started = started.clone();
64            }
65            if let Some(reads) = stats.reads {
66                cache.stats.reads = reads;
67            }
68            if let Some(writes) = stats.writes {
69                cache.stats.writes = writes;
70            }
71            if let Some(updates) = stats.updates {
72                cache.stats.updates = updates;
73            }
74            if let Some(errors) = stats.errors {
75                cache.stats.errors = errors;
76            }
77            if let Some(connection) = stats.connection {
78                cache.stats.connection = connection.clone();
79            }
80        }
81        if let Some(version) = other.version {
82            cache.version = version.clone();
83        }
84        if let Some(keys) = other.keys {
85            cache.keys = keys.clone();
86        }
87        cache
88    }
89}
90
91impl Client {
92    pub(crate) fn new(socket: Arc<Mutex<WebSocket<MaybeTlsStream<TcpStream>>>>) -> Self {
93        let next_transaction_id = AtomicU64::new(1);
94        let cache = Arc::new(Mutex::new(Cache::default()));
95        let requests = Arc::new(Mutex::new(HashMap::new()));
96        let subscribers = Arc::new(Mutex::<Vec<Sender<Cache>>>::new(vec![]));
97
98        // Spawn a thread to handle incoming messages
99        let socket_2 = socket.clone();
100        let cache_2 = cache.clone();
101        let requests_2 = requests.clone();
102        let subscribers_2 = subscribers.clone();
103        std::thread::spawn(move || {
104            loop {
105                let maybe_message = {
106                    if let Ok(mut socket) = socket_2.lock() {
107                        match socket.read() {
108                            Ok(message) => Some(message),
109                            Err(e) => match e {
110                                tungstenite::Error::Io(ref e) if e.kind() == ErrorKind::WouldBlock => None,
111                                _ => panic!("{e:?}"),
112                            },
113                        }
114                    } else {
115                        continue;
116                    }
117                };
118                let message = match maybe_message {
119                    Some(message) => message,
120                    None => {
121                        std::thread::sleep(Duration::from_millis(20));
122                        continue;
123                    }
124                };
125                if let Message::Text(ref message) = message {
126                    let incoming: HashMap<String, Value> = serde_json::from_slice(message.as_bytes()).unwrap();
127
128                    if let Some(Value::Number(transaction)) = incoming.get("transaction") {
129                        let transaction: u64 = transaction.as_u64().unwrap();
130                        if let Some(response) = incoming.get("response") {
131                            let mut requests = requests_2.lock().unwrap();
132                            if response.is_string() && response.as_str().unwrap_or_default().is_empty() {
133                                requests.insert(transaction, Some(Response { error: None }));
134                            } else if let Value::Object(response) = response {
135                                if let Some(error_msg) = response.get("error") {
136                                    requests.insert(
137                                        transaction,
138                                        Some(Response {
139                                            error: Some(error_msg.to_string()),
140                                        }),
141                                    );
142                                }
143                            }
144                        }
145                        continue;
146                    }
147
148                    // Cache the new update
149                    let payload: SparseCache = serde_json::from_slice(message.as_bytes()).unwrap();
150                    if let Ok(mut cache) = cache_2.lock() {
151                        Self::merge(&mut cache, &payload);
152                    }
153
154                    // Notify subscribers of the new update
155                    {
156                        let subscribers = subscribers_2.lock().unwrap();
157                        let event: Cache = payload.into();
158                        for tx in subscribers.iter() {
159                            tx.send(event.clone()).unwrap();
160                        }
161                    }
162                }
163            }
164        });
165
166        Self {
167            cache,
168            next_transaction_id,
169            requests,
170            socket,
171            subscribers,
172        }
173    }
174
175    pub fn add_consumer(&mut self, node_id: &str, context_id: &str, profile: &str) -> Result<(), Error> {
176        let system_id = system::get_system_id()?;
177        let args = vec![
178            system_id.to_string(),
179            node_id.to_string(),
180            context_id.to_string(),
181            profile.to_string(),
182        ];
183        let transaction = self.send(Format::Json, "consumers", &args)?;
184        let _response = self.wait_for_response(transaction, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
185        Ok(())
186    }
187
188    pub fn add_context(&mut self, node_id: &str, id: &str, name: &str) -> Result<(), Error> {
189        let system_id = system::get_system_id()?;
190        let args = vec![
191            system_id.to_string(),
192            node_id.to_string(),
193            id.to_string(),
194            name.to_string(),
195        ];
196        let transaction = self.send(Format::Json, "contexts", &args)?;
197        let _response = self.wait_for_response(transaction, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
198        Ok(())
199    }
200
201    pub fn add_node(&mut self, id: &str, name: &str, upstream: bool, token: Option<String>) -> Result<(), Error> {
202        let system_id = system::get_system_id()?;
203        let upstream_arg = if upstream { "yes".to_string() } else { "no".to_string() };
204        let token_arg = token.unwrap_or("$uuid".to_string());
205        let args = vec![
206            system_id.to_string(),
207            id.to_string(),
208            name.to_string(),
209            upstream_arg,
210            token_arg,
211        ];
212        let transaction = self.send(Format::Json, "nodes", &args)?;
213        let _response = self.wait_for_response(transaction, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
214        Ok(())
215    }
216
217    pub fn add_provider(&mut self, node_id: &str, context_id: &str, profile: &str) -> Result<(), Error> {
218        let system_id = system::get_system_id()?;
219        let args = vec![
220            system_id.to_string(),
221            node_id.to_string(),
222            context_id.to_string(),
223            profile.to_string(),
224        ];
225        let transaction = self.send(Format::Json, "providers", &args)?;
226        let _response = self.wait_for_response(transaction, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
227        Ok(())
228    }
229
230    pub fn add_system(&mut self) -> Result<(), Error> {
231        let id = system::get_system_id()?;
232        let name = hostname::get()?.to_str().unwrap().to_string();
233        self.add_system_(&id, &name)?;
234        Ok(())
235    }
236
237    fn add_system_(&mut self, id: &Uuid, name: &str) -> Result<(), Error> {
238        let args = vec![id.to_string(), name.to_string()];
239        let transaction = self.send(Format::Json, "systems", &args)?;
240        let _response = self.wait_for_response(transaction, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
241        Ok(())
242    }
243
244    pub fn get(&self, key: &str, default_value: Option<Value>) -> Result<Option<Value>, Error> {
245        let cache = self.cache.lock()?;
246        let value = match cache.keys.get(key) {
247            Some(value) => Some(value.clone()),
248            None => default_value,
249        };
250        Ok(value)
251    }
252
253    pub fn keys(&self) -> Result<Vec<String>, Error> {
254        let mut vec = vec![];
255        let cache = self.cache.lock()?;
256        for (key, _) in cache.keys.iter() {
257            vec.push(key.clone());
258        }
259        Ok(vec)
260    }
261
262    fn merge(target: &mut Cache, source: &SparseCache) {
263        if let Some(ref stats) = source.stats {
264            if let Some(ref started) = stats.started {
265                target.stats.started = started.clone();
266            }
267            if let Some(reads) = stats.reads {
268                target.stats.reads = reads;
269            }
270            if let Some(writes) = stats.writes {
271                target.stats.writes = writes;
272            }
273            if let Some(updates) = stats.updates {
274                target.stats.updates = updates;
275            }
276            if let Some(errors) = stats.errors {
277                target.stats.errors = errors;
278            }
279            if let Some(ref connection) = stats.connection {
280                target.stats.connection = connection.clone();
281            }
282        }
283        if let Some(ref version) = source.version {
284            target.version = version.clone();
285        }
286        if let Some(ref keys) = source.keys {
287            for (k, v) in keys.iter() {
288                target.keys.insert(k.to_string(), v.clone());
289            }
290        }
291    }
292
293    pub fn on_update(&mut self) -> Result<Receiver<Cache>, Error> {
294        let (tx, rx) = mpsc::channel();
295        let mut subscribers = self.subscribers.lock()?;
296        subscribers.push(tx);
297        Ok(rx)
298    }
299
300    pub fn put(&mut self, key: &str, value: &str) -> Result<(), Error> {
301        let args = vec![key.to_string(), value.to_string()];
302        let _ = self.send(Format::Json, "put", &args)?;
303        Ok(())
304    }
305
306    pub fn put_property(
307        &mut self,
308        node_id: &str,
309        context_id: &str,
310        profile: &str,
311        property: &str,
312        value: &str,
313    ) -> Result<(), Error> {
314        let system_id = system::get_system_id()?;
315        let key =
316            format!("cns/{system_id}/nodes/{node_id}/contexts/{context_id}/provider/{profile}/properties/{property}");
317        self.put(&key, value)
318    }
319
320    pub fn send(&mut self, format: Format, cmd: &str, args: &[String]) -> Result<u64, Error> {
321        let mut cmd = cmd.to_string();
322        for arg in args {
323            cmd = format!("{cmd} \"{arg}\"");
324        }
325
326        let mut msg = HashMap::new();
327        let transaction_id = self.next_transaction_id.fetch_add(1, Ordering::SeqCst);
328        msg.insert("format".to_string(), Value::String(format.to_string()));
329        msg.insert("transaction".to_string(), Value::from(transaction_id));
330        msg.insert("command".to_string(), Value::String(cmd));
331
332        match format {
333            Format::Json => {
334                let msg_as_json = serde_json::to_string(&msg)?;
335                let message = Message::text(msg_as_json);
336                self.send_message(message)?;
337                {
338                    let mut requests = self.requests.lock()?;
339                    requests.insert(transaction_id, None);
340                }
341                Ok(transaction_id)
342            }
343        }
344    }
345
346    fn send_message(&mut self, message: Message) -> Result<(), Error> {
347        let mut socket = self.socket.lock()?;
348        socket.send(message)?;
349        Ok(())
350    }
351
352    pub fn stats(&self) -> Result<Stats, Error> {
353        let cache = self.cache.lock()?;
354        Ok(cache.stats.clone())
355    }
356
357    pub fn version(&self) -> Result<String, Error> {
358        let cache = self.cache.lock()?;
359        Ok(cache.version.clone())
360    }
361
362    pub fn wait_for_open(&self, timeout: Duration) -> Result<(), Error> {
363        let start_time = SystemTime::now();
364        let sleep_for = Duration::from_millis(100);
365        while SystemTime::now().duration_since(start_time)? < timeout {
366            {
367                let cache = self.cache.lock()?;
368                if !cache.version.is_empty() {
369                    return Ok(());
370                }
371            }
372            std::thread::sleep(sleep_for);
373        }
374        Err(Error::Timeout("Timed out waiting for open".to_string()))
375    }
376
377    fn wait_for_response(&self, transaction: u64, timeout: Duration) -> Result<Response, Error> {
378        let start_time = SystemTime::now();
379        let sleep_for = Duration::from_millis(100);
380        while SystemTime::now().duration_since(start_time)? < timeout {
381            {
382                let requests = self.requests.lock()?;
383                let response = match requests.get(&transaction) {
384                    None => return Err(Error::Default("No such transaction".to_string())),
385                    Some(response) => response.clone(),
386                };
387                if let Some(response) = response {
388                    match response.error.as_ref() {
389                        None => return Ok(response.clone()),
390                        Some(error_msg) => return Err(Error::Default(error_msg.clone())),
391                    }
392                }
393            }
394            std::thread::sleep(sleep_for);
395        }
396        Err(Error::Timeout("Timed out waiting for response".to_string()))
397    }
398}