1use std::collections::{BTreeMap, BTreeSet};
2use std::time::Duration;
3
4use hyper_util::rt::TokioIo;
5use tokio::net::UnixStream;
6use tonic::Request;
7use tonic::codegen::async_trait;
8use tonic::metadata::MetadataValue;
9use tonic::service::Interceptor;
10use tonic::service::interceptor::InterceptedService;
11use tonic::transport::{Channel, ClientTlsConfig, Endpoint, Uri};
12use tower::service_fn;
13
14use crate::api::RuntimeMetadata;
15use crate::error::Result;
16use crate::generated::v1::{self as pb, cache_client::CacheClient};
17
18type CacheTransport = InterceptedService<Channel, RelayTokenInterceptor>;
19
20pub const ENV_CACHE_SOCKET: &str = "GESTALT_CACHE_SOCKET";
22pub const ENV_CACHE_SOCKET_TOKEN: &str = "GESTALT_CACHE_SOCKET_TOKEN";
24pub const ENV_CACHE_SOCKET_TOKEN_SUFFIX: &str = "_TOKEN";
26const CACHE_RELAY_TOKEN_HEADER: &str = "x-gestalt-host-service-relay-token";
27
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct CacheEntry {
31 pub key: String,
33 pub value: Vec<u8>,
35}
36
37#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
38pub struct CacheSetOptions {
40 pub ttl: Option<Duration>,
42}
43
44#[derive(Debug, thiserror::Error)]
45pub enum CacheError {
47 #[error("{0}")]
49 Transport(#[from] tonic::transport::Error),
50 #[error("{0}")]
52 Status(#[from] tonic::Status),
53 #[error("{0}")]
55 Env(String),
56}
57
58#[async_trait]
59pub trait CacheProvider: Send + Sync + 'static {
61 async fn configure(
63 &self,
64 _name: &str,
65 _config: serde_json::Map<String, serde_json::Value>,
66 ) -> Result<()> {
67 Ok(())
68 }
69
70 fn metadata(&self) -> Option<RuntimeMetadata> {
72 None
73 }
74
75 fn warnings(&self) -> Vec<String> {
77 Vec::new()
78 }
79
80 async fn health_check(&self) -> Result<()> {
82 Ok(())
83 }
84
85 async fn start(&self) -> Result<()> {
87 Ok(())
88 }
89
90 async fn close(&self) -> Result<()> {
92 Ok(())
93 }
94
95 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>>;
97
98 async fn get_many(&self, keys: &[String]) -> Result<BTreeMap<String, Vec<u8>>> {
101 let mut values = BTreeMap::new();
102 for key in keys {
103 if let Some(value) = self.get(key).await? {
104 values.insert(key.clone(), value);
105 }
106 }
107 Ok(values)
108 }
109
110 async fn set(&self, key: &str, value: &[u8], options: CacheSetOptions) -> Result<()>;
112
113 async fn set_many(&self, entries: &[CacheEntry], options: CacheSetOptions) -> Result<()> {
116 for entry in entries {
117 self.set(&entry.key, &entry.value, options).await?;
118 }
119 Ok(())
120 }
121
122 async fn delete(&self, key: &str) -> Result<bool>;
124
125 async fn delete_many(&self, keys: &[String]) -> Result<i64> {
128 let mut deleted = 0_i64;
129 let mut seen = BTreeSet::new();
130 for key in keys {
131 if !seen.insert(key.as_str()) {
132 continue;
133 }
134 if self.delete(key).await? {
135 deleted += 1;
136 }
137 }
138 Ok(deleted)
139 }
140
141 async fn touch(&self, key: &str, ttl: Duration) -> Result<bool>;
143}
144
145pub struct Cache {
147 client: CacheClient<CacheTransport>,
148}
149
150impl Cache {
151 pub async fn connect() -> std::result::Result<Self, CacheError> {
153 Self::connect_named("").await
154 }
155
156 pub async fn connect_named(name: &str) -> std::result::Result<Self, CacheError> {
158 let env_name = cache_socket_env(name);
159 let target = std::env::var(&env_name)
160 .map_err(|_| CacheError::Env(format!("{env_name} is not set")))?;
161 let relay_token =
162 std::env::var(cache_socket_token_env(name)).unwrap_or_else(|_| String::new());
163
164 let channel = match parse_cache_target(&target)? {
165 CacheTarget::Unix(path) => {
166 Endpoint::try_from("http://[::]:50051")?
167 .connect_with_connector(service_fn(move |_: Uri| {
168 let path = path.clone();
169 async move { UnixStream::connect(path).await.map(TokioIo::new) }
170 }))
171 .await?
172 }
173 CacheTarget::Tcp(address) => {
174 Endpoint::from_shared(format!("http://{address}"))?
175 .connect()
176 .await?
177 }
178 CacheTarget::Tls(address) => {
179 Endpoint::from_shared(format!("https://{address}"))?
180 .tls_config(ClientTlsConfig::new().with_native_roots())?
181 .connect()
182 .await?
183 }
184 };
185
186 Ok(Self {
187 client: CacheClient::with_interceptor(
188 channel,
189 relay_token_interceptor(relay_token.trim())?,
190 ),
191 })
192 }
193
194 pub async fn get(&mut self, key: &str) -> std::result::Result<Option<Vec<u8>>, CacheError> {
196 let response = self
197 .client
198 .get(pb::CacheGetRequest {
199 key: key.to_string(),
200 })
201 .await?
202 .into_inner();
203 if !response.found {
204 return Ok(None);
205 }
206 Ok(Some(response.value))
207 }
208
209 pub async fn get_many<S>(
211 &mut self,
212 keys: &[S],
213 ) -> std::result::Result<BTreeMap<String, Vec<u8>>, CacheError>
214 where
215 S: AsRef<str>,
216 {
217 let request_keys: Vec<String> = keys.iter().map(|key| key.as_ref().to_string()).collect();
218 let response = self
219 .client
220 .get_many(pb::CacheGetManyRequest { keys: request_keys })
221 .await?
222 .into_inner();
223 let mut values = BTreeMap::new();
224 for entry in response.entries {
225 if entry.found {
226 values.insert(entry.key, entry.value);
227 }
228 }
229 Ok(values)
230 }
231
232 pub async fn set(
234 &mut self,
235 key: &str,
236 value: &[u8],
237 options: CacheSetOptions,
238 ) -> std::result::Result<(), CacheError> {
239 self.client
240 .set(pb::CacheSetRequest {
241 key: key.to_string(),
242 value: value.to_vec(),
243 ttl: duration_to_proto(options.ttl),
244 })
245 .await?;
246 Ok(())
247 }
248
249 pub async fn set_many(
251 &mut self,
252 entries: &[CacheEntry],
253 options: CacheSetOptions,
254 ) -> std::result::Result<(), CacheError> {
255 self.client
256 .set_many(pb::CacheSetManyRequest {
257 entries: entries
258 .iter()
259 .map(|entry| pb::CacheSetEntry {
260 key: entry.key.clone(),
261 value: entry.value.clone(),
262 })
263 .collect(),
264 ttl: duration_to_proto(options.ttl),
265 })
266 .await?;
267 Ok(())
268 }
269
270 pub async fn delete(&mut self, key: &str) -> std::result::Result<bool, CacheError> {
272 let response = self
273 .client
274 .delete(pb::CacheDeleteRequest {
275 key: key.to_string(),
276 })
277 .await?
278 .into_inner();
279 Ok(response.deleted)
280 }
281
282 pub async fn delete_many<S>(&mut self, keys: &[S]) -> std::result::Result<i64, CacheError>
284 where
285 S: AsRef<str>,
286 {
287 let response = self
288 .client
289 .delete_many(pb::CacheDeleteManyRequest {
290 keys: keys.iter().map(|key| key.as_ref().to_string()).collect(),
291 })
292 .await?
293 .into_inner();
294 Ok(response.deleted)
295 }
296
297 pub async fn touch(
299 &mut self,
300 key: &str,
301 ttl: Duration,
302 ) -> std::result::Result<bool, CacheError> {
303 let response = self
304 .client
305 .touch(pb::CacheTouchRequest {
306 key: key.to_string(),
307 ttl: duration_to_proto(Some(ttl)),
308 })
309 .await?
310 .into_inner();
311 Ok(response.touched)
312 }
313}
314
315pub fn cache_socket_env(name: &str) -> String {
317 let trimmed = name.trim();
318 if trimmed.is_empty() {
319 return ENV_CACHE_SOCKET.to_string();
320 }
321 let mut env = String::from(ENV_CACHE_SOCKET);
322 env.push('_');
323 for ch in trimmed.chars() {
324 if ch.is_ascii_alphanumeric() {
325 env.push(ch.to_ascii_uppercase());
326 } else {
327 env.push('_');
328 }
329 }
330 env
331}
332
333pub fn cache_socket_token_env(name: &str) -> String {
335 format!(
336 "{env}{}",
337 ENV_CACHE_SOCKET_TOKEN_SUFFIX,
338 env = cache_socket_env(name)
339 )
340}
341
342enum CacheTarget {
343 Unix(String),
344 Tcp(String),
345 Tls(String),
346}
347
348fn parse_cache_target(raw_target: &str) -> std::result::Result<CacheTarget, CacheError> {
349 let target = raw_target.trim();
350 if target.is_empty() {
351 return Err(CacheError::Env(
352 "cache: transport target is required".to_string(),
353 ));
354 }
355 if let Some(address) = target.strip_prefix("tcp://") {
356 let address = address.trim();
357 if address.is_empty() {
358 return Err(CacheError::Env(format!(
359 "cache: tcp target {raw_target:?} is missing host:port"
360 )));
361 }
362 return Ok(CacheTarget::Tcp(address.to_string()));
363 }
364 if let Some(address) = target.strip_prefix("tls://") {
365 let address = address.trim();
366 if address.is_empty() {
367 return Err(CacheError::Env(format!(
368 "cache: tls target {raw_target:?} is missing host:port"
369 )));
370 }
371 return Ok(CacheTarget::Tls(address.to_string()));
372 }
373 if let Some(path) = target.strip_prefix("unix://") {
374 let path = path.trim();
375 if path.is_empty() {
376 return Err(CacheError::Env(format!(
377 "cache: unix target {raw_target:?} is missing a socket path"
378 )));
379 }
380 return Ok(CacheTarget::Unix(path.to_string()));
381 }
382 if target.contains("://") {
383 let scheme = target.split("://").next().unwrap_or_default();
384 return Err(CacheError::Env(format!(
385 "cache: unsupported target scheme {scheme:?}"
386 )));
387 }
388 Ok(CacheTarget::Unix(target.to_string()))
389}
390
391fn relay_token_interceptor(token: &str) -> std::result::Result<RelayTokenInterceptor, CacheError> {
392 let header =
393 if token.trim().is_empty() {
394 None
395 } else {
396 Some(MetadataValue::try_from(token.to_string()).map_err(|err| {
397 CacheError::Env(format!("invalid cache relay token metadata: {err}"))
398 })?)
399 };
400 Ok(RelayTokenInterceptor { header })
401}
402
403#[derive(Clone)]
404struct RelayTokenInterceptor {
405 header: Option<MetadataValue<tonic::metadata::Ascii>>,
406}
407
408impl Interceptor for RelayTokenInterceptor {
409 fn call(
410 &mut self,
411 mut request: Request<()>,
412 ) -> std::result::Result<Request<()>, tonic::Status> {
413 if let Some(header) = self.header.clone() {
414 request
415 .metadata_mut()
416 .insert(CACHE_RELAY_TOKEN_HEADER, header);
417 }
418 Ok(request)
419 }
420}
421
422fn duration_to_proto(ttl: Option<Duration>) -> Option<prost_types::Duration> {
423 let ttl = ttl.filter(|ttl| !ttl.is_zero())?;
424 Some(prost_types::Duration {
425 seconds: i64::try_from(ttl.as_secs()).unwrap_or(i64::MAX),
426 nanos: i32::try_from(ttl.subsec_nanos()).unwrap_or(i32::MAX),
427 })
428}