1use super::{Error, Stats};
2use serde::Deserialize;
3use serde_json::Value;
4use std::io::ErrorKind;
5use std::time::{Duration, SystemTime};
6use std::{
7 collections::HashMap,
8 net::TcpStream,
9 sync::{
10 Arc, Mutex,
11 atomic::{AtomicU64, Ordering},
12 },
13};
14use strum_macros::{AsRefStr, Display};
15use tungstenite::{Message, WebSocket, stream::MaybeTlsStream};
16
17#[derive(AsRefStr, Clone, Debug, Display)]
18pub enum Format {
19 #[strum(serialize = "json")]
20 Json,
21}
22
23pub struct Connection {
24 socket: Arc<Mutex<WebSocket<MaybeTlsStream<TcpStream>>>>,
25 next_transaction_id: AtomicU64,
26 cache: Arc<Mutex<Cache>>,
27}
28
29#[derive(Debug, Default, Deserialize)]
30struct Cache {
31 version: String,
32 stats: Stats,
33 keys: HashMap<String, Value>,
34}
35
36impl Connection {
37 pub(crate) fn new(socket: Arc<Mutex<WebSocket<MaybeTlsStream<TcpStream>>>>) -> Self {
38 let next_transaction_id = AtomicU64::new(1);
39 let cache = Arc::new(Mutex::new(Cache::default()));
40
41 let socket_2 = socket.clone();
43 let cache_2 = cache.clone();
44 std::thread::spawn(move || {
45 loop {
46 let maybe_message = {
47 if let Ok(mut socket) = socket_2.lock() {
48 match socket.read() {
49 Ok(message) => Some(message),
50 Err(e) => match e {
51 tungstenite::Error::Io(ref e) if e.kind() == ErrorKind::WouldBlock => None,
52 _ => panic!("{e:?}"),
53 },
54 }
55 } else {
56 continue;
57 }
58 };
59 let message = match maybe_message {
60 Some(message) => message,
61 None => {
62 std::thread::sleep(Duration::from_millis(20));
63 continue;
64 }
65 };
66 if let Message::Text(ref message) = message {
67 let payload: Cache = serde_json::from_slice(message.as_bytes()).unwrap();
68 if let Ok(mut cache) = cache_2.lock() {
69 Self::merge(&mut cache, &payload);
70 }
71 }
72 }
73 });
74
75 Self {
76 socket,
77 next_transaction_id,
78 cache,
79 }
80 }
81
82 pub fn get(&self, key: &str, default_value: Option<Value>) -> Result<Option<Value>, Error> {
83 let cache = self.cache.lock()?;
84 let value = match cache.keys.get(key) {
85 Some(value) => Some(value.clone()),
86 None => default_value,
87 };
88 Ok(value)
89 }
90
91 pub fn keys(&self) -> Result<Vec<String>, Error> {
92 let mut vec = vec![];
93 let cache = self.cache.lock()?;
94 for (key, _) in cache.keys.iter() {
95 vec.push(key.clone());
96 }
97 Ok(vec)
98 }
99
100 fn merge(target: &mut Cache, source: &Cache) {
101 target.stats = source.stats.clone();
102 target.version = source.version.clone();
103 for (k, v) in source.keys.iter() {
104 target.keys.insert(k.to_string(), v.clone());
105 }
106 }
107
108 pub fn put(&mut self, key: &str, value: &str) -> Result<(), Error> {
109 let args = vec![format!("\"{key}\""), value.to_string()];
110 self.send(Format::Json, "put", &args)
111 }
112
113 pub fn send(&mut self, format: Format, cmd: &str, args: &[String]) -> Result<(), Error> {
114 let mut cmd = cmd.to_string();
115 for arg in args {
116 cmd = format!("{cmd} \"{arg}\"");
117 }
118
119 let mut msg = HashMap::new();
120 let transaction_id = self.next_transaction_id.fetch_add(1, Ordering::SeqCst);
121 msg.insert("format".to_string(), Value::String(format.to_string()));
122 msg.insert("transaction".to_string(), Value::from(transaction_id));
123 msg.insert("command".to_string(), Value::String(cmd));
124
125 let msg_as_json = serde_json::to_string(&msg)?;
126 let message = Message::text(msg_as_json);
127 self.send_message(message)
128 }
129
130 fn send_message(&mut self, message: Message) -> Result<(), Error> {
131 let mut socket = self.socket.lock()?;
132 socket.send(message)?;
133 Ok(())
134 }
135
136 pub fn stats(&self) -> Result<Stats, Error> {
137 let cache = self.cache.lock()?;
138 Ok(cache.stats.clone())
139 }
140
141 pub fn version(&self) -> Result<String, Error> {
142 let cache = self.cache.lock()?;
143 Ok(cache.version.clone())
144 }
145
146 pub fn wait_for_open(&self, timeout: Duration) -> Result<(), Error> {
147 let start_time = SystemTime::now();
148 let sleep_for = Duration::from_millis(100);
149 while SystemTime::now().duration_since(start_time)? < timeout {
150 {
151 let cache = self.cache.lock()?;
152 if !cache.version.is_empty() {
153 return Ok(());
154 }
155 }
156 std::thread::sleep(sleep_for);
157 }
158 Err(Error::Timeout("Timed out waiting for open".to_string()))
159 }
160}