1use std::collections::{BTreeMap, BTreeSet};
2use std::time::Duration;
3
4use hyper_util::rt::TokioIo;
5use tokio::net::UnixStream;
6use tonic::codegen::async_trait;
7use tonic::transport::{Channel, Endpoint, Uri};
8use tower::service_fn;
9
10use crate::api::RuntimeMetadata;
11use crate::error::Result;
12use crate::generated::v1::{self as pb, cache_client::CacheClient};
13
14pub const ENV_CACHE_SOCKET: &str = "GESTALT_CACHE_SOCKET";
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct CacheEntry {
18 pub key: String,
19 pub value: Vec<u8>,
20}
21
22#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
23pub struct CacheSetOptions {
24 pub ttl: Option<Duration>,
25}
26
27#[derive(Debug, thiserror::Error)]
28pub enum CacheError {
29 #[error("{0}")]
30 Transport(#[from] tonic::transport::Error),
31 #[error("{0}")]
32 Status(#[from] tonic::Status),
33 #[error("{0}")]
34 Env(String),
35}
36
37#[async_trait]
38pub trait CacheProvider: Send + Sync + 'static {
39 async fn configure(
40 &self,
41 _name: &str,
42 _config: serde_json::Map<String, serde_json::Value>,
43 ) -> Result<()> {
44 Ok(())
45 }
46
47 fn metadata(&self) -> Option<RuntimeMetadata> {
48 None
49 }
50
51 fn warnings(&self) -> Vec<String> {
52 Vec::new()
53 }
54
55 async fn health_check(&self) -> Result<()> {
56 Ok(())
57 }
58
59 async fn close(&self) -> Result<()> {
60 Ok(())
61 }
62
63 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>>;
64
65 async fn get_many(&self, keys: &[String]) -> Result<BTreeMap<String, Vec<u8>>> {
66 let mut values = BTreeMap::new();
67 for key in keys {
68 if let Some(value) = self.get(key).await? {
69 values.insert(key.clone(), value);
70 }
71 }
72 Ok(values)
73 }
74
75 async fn set(&self, key: &str, value: &[u8], options: CacheSetOptions) -> Result<()>;
76
77 async fn set_many(&self, entries: &[CacheEntry], options: CacheSetOptions) -> Result<()> {
78 for entry in entries {
79 self.set(&entry.key, &entry.value, options).await?;
80 }
81 Ok(())
82 }
83
84 async fn delete(&self, key: &str) -> Result<bool>;
85
86 async fn delete_many(&self, keys: &[String]) -> Result<i64> {
87 let mut deleted = 0_i64;
88 let mut seen = BTreeSet::new();
89 for key in keys {
90 if !seen.insert(key.as_str()) {
91 continue;
92 }
93 if self.delete(key).await? {
94 deleted += 1;
95 }
96 }
97 Ok(deleted)
98 }
99
100 async fn touch(&self, key: &str, ttl: Duration) -> Result<bool>;
101}
102
103pub struct Cache {
104 client: CacheClient<Channel>,
105}
106
107impl Cache {
108 pub async fn connect() -> std::result::Result<Self, CacheError> {
109 Self::connect_named("").await
110 }
111
112 pub async fn connect_named(name: &str) -> std::result::Result<Self, CacheError> {
113 let env_name = cache_socket_env(name);
114 let socket_path = std::env::var(&env_name)
115 .map_err(|_| CacheError::Env(format!("{env_name} is not set")))?;
116
117 let channel = Endpoint::try_from("http://[::]:50051")?
118 .connect_with_connector(service_fn(move |_: Uri| {
119 let path = socket_path.clone();
120 async move { UnixStream::connect(path).await.map(TokioIo::new) }
121 }))
122 .await?;
123
124 Ok(Self {
125 client: CacheClient::new(channel),
126 })
127 }
128
129 pub async fn get(&mut self, key: &str) -> std::result::Result<Option<Vec<u8>>, CacheError> {
130 let response = self
131 .client
132 .get(pb::CacheGetRequest {
133 key: key.to_string(),
134 })
135 .await?
136 .into_inner();
137 if !response.found {
138 return Ok(None);
139 }
140 Ok(Some(response.value))
141 }
142
143 pub async fn get_many<S>(
144 &mut self,
145 keys: &[S],
146 ) -> std::result::Result<BTreeMap<String, Vec<u8>>, CacheError>
147 where
148 S: AsRef<str>,
149 {
150 let request_keys: Vec<String> = keys.iter().map(|key| key.as_ref().to_string()).collect();
151 let response = self
152 .client
153 .get_many(pb::CacheGetManyRequest { keys: request_keys })
154 .await?
155 .into_inner();
156 let mut values = BTreeMap::new();
157 for entry in response.entries {
158 if entry.found {
159 values.insert(entry.key, entry.value);
160 }
161 }
162 Ok(values)
163 }
164
165 pub async fn set(
166 &mut self,
167 key: &str,
168 value: &[u8],
169 options: CacheSetOptions,
170 ) -> std::result::Result<(), CacheError> {
171 self.client
172 .set(pb::CacheSetRequest {
173 key: key.to_string(),
174 value: value.to_vec(),
175 ttl: duration_to_proto(options.ttl),
176 })
177 .await?;
178 Ok(())
179 }
180
181 pub async fn set_many(
182 &mut self,
183 entries: &[CacheEntry],
184 options: CacheSetOptions,
185 ) -> std::result::Result<(), CacheError> {
186 self.client
187 .set_many(pb::CacheSetManyRequest {
188 entries: entries
189 .iter()
190 .map(|entry| pb::CacheSetEntry {
191 key: entry.key.clone(),
192 value: entry.value.clone(),
193 })
194 .collect(),
195 ttl: duration_to_proto(options.ttl),
196 })
197 .await?;
198 Ok(())
199 }
200
201 pub async fn delete(&mut self, key: &str) -> std::result::Result<bool, CacheError> {
202 let response = self
203 .client
204 .delete(pb::CacheDeleteRequest {
205 key: key.to_string(),
206 })
207 .await?
208 .into_inner();
209 Ok(response.deleted)
210 }
211
212 pub async fn delete_many<S>(&mut self, keys: &[S]) -> std::result::Result<i64, CacheError>
213 where
214 S: AsRef<str>,
215 {
216 let response = self
217 .client
218 .delete_many(pb::CacheDeleteManyRequest {
219 keys: keys.iter().map(|key| key.as_ref().to_string()).collect(),
220 })
221 .await?
222 .into_inner();
223 Ok(response.deleted)
224 }
225
226 pub async fn touch(
227 &mut self,
228 key: &str,
229 ttl: Duration,
230 ) -> std::result::Result<bool, CacheError> {
231 let response = self
232 .client
233 .touch(pb::CacheTouchRequest {
234 key: key.to_string(),
235 ttl: duration_to_proto(Some(ttl)),
236 })
237 .await?
238 .into_inner();
239 Ok(response.touched)
240 }
241}
242
243pub fn cache_socket_env(name: &str) -> String {
244 let trimmed = name.trim();
245 if trimmed.is_empty() {
246 return ENV_CACHE_SOCKET.to_string();
247 }
248 let mut env = String::from(ENV_CACHE_SOCKET);
249 env.push('_');
250 for ch in trimmed.chars() {
251 if ch.is_ascii_alphanumeric() {
252 env.push(ch.to_ascii_uppercase());
253 } else {
254 env.push('_');
255 }
256 }
257 env
258}
259
260fn duration_to_proto(ttl: Option<Duration>) -> Option<prost_types::Duration> {
261 let ttl = ttl.filter(|ttl| !ttl.is_zero())?;
262 Some(prost_types::Duration {
263 seconds: i64::try_from(ttl.as_secs()).unwrap_or(i64::MAX),
264 nanos: i32::try_from(ttl.subsec_nanos()).unwrap_or(i32::MAX),
265 })
266}