arete_sdk/
connection.rs

1use super::{Error, Stats, system};
2use crate::stats::ConnectionState;
3use serde::Deserialize;
4use serde_json::Value;
5use std::io::ErrorKind;
6use std::time::{Duration, SystemTime};
7use std::{
8    collections::HashMap,
9    net::TcpStream,
10    sync::{
11        Arc, Mutex,
12        atomic::{AtomicU64, Ordering},
13    },
14};
15use strum_macros::{AsRefStr, Display};
16use tungstenite::{Message, WebSocket, stream::MaybeTlsStream};
17use uuid::Uuid;
18
19const DEFAULT_TIMEOUT_SECS: u64 = 5;
20
21#[derive(AsRefStr, Clone, Debug, Display)]
22pub enum Format {
23    #[strum(serialize = "json")]
24    Json,
25}
26
27#[derive(Clone, Debug)]
28struct Response {
29    error: Option<String>,
30}
31
32pub struct Connection {
33    cache: Arc<Mutex<Cache>>,
34    next_transaction_id: AtomicU64,
35    requests: Arc<Mutex<HashMap<u64, Option<Response>>>>,
36    socket: Arc<Mutex<WebSocket<MaybeTlsStream<TcpStream>>>>,
37}
38
39#[derive(Debug, Default, Deserialize)]
40struct Cache {
41    version: String,
42    stats: Stats,
43    keys: HashMap<String, Value>,
44}
45
46#[derive(Clone, Debug, Default, Deserialize)]
47struct SparseStats {
48    started: Option<String>,
49    reads: Option<u32>,
50    writes: Option<u32>,
51    updates: Option<u32>,
52    errors: Option<u32>,
53    connection: Option<ConnectionState>,
54}
55
56#[derive(Debug, Default, Deserialize)]
57struct SparseCache {
58    stats: Option<SparseStats>,
59    version: Option<String>,
60    keys: Option<HashMap<String, Value>>,
61}
62
63impl Connection {
64    pub(crate) fn new(socket: Arc<Mutex<WebSocket<MaybeTlsStream<TcpStream>>>>) -> Self {
65        let next_transaction_id = AtomicU64::new(1);
66        let cache = Arc::new(Mutex::new(Cache::default()));
67        let requests = Arc::new(Mutex::new(HashMap::new()));
68
69        // Spawn a thread to handle incoming messages
70        let socket_2 = socket.clone();
71        let cache_2 = cache.clone();
72        let requests_2 = requests.clone();
73        std::thread::spawn(move || {
74            loop {
75                let maybe_message = {
76                    if let Ok(mut socket) = socket_2.lock() {
77                        match socket.read() {
78                            Ok(message) => Some(message),
79                            Err(e) => match e {
80                                tungstenite::Error::Io(ref e) if e.kind() == ErrorKind::WouldBlock => None,
81                                _ => panic!("{e:?}"),
82                            },
83                        }
84                    } else {
85                        continue;
86                    }
87                };
88                let message = match maybe_message {
89                    Some(message) => message,
90                    None => {
91                        std::thread::sleep(Duration::from_millis(20));
92                        continue;
93                    }
94                };
95                if let Message::Text(ref message) = message {
96                    let incoming: HashMap<String, Value> = serde_json::from_slice(message.as_bytes()).unwrap();
97
98                    if let Some(Value::Number(transaction)) = incoming.get("transaction") {
99                        let transaction: u64 = transaction.as_u64().unwrap();
100                        if let Some(response) = incoming.get("response") {
101                            let mut requests = requests_2.lock().unwrap();
102                            if response.is_string() && response.as_str().unwrap_or_default().is_empty() {
103                                requests.insert(transaction, Some(Response { error: None }));
104                            } else if let Value::Object(response) = response {
105                                if let Some(error_msg) = response.get("error") {
106                                    requests.insert(
107                                        transaction,
108                                        Some(Response {
109                                            error: Some(error_msg.to_string()),
110                                        }),
111                                    );
112                                }
113                            }
114                        }
115                        continue;
116                    }
117
118                    let payload: SparseCache = serde_json::from_slice(message.as_bytes()).unwrap();
119                    if let Ok(mut cache) = cache_2.lock() {
120                        Self::merge(&mut cache, &payload);
121                    }
122                }
123            }
124        });
125
126        Self {
127            cache,
128            next_transaction_id,
129            requests,
130            socket,
131        }
132    }
133
134    pub fn add_node(&mut self, id: &str, name: &str, upstream: bool, token: Option<String>) -> Result<(), Error> {
135        let system_id = system::get_system_id()?;
136        let upstream_arg = if upstream { "yes".to_string() } else { "no".to_string() };
137        let token_arg = token.unwrap_or("$uuid".to_string());
138        let args = vec![
139            system_id.to_string(),
140            id.to_string(),
141            name.to_string(),
142            upstream_arg,
143            token_arg,
144        ];
145        let transaction = self.send(Format::Json, "nodes", &args)?;
146        let _response = self.wait_for_response(transaction, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
147        Ok(())
148    }
149
150    pub fn add_system(&mut self) -> Result<(), Error> {
151        let id = system::get_system_id()?;
152        let name = hostname::get()?.to_str().unwrap().to_string();
153        self.add_system_(&id, &name)?;
154        Ok(())
155    }
156
157    fn add_system_(&mut self, id: &Uuid, name: &str) -> Result<(), Error> {
158        let args = vec![id.to_string(), name.to_string()];
159        let transaction = self.send(Format::Json, "systems", &args)?;
160        let _response = self.wait_for_response(transaction, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
161        Ok(())
162    }
163
164    pub fn get(&self, key: &str, default_value: Option<Value>) -> Result<Option<Value>, Error> {
165        let cache = self.cache.lock()?;
166        let value = match cache.keys.get(key) {
167            Some(value) => Some(value.clone()),
168            None => default_value,
169        };
170        Ok(value)
171    }
172
173    pub fn keys(&self) -> Result<Vec<String>, Error> {
174        let mut vec = vec![];
175        let cache = self.cache.lock()?;
176        for (key, _) in cache.keys.iter() {
177            vec.push(key.clone());
178        }
179        Ok(vec)
180    }
181
182    fn merge(target: &mut Cache, source: &SparseCache) {
183        if let Some(ref stats) = source.stats {
184            if let Some(ref started) = stats.started {
185                target.stats.started = started.clone();
186            }
187            if let Some(reads) = stats.reads {
188                target.stats.reads = reads;
189            }
190            if let Some(writes) = stats.writes {
191                target.stats.writes = writes;
192            }
193            if let Some(updates) = stats.updates {
194                target.stats.updates = updates;
195            }
196            if let Some(errors) = stats.errors {
197                target.stats.errors = errors;
198            }
199            if let Some(ref connection) = stats.connection {
200                target.stats.connection = connection.clone();
201            }
202        }
203        if let Some(ref version) = source.version {
204            target.version = version.clone();
205        }
206        if let Some(ref keys) = source.keys {
207            for (k, v) in keys.iter() {
208                target.keys.insert(k.to_string(), v.clone());
209            }
210        }
211    }
212
213    pub fn put(&mut self, key: &str, value: &str) -> Result<(), Error> {
214        let args = vec![format!("\"{key}\""), value.to_string()];
215        let _ = self.send(Format::Json, "put", &args)?;
216        Ok(())
217    }
218
219    pub fn send(&mut self, format: Format, cmd: &str, args: &[String]) -> Result<u64, Error> {
220        let mut cmd = cmd.to_string();
221        for arg in args {
222            cmd = format!("{cmd} \"{arg}\"");
223        }
224
225        let mut msg = HashMap::new();
226        let transaction_id = self.next_transaction_id.fetch_add(1, Ordering::SeqCst);
227        msg.insert("format".to_string(), Value::String(format.to_string()));
228        msg.insert("transaction".to_string(), Value::from(transaction_id));
229        msg.insert("command".to_string(), Value::String(cmd));
230
231        {
232            let mut requests = self.requests.lock()?;
233            requests.insert(transaction_id, None);
234        }
235
236        match format {
237            Format::Json => {
238                let msg_as_json = serde_json::to_string(&msg)?;
239                let message = Message::text(msg_as_json);
240                self.send_message(message)?;
241                Ok(transaction_id)
242            }
243        }
244    }
245
246    fn send_message(&mut self, message: Message) -> Result<(), Error> {
247        let mut socket = self.socket.lock()?;
248        socket.send(message)?;
249        Ok(())
250    }
251
252    pub fn stats(&self) -> Result<Stats, Error> {
253        let cache = self.cache.lock()?;
254        Ok(cache.stats.clone())
255    }
256
257    pub fn version(&self) -> Result<String, Error> {
258        let cache = self.cache.lock()?;
259        Ok(cache.version.clone())
260    }
261
262    pub fn wait_for_open(&self, timeout: Duration) -> Result<(), Error> {
263        let start_time = SystemTime::now();
264        let sleep_for = Duration::from_millis(100);
265        while SystemTime::now().duration_since(start_time)? < timeout {
266            {
267                let cache = self.cache.lock()?;
268                if !cache.version.is_empty() {
269                    return Ok(());
270                }
271            }
272            std::thread::sleep(sleep_for);
273        }
274        Err(Error::Timeout("Timed out waiting for open".to_string()))
275    }
276
277    fn wait_for_response(&self, transaction: u64, timeout: Duration) -> Result<Response, Error> {
278        let start_time = SystemTime::now();
279        let sleep_for = Duration::from_millis(100);
280        while SystemTime::now().duration_since(start_time)? < timeout {
281            {
282                let requests = self.requests.lock()?;
283                let response = match requests.get(&transaction) {
284                    None => return Err(Error::Default("No such transaction".to_string())),
285                    Some(response) => response.clone(),
286                };
287                if let Some(response) = response {
288                    match response.error.as_ref() {
289                        None => return Ok(response.clone()),
290                        Some(error_msg) => return Err(Error::Default(error_msg.clone())),
291                    }
292                }
293            }
294            std::thread::sleep(sleep_for);
295        }
296        Err(Error::Timeout("Timed out waiting for response".to_string()))
297    }
298}