1use super::{Error, Stats, system};
2use crate::stats::ConnectionState;
3use serde::Deserialize;
4use serde_json::Value;
5use std::io::ErrorKind;
6use std::time::{Duration, SystemTime};
7use std::{
8 collections::HashMap,
9 net::TcpStream,
10 sync::{
11 Arc, Mutex,
12 atomic::{AtomicU64, Ordering},
13 },
14};
15use strum_macros::{AsRefStr, Display};
16use tungstenite::{Message, WebSocket, stream::MaybeTlsStream};
17use uuid::Uuid;
18
19const DEFAULT_TIMEOUT_SECS: u64 = 5;
20
21#[derive(AsRefStr, Clone, Debug, Display)]
22pub enum Format {
23 #[strum(serialize = "json")]
24 Json,
25}
26
27#[derive(Clone, Debug)]
28struct Response {
29 error: Option<String>,
30}
31
32pub struct Connection {
33 cache: Arc<Mutex<Cache>>,
34 next_transaction_id: AtomicU64,
35 requests: Arc<Mutex<HashMap<u64, Option<Response>>>>,
36 socket: Arc<Mutex<WebSocket<MaybeTlsStream<TcpStream>>>>,
37}
38
39#[derive(Debug, Default, Deserialize)]
40struct Cache {
41 version: String,
42 stats: Stats,
43 keys: HashMap<String, Value>,
44}
45
46#[derive(Clone, Debug, Default, Deserialize)]
47struct SparseStats {
48 started: Option<String>,
49 reads: Option<u32>,
50 writes: Option<u32>,
51 updates: Option<u32>,
52 errors: Option<u32>,
53 connection: Option<ConnectionState>,
54}
55
56#[derive(Debug, Default, Deserialize)]
57struct SparseCache {
58 stats: Option<SparseStats>,
59 version: Option<String>,
60 keys: Option<HashMap<String, Value>>,
61}
62
63impl Connection {
64 pub(crate) fn new(socket: Arc<Mutex<WebSocket<MaybeTlsStream<TcpStream>>>>) -> Self {
65 let next_transaction_id = AtomicU64::new(1);
66 let cache = Arc::new(Mutex::new(Cache::default()));
67 let requests = Arc::new(Mutex::new(HashMap::new()));
68
69 let socket_2 = socket.clone();
71 let cache_2 = cache.clone();
72 let requests_2 = requests.clone();
73 std::thread::spawn(move || {
74 loop {
75 let maybe_message = {
76 if let Ok(mut socket) = socket_2.lock() {
77 match socket.read() {
78 Ok(message) => Some(message),
79 Err(e) => match e {
80 tungstenite::Error::Io(ref e) if e.kind() == ErrorKind::WouldBlock => None,
81 _ => panic!("{e:?}"),
82 },
83 }
84 } else {
85 continue;
86 }
87 };
88 let message = match maybe_message {
89 Some(message) => message,
90 None => {
91 std::thread::sleep(Duration::from_millis(20));
92 continue;
93 }
94 };
95 if let Message::Text(ref message) = message {
96 let incoming: HashMap<String, Value> = serde_json::from_slice(message.as_bytes()).unwrap();
97
98 if let Some(Value::Number(transaction)) = incoming.get("transaction") {
99 let transaction: u64 = transaction.as_u64().unwrap();
100 if let Some(response) = incoming.get("response") {
101 let mut requests = requests_2.lock().unwrap();
102 if response.is_string() && response.as_str().unwrap_or_default().is_empty() {
103 requests.insert(transaction, Some(Response { error: None }));
104 } else if let Value::Object(response) = response {
105 if let Some(error_msg) = response.get("error") {
106 requests.insert(
107 transaction,
108 Some(Response {
109 error: Some(error_msg.to_string()),
110 }),
111 );
112 }
113 }
114 }
115 continue;
116 }
117
118 let payload: SparseCache = serde_json::from_slice(message.as_bytes()).unwrap();
119 if let Ok(mut cache) = cache_2.lock() {
120 Self::merge(&mut cache, &payload);
121 }
122 }
123 }
124 });
125
126 Self {
127 cache,
128 next_transaction_id,
129 requests,
130 socket,
131 }
132 }
133
134 pub fn add_node(&mut self, id: &str, name: &str, upstream: bool, token: Option<String>) -> Result<(), Error> {
135 let system_id = system::get_system_id()?;
136 let upstream_arg = if upstream { "yes".to_string() } else { "no".to_string() };
137 let token_arg = token.unwrap_or("$uuid".to_string());
138 let args = vec![
139 system_id.to_string(),
140 id.to_string(),
141 name.to_string(),
142 upstream_arg,
143 token_arg,
144 ];
145 let transaction = self.send(Format::Json, "nodes", &args)?;
146 let _response = self.wait_for_response(transaction, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
147 Ok(())
148 }
149
150 pub fn add_system(&mut self) -> Result<(), Error> {
151 let id = system::get_system_id()?;
152 let name = hostname::get()?.to_str().unwrap().to_string();
153 self.add_system_(&id, &name)?;
154 Ok(())
155 }
156
157 fn add_system_(&mut self, id: &Uuid, name: &str) -> Result<(), Error> {
158 let args = vec![id.to_string(), name.to_string()];
159 let transaction = self.send(Format::Json, "systems", &args)?;
160 let _response = self.wait_for_response(transaction, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
161 Ok(())
162 }
163
164 pub fn get(&self, key: &str, default_value: Option<Value>) -> Result<Option<Value>, Error> {
165 let cache = self.cache.lock()?;
166 let value = match cache.keys.get(key) {
167 Some(value) => Some(value.clone()),
168 None => default_value,
169 };
170 Ok(value)
171 }
172
173 pub fn keys(&self) -> Result<Vec<String>, Error> {
174 let mut vec = vec![];
175 let cache = self.cache.lock()?;
176 for (key, _) in cache.keys.iter() {
177 vec.push(key.clone());
178 }
179 Ok(vec)
180 }
181
182 fn merge(target: &mut Cache, source: &SparseCache) {
183 if let Some(ref stats) = source.stats {
184 if let Some(ref started) = stats.started {
185 target.stats.started = started.clone();
186 }
187 if let Some(reads) = stats.reads {
188 target.stats.reads = reads;
189 }
190 if let Some(writes) = stats.writes {
191 target.stats.writes = writes;
192 }
193 if let Some(updates) = stats.updates {
194 target.stats.updates = updates;
195 }
196 if let Some(errors) = stats.errors {
197 target.stats.errors = errors;
198 }
199 if let Some(ref connection) = stats.connection {
200 target.stats.connection = connection.clone();
201 }
202 }
203 if let Some(ref version) = source.version {
204 target.version = version.clone();
205 }
206 if let Some(ref keys) = source.keys {
207 for (k, v) in keys.iter() {
208 target.keys.insert(k.to_string(), v.clone());
209 }
210 }
211 }
212
213 pub fn put(&mut self, key: &str, value: &str) -> Result<(), Error> {
214 let args = vec![format!("\"{key}\""), value.to_string()];
215 let _ = self.send(Format::Json, "put", &args)?;
216 Ok(())
217 }
218
219 pub fn send(&mut self, format: Format, cmd: &str, args: &[String]) -> Result<u64, Error> {
220 let mut cmd = cmd.to_string();
221 for arg in args {
222 cmd = format!("{cmd} \"{arg}\"");
223 }
224
225 let mut msg = HashMap::new();
226 let transaction_id = self.next_transaction_id.fetch_add(1, Ordering::SeqCst);
227 msg.insert("format".to_string(), Value::String(format.to_string()));
228 msg.insert("transaction".to_string(), Value::from(transaction_id));
229 msg.insert("command".to_string(), Value::String(cmd));
230
231 {
232 let mut requests = self.requests.lock()?;
233 requests.insert(transaction_id, None);
234 }
235
236 match format {
237 Format::Json => {
238 let msg_as_json = serde_json::to_string(&msg)?;
239 let message = Message::text(msg_as_json);
240 self.send_message(message)?;
241 Ok(transaction_id)
242 }
243 }
244 }
245
246 fn send_message(&mut self, message: Message) -> Result<(), Error> {
247 let mut socket = self.socket.lock()?;
248 socket.send(message)?;
249 Ok(())
250 }
251
252 pub fn stats(&self) -> Result<Stats, Error> {
253 let cache = self.cache.lock()?;
254 Ok(cache.stats.clone())
255 }
256
257 pub fn version(&self) -> Result<String, Error> {
258 let cache = self.cache.lock()?;
259 Ok(cache.version.clone())
260 }
261
262 pub fn wait_for_open(&self, timeout: Duration) -> Result<(), Error> {
263 let start_time = SystemTime::now();
264 let sleep_for = Duration::from_millis(100);
265 while SystemTime::now().duration_since(start_time)? < timeout {
266 {
267 let cache = self.cache.lock()?;
268 if !cache.version.is_empty() {
269 return Ok(());
270 }
271 }
272 std::thread::sleep(sleep_for);
273 }
274 Err(Error::Timeout("Timed out waiting for open".to_string()))
275 }
276
277 fn wait_for_response(&self, transaction: u64, timeout: Duration) -> Result<Response, Error> {
278 let start_time = SystemTime::now();
279 let sleep_for = Duration::from_millis(100);
280 while SystemTime::now().duration_since(start_time)? < timeout {
281 {
282 let requests = self.requests.lock()?;
283 let response = match requests.get(&transaction) {
284 None => return Err(Error::Default("No such transaction".to_string())),
285 Some(response) => response.clone(),
286 };
287 if let Some(response) = response {
288 match response.error.as_ref() {
289 None => return Ok(response.clone()),
290 Some(error_msg) => return Err(Error::Default(error_msg.clone())),
291 }
292 }
293 }
294 std::thread::sleep(sleep_for);
295 }
296 Err(Error::Timeout("Timed out waiting for response".to_string()))
297 }
298}