1use super::{Cache, Error, Stats, system};
2use crate::stats::ConnectionState;
3use serde::Deserialize;
4use serde_json::Value;
5use std::{
6 collections::HashMap,
7 io::ErrorKind,
8 net::TcpStream,
9 sync::{
10 Arc, Mutex,
11 atomic::{AtomicU64, Ordering},
12 mpsc::{self, Receiver, Sender},
13 },
14 time::{Duration, SystemTime},
15};
16use strum_macros::{AsRefStr, Display};
17use tungstenite::{Message, WebSocket, stream::MaybeTlsStream};
18use uuid::Uuid;
19
20const DEFAULT_TIMEOUT_SECS: u64 = 5;
21
22#[derive(AsRefStr, Clone, Debug, Display)]
23pub enum Format {
24 #[strum(serialize = "json")]
25 Json,
26}
27
28#[derive(Clone, Debug)]
29struct Response {
30 error: Option<String>,
31}
32
33pub struct Client {
34 cache: Arc<Mutex<Cache>>,
35 next_transaction_id: AtomicU64,
36 requests: Arc<Mutex<HashMap<u64, Option<Response>>>>,
37 socket: Arc<Mutex<WebSocket<MaybeTlsStream<TcpStream>>>>,
38 subscribers: Arc<Mutex<Vec<Sender<Cache>>>>,
39}
40
41#[derive(Clone, Debug, Default, Deserialize)]
42struct SparseStats {
43 started: Option<String>,
44 reads: Option<u32>,
45 writes: Option<u32>,
46 updates: Option<u32>,
47 errors: Option<u32>,
48 connection: Option<ConnectionState>,
49}
50
51#[derive(Debug, Default, Deserialize)]
52struct SparseCache {
53 stats: Option<SparseStats>,
54 version: Option<String>,
55 keys: Option<HashMap<String, Value>>,
56}
57
58impl From<SparseCache> for Cache {
59 fn from(other: SparseCache) -> Self {
60 let mut cache = Cache::default();
61 if let Some(stats) = other.stats {
62 if let Some(started) = stats.started {
63 cache.stats.started = started.clone();
64 }
65 if let Some(reads) = stats.reads {
66 cache.stats.reads = reads;
67 }
68 if let Some(writes) = stats.writes {
69 cache.stats.writes = writes;
70 }
71 if let Some(updates) = stats.updates {
72 cache.stats.updates = updates;
73 }
74 if let Some(errors) = stats.errors {
75 cache.stats.errors = errors;
76 }
77 if let Some(connection) = stats.connection {
78 cache.stats.connection = connection.clone();
79 }
80 }
81 if let Some(version) = other.version {
82 cache.version = version.clone();
83 }
84 if let Some(keys) = other.keys {
85 cache.keys = keys.clone();
86 }
87 cache
88 }
89}
90
91impl Client {
92 pub(crate) fn new(socket: Arc<Mutex<WebSocket<MaybeTlsStream<TcpStream>>>>) -> Self {
93 let next_transaction_id = AtomicU64::new(1);
94 let cache = Arc::new(Mutex::new(Cache::default()));
95 let requests = Arc::new(Mutex::new(HashMap::new()));
96 let subscribers = Arc::new(Mutex::<Vec<Sender<Cache>>>::new(vec![]));
97
98 let socket_2 = socket.clone();
100 let cache_2 = cache.clone();
101 let requests_2 = requests.clone();
102 let subscribers_2 = subscribers.clone();
103 std::thread::spawn(move || {
104 loop {
105 let maybe_message = {
106 if let Ok(mut socket) = socket_2.lock() {
107 match socket.read() {
108 Ok(message) => Some(message),
109 Err(e) => match e {
110 tungstenite::Error::Io(ref e) if e.kind() == ErrorKind::WouldBlock => None,
111 _ => panic!("{e:?}"),
112 },
113 }
114 } else {
115 continue;
116 }
117 };
118 let message = match maybe_message {
119 Some(message) => message,
120 None => {
121 std::thread::sleep(Duration::from_millis(20));
122 continue;
123 }
124 };
125 if let Message::Text(ref message) = message {
126 let incoming: HashMap<String, Value> = serde_json::from_slice(message.as_bytes()).unwrap();
127
128 if let Some(Value::Number(transaction)) = incoming.get("transaction") {
129 let transaction: u64 = transaction.as_u64().unwrap();
130 if let Some(response) = incoming.get("response") {
131 let mut requests = requests_2.lock().unwrap();
132 if response.is_string() && response.as_str().unwrap_or_default().is_empty() {
133 requests.insert(transaction, Some(Response { error: None }));
134 } else if let Value::Object(response) = response {
135 if let Some(error_msg) = response.get("error") {
136 requests.insert(
137 transaction,
138 Some(Response {
139 error: Some(error_msg.to_string()),
140 }),
141 );
142 }
143 }
144 }
145 continue;
146 }
147
148 let payload: SparseCache = serde_json::from_slice(message.as_bytes()).unwrap();
150 if let Ok(mut cache) = cache_2.lock() {
151 Self::merge(&mut cache, &payload);
152 }
153
154 {
156 let subscribers = subscribers_2.lock().unwrap();
157 let event: Cache = payload.into();
158 for tx in subscribers.iter() {
159 tx.send(event.clone()).unwrap();
160 }
161 }
162 }
163 }
164 });
165
166 Self {
167 cache,
168 next_transaction_id,
169 requests,
170 socket,
171 subscribers,
172 }
173 }
174
175 pub fn add_consumer(&mut self, node_id: &str, context_id: &str, profile: &str) -> Result<(), Error> {
176 let system_id = system::get_system_id()?;
177 let args = vec![
178 system_id.to_string(),
179 node_id.to_string(),
180 context_id.to_string(),
181 profile.to_string(),
182 ];
183 let transaction = self.send(Format::Json, "consumers", &args)?;
184 let _response = self.wait_for_response(transaction, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
185 Ok(())
186 }
187
188 pub fn add_context(&mut self, node_id: &str, id: &str, name: &str) -> Result<(), Error> {
189 let system_id = system::get_system_id()?;
190 let args = vec![
191 system_id.to_string(),
192 node_id.to_string(),
193 id.to_string(),
194 name.to_string(),
195 ];
196 let transaction = self.send(Format::Json, "contexts", &args)?;
197 let _response = self.wait_for_response(transaction, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
198 Ok(())
199 }
200
201 pub fn add_node(&mut self, id: &str, name: &str, upstream: bool, token: Option<String>) -> Result<(), Error> {
202 let system_id = system::get_system_id()?;
203 let upstream_arg = if upstream { "yes".to_string() } else { "no".to_string() };
204 let token_arg = token.unwrap_or("$uuid".to_string());
205 let args = vec![
206 system_id.to_string(),
207 id.to_string(),
208 name.to_string(),
209 upstream_arg,
210 token_arg,
211 ];
212 let transaction = self.send(Format::Json, "nodes", &args)?;
213 let _response = self.wait_for_response(transaction, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
214 Ok(())
215 }
216
217 pub fn add_provider(&mut self, node_id: &str, context_id: &str, profile: &str) -> Result<(), Error> {
218 let system_id = system::get_system_id()?;
219 let args = vec![
220 system_id.to_string(),
221 node_id.to_string(),
222 context_id.to_string(),
223 profile.to_string(),
224 ];
225 let transaction = self.send(Format::Json, "providers", &args)?;
226 let _response = self.wait_for_response(transaction, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
227 Ok(())
228 }
229
230 pub fn add_system(&mut self) -> Result<(), Error> {
231 let id = system::get_system_id()?;
232 let name = hostname::get()?.to_str().unwrap().to_string();
233 self.add_system_(&id, &name)?;
234 Ok(())
235 }
236
237 fn add_system_(&mut self, id: &Uuid, name: &str) -> Result<(), Error> {
238 let args = vec![id.to_string(), name.to_string()];
239 let transaction = self.send(Format::Json, "systems", &args)?;
240 let _response = self.wait_for_response(transaction, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
241 Ok(())
242 }
243
244 pub fn get(&self, key: &str, default_value: Option<Value>) -> Result<Option<Value>, Error> {
245 let cache = self.cache.lock()?;
246 let value = match cache.keys.get(key) {
247 Some(value) => Some(value.clone()),
248 None => default_value,
249 };
250 Ok(value)
251 }
252
253 pub fn keys(&self) -> Result<Vec<String>, Error> {
254 let mut vec = vec![];
255 let cache = self.cache.lock()?;
256 for (key, _) in cache.keys.iter() {
257 vec.push(key.clone());
258 }
259 Ok(vec)
260 }
261
262 fn merge(target: &mut Cache, source: &SparseCache) {
263 if let Some(ref stats) = source.stats {
264 if let Some(ref started) = stats.started {
265 target.stats.started = started.clone();
266 }
267 if let Some(reads) = stats.reads {
268 target.stats.reads = reads;
269 }
270 if let Some(writes) = stats.writes {
271 target.stats.writes = writes;
272 }
273 if let Some(updates) = stats.updates {
274 target.stats.updates = updates;
275 }
276 if let Some(errors) = stats.errors {
277 target.stats.errors = errors;
278 }
279 if let Some(ref connection) = stats.connection {
280 target.stats.connection = connection.clone();
281 }
282 }
283 if let Some(ref version) = source.version {
284 target.version = version.clone();
285 }
286 if let Some(ref keys) = source.keys {
287 for (k, v) in keys.iter() {
288 target.keys.insert(k.to_string(), v.clone());
289 }
290 }
291 }
292
293 pub fn on_update(&mut self) -> Result<Receiver<Cache>, Error> {
294 let (tx, rx) = mpsc::channel();
295 let mut subscribers = self.subscribers.lock()?;
296 subscribers.push(tx);
297 Ok(rx)
298 }
299
300 pub fn put(&mut self, key: &str, value: &str) -> Result<(), Error> {
301 let args = vec![key.to_string(), value.to_string()];
302 let _ = self.send(Format::Json, "put", &args)?;
303 Ok(())
304 }
305
306 pub fn put_property(
307 &mut self,
308 node_id: &str,
309 context_id: &str,
310 profile: &str,
311 property: &str,
312 value: &str,
313 ) -> Result<(), Error> {
314 let system_id = system::get_system_id()?;
315 let key =
316 format!("cns/{system_id}/nodes/{node_id}/contexts/{context_id}/provider/{profile}/properties/{property}");
317 self.put(&key, value)
318 }
319
320 pub fn send(&mut self, format: Format, cmd: &str, args: &[String]) -> Result<u64, Error> {
321 let mut cmd = cmd.to_string();
322 for arg in args {
323 cmd = format!("{cmd} \"{arg}\"");
324 }
325
326 let mut msg = HashMap::new();
327 let transaction_id = self.next_transaction_id.fetch_add(1, Ordering::SeqCst);
328 msg.insert("format".to_string(), Value::String(format.to_string()));
329 msg.insert("transaction".to_string(), Value::from(transaction_id));
330 msg.insert("command".to_string(), Value::String(cmd));
331
332 match format {
333 Format::Json => {
334 let msg_as_json = serde_json::to_string(&msg)?;
335 let message = Message::text(msg_as_json);
336 self.send_message(message)?;
337 {
338 let mut requests = self.requests.lock()?;
339 requests.insert(transaction_id, None);
340 }
341 Ok(transaction_id)
342 }
343 }
344 }
345
346 fn send_message(&mut self, message: Message) -> Result<(), Error> {
347 let mut socket = self.socket.lock()?;
348 socket.send(message)?;
349 Ok(())
350 }
351
352 pub fn stats(&self) -> Result<Stats, Error> {
353 let cache = self.cache.lock()?;
354 Ok(cache.stats.clone())
355 }
356
357 pub fn version(&self) -> Result<String, Error> {
358 let cache = self.cache.lock()?;
359 Ok(cache.version.clone())
360 }
361
362 pub fn wait_for_open(&self, timeout: Duration) -> Result<(), Error> {
363 let start_time = SystemTime::now();
364 let sleep_for = Duration::from_millis(100);
365 while SystemTime::now().duration_since(start_time)? < timeout {
366 {
367 let cache = self.cache.lock()?;
368 if !cache.version.is_empty() {
369 return Ok(());
370 }
371 }
372 std::thread::sleep(sleep_for);
373 }
374 Err(Error::Timeout("Timed out waiting for open".to_string()))
375 }
376
377 fn wait_for_response(&self, transaction: u64, timeout: Duration) -> Result<Response, Error> {
378 let start_time = SystemTime::now();
379 let sleep_for = Duration::from_millis(100);
380 while SystemTime::now().duration_since(start_time)? < timeout {
381 {
382 let requests = self.requests.lock()?;
383 let response = match requests.get(&transaction) {
384 None => return Err(Error::Default("No such transaction".to_string())),
385 Some(response) => response.clone(),
386 };
387 if let Some(response) = response {
388 match response.error.as_ref() {
389 None => return Ok(response.clone()),
390 Some(error_msg) => return Err(Error::Default(error_msg.clone())),
391 }
392 }
393 }
394 std::thread::sleep(sleep_for);
395 }
396 Err(Error::Timeout("Timed out waiting for response".to_string()))
397 }
398}