arete_sdk/
connection.rs

1use super::{Error, Stats};
2use serde::Deserialize;
3use serde_json::Value;
4use std::io::ErrorKind;
5use std::time::{Duration, SystemTime};
6use std::{
7    collections::HashMap,
8    net::TcpStream,
9    sync::{
10        Arc, Mutex,
11        atomic::{AtomicU64, Ordering},
12    },
13};
14use strum_macros::{AsRefStr, Display};
15use tungstenite::{Message, WebSocket, stream::MaybeTlsStream};
16
17#[derive(AsRefStr, Clone, Debug, Display)]
18pub enum Format {
19    #[strum(serialize = "json")]
20    Json,
21}
22
23pub struct Connection {
24    socket: Arc<Mutex<WebSocket<MaybeTlsStream<TcpStream>>>>,
25    next_transaction_id: AtomicU64,
26    cache: Arc<Mutex<Cache>>,
27}
28
29#[derive(Debug, Default, Deserialize)]
30struct Cache {
31    version: String,
32    stats: Stats,
33    keys: HashMap<String, Value>,
34}
35
36impl Connection {
37    pub(crate) fn new(socket: Arc<Mutex<WebSocket<MaybeTlsStream<TcpStream>>>>) -> Self {
38        let next_transaction_id = AtomicU64::new(1);
39        let cache = Arc::new(Mutex::new(Cache::default()));
40
41        // Spawn a thread to handle incoming messages
42        let socket_2 = socket.clone();
43        let cache_2 = cache.clone();
44        std::thread::spawn(move || {
45            loop {
46                let maybe_message = {
47                    if let Ok(mut socket) = socket_2.lock() {
48                        match socket.read() {
49                            Ok(message) => Some(message),
50                            Err(e) => match e {
51                                tungstenite::Error::Io(ref e) if e.kind() == ErrorKind::WouldBlock => None,
52                                _ => panic!("{e:?}"),
53                            },
54                        }
55                    } else {
56                        continue;
57                    }
58                };
59                let message = match maybe_message {
60                    Some(message) => message,
61                    None => {
62                        std::thread::sleep(Duration::from_millis(20));
63                        continue;
64                    }
65                };
66                if let Message::Text(ref message) = message {
67                    let payload: Cache = serde_json::from_slice(message.as_bytes()).unwrap();
68                    if let Ok(mut cache) = cache_2.lock() {
69                        Self::merge(&mut cache, &payload);
70                    }
71                }
72            }
73        });
74
75        Self {
76            socket,
77            next_transaction_id,
78            cache,
79        }
80    }
81
82    pub fn get(&self, key: &str, default_value: Option<Value>) -> Result<Option<Value>, Error> {
83        let cache = self.cache.lock()?;
84        let value = match cache.keys.get(key) {
85            Some(value) => Some(value.clone()),
86            None => default_value,
87        };
88        Ok(value)
89    }
90
91    pub fn keys(&self) -> Result<Vec<String>, Error> {
92        let mut vec = vec![];
93        let cache = self.cache.lock()?;
94        for (key, _) in cache.keys.iter() {
95            vec.push(key.clone());
96        }
97        Ok(vec)
98    }
99
100    fn merge(target: &mut Cache, source: &Cache) {
101        target.stats = source.stats.clone();
102        target.version = source.version.clone();
103        for (k, v) in source.keys.iter() {
104            target.keys.insert(k.to_string(), v.clone());
105        }
106    }
107
108    pub fn put(&mut self, key: &str, value: &str) -> Result<(), Error> {
109        let args = vec![format!("\"{key}\""), value.to_string()];
110        self.send(Format::Json, "put", &args)
111    }
112
113    pub fn send(&mut self, format: Format, cmd: &str, args: &[String]) -> Result<(), Error> {
114        let mut cmd = cmd.to_string();
115        for arg in args {
116            cmd = format!("{cmd} \"{arg}\"");
117        }
118
119        let mut msg = HashMap::new();
120        let transaction_id = self.next_transaction_id.fetch_add(1, Ordering::SeqCst);
121        msg.insert("format".to_string(), Value::String(format.to_string()));
122        msg.insert("transaction".to_string(), Value::from(transaction_id));
123        msg.insert("command".to_string(), Value::String(cmd));
124
125        let msg_as_json = serde_json::to_string(&msg)?;
126        let message = Message::text(msg_as_json);
127        self.send_message(message)
128    }
129
130    fn send_message(&mut self, message: Message) -> Result<(), Error> {
131        let mut socket = self.socket.lock()?;
132        socket.send(message)?;
133        Ok(())
134    }
135
136    pub fn stats(&self) -> Result<Stats, Error> {
137        let cache = self.cache.lock()?;
138        Ok(cache.stats.clone())
139    }
140
141    pub fn version(&self) -> Result<String, Error> {
142        let cache = self.cache.lock()?;
143        Ok(cache.version.clone())
144    }
145
146    pub fn wait_for_open(&self, timeout: Duration) -> Result<(), Error> {
147        let start_time = SystemTime::now();
148        let sleep_for = Duration::from_millis(100);
149        while SystemTime::now().duration_since(start_time)? < timeout {
150            {
151                let cache = self.cache.lock()?;
152                if !cache.version.is_empty() {
153                    return Ok(());
154                }
155            }
156            std::thread::sleep(sleep_for);
157        }
158        Err(Error::Timeout("Timed out waiting for open".to_string()))
159    }
160}