Skip to main content

gestalt/
cache.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::time::Duration;
3
4use hyper_util::rt::TokioIo;
5use tokio::net::UnixStream;
6use tonic::Request;
7use tonic::codegen::async_trait;
8use tonic::metadata::MetadataValue;
9use tonic::service::Interceptor;
10use tonic::service::interceptor::InterceptedService;
11use tonic::transport::{Channel, ClientTlsConfig, Endpoint, Uri};
12use tower::service_fn;
13
14use crate::api::RuntimeMetadata;
15use crate::error::Result;
16use crate::generated::v1::{self as pb, cache_client::CacheClient};
17
18type CacheTransport = InterceptedService<Channel, RelayTokenInterceptor>;
19
20/// Default Unix-socket environment variable used by [`Cache::connect`].
21pub const ENV_CACHE_SOCKET: &str = "GESTALT_CACHE_SOCKET";
22/// Default relay-token environment variable used by [`Cache::connect`].
23pub const ENV_CACHE_SOCKET_TOKEN: &str = "GESTALT_CACHE_SOCKET_TOKEN";
24/// Suffix added to named cache socket variables for relay-token variables.
25pub const ENV_CACHE_SOCKET_TOKEN_SUFFIX: &str = "_TOKEN";
26const CACHE_RELAY_TOKEN_HEADER: &str = "x-gestalt-host-service-relay-token";
27
28#[derive(Debug, Clone, PartialEq, Eq)]
29/// One cache entry written through [`Cache::set_many`].
30pub struct CacheEntry {
31    /// Cache key to store.
32    pub key: String,
33    /// Cache value bytes.
34    pub value: Vec<u8>,
35}
36
37#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
38/// Options applied to cache writes.
39pub struct CacheSetOptions {
40    /// Optional time-to-live for the stored value.
41    pub ttl: Option<Duration>,
42}
43
44#[derive(Debug, thiserror::Error)]
45/// Errors returned by the cache client transport.
46pub enum CacheError {
47    /// The host-service transport could not be created.
48    #[error("{0}")]
49    Transport(#[from] tonic::transport::Error),
50    /// The host-service RPC returned a gRPC status.
51    #[error("{0}")]
52    Status(#[from] tonic::Status),
53    /// Required environment or target configuration was invalid.
54    #[error("{0}")]
55    Env(String),
56}
57
58#[async_trait]
59/// Lifecycle and RPC contract for cache providers.
60pub trait CacheProvider: Send + Sync + 'static {
61    /// Configures the provider before it starts serving requests.
62    async fn configure(
63        &self,
64        _name: &str,
65        _config: serde_json::Map<String, serde_json::Value>,
66    ) -> Result<()> {
67        Ok(())
68    }
69
70    /// Returns runtime metadata that should augment the static manifest.
71    fn metadata(&self) -> Option<RuntimeMetadata> {
72        None
73    }
74
75    /// Returns non-fatal warnings the host should surface to users.
76    fn warnings(&self) -> Vec<String> {
77        Vec::new()
78    }
79
80    /// Performs an optional health check.
81    async fn health_check(&self) -> Result<()> {
82        Ok(())
83    }
84
85    /// Starts provider-owned background work after configuration.
86    async fn start(&self) -> Result<()> {
87        Ok(())
88    }
89
90    /// Shuts the provider down before the runtime exits.
91    async fn close(&self) -> Result<()> {
92        Ok(())
93    }
94
95    /// Loads one cache value.
96    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>>;
97
98    /// Loads many cache values, defaulting to repeated [`CacheProvider::get`]
99    /// calls.
100    async fn get_many(&self, keys: &[String]) -> Result<BTreeMap<String, Vec<u8>>> {
101        let mut values = BTreeMap::new();
102        for key in keys {
103            if let Some(value) = self.get(key).await? {
104                values.insert(key.clone(), value);
105            }
106        }
107        Ok(values)
108    }
109
110    /// Stores one cache value.
111    async fn set(&self, key: &str, value: &[u8], options: CacheSetOptions) -> Result<()>;
112
113    /// Stores many cache values, defaulting to repeated [`CacheProvider::set`]
114    /// calls.
115    async fn set_many(&self, entries: &[CacheEntry], options: CacheSetOptions) -> Result<()> {
116        for entry in entries {
117            self.set(&entry.key, &entry.value, options).await?;
118        }
119        Ok(())
120    }
121
122    /// Deletes one cache key.
123    async fn delete(&self, key: &str) -> Result<bool>;
124
125    /// Deletes many cache keys, defaulting to repeated
126    /// [`CacheProvider::delete`] calls.
127    async fn delete_many(&self, keys: &[String]) -> Result<i64> {
128        let mut deleted = 0_i64;
129        let mut seen = BTreeSet::new();
130        for key in keys {
131            if !seen.insert(key.as_str()) {
132                continue;
133            }
134            if self.delete(key).await? {
135                deleted += 1;
136            }
137        }
138        Ok(deleted)
139    }
140
141    /// Updates the TTL for one cache key.
142    async fn touch(&self, key: &str, ttl: Duration) -> Result<bool>;
143}
144
145/// Client for a running cache provider.
146pub struct Cache {
147    client: CacheClient<CacheTransport>,
148}
149
150impl Cache {
151    /// Connects to the default cache transport socket.
152    pub async fn connect() -> std::result::Result<Self, CacheError> {
153        Self::connect_named("").await
154    }
155
156    /// Connects to a named cache transport socket.
157    pub async fn connect_named(name: &str) -> std::result::Result<Self, CacheError> {
158        let env_name = cache_socket_env(name);
159        let target = std::env::var(&env_name)
160            .map_err(|_| CacheError::Env(format!("{env_name} is not set")))?;
161        let relay_token =
162            std::env::var(cache_socket_token_env(name)).unwrap_or_else(|_| String::new());
163
164        let channel = match parse_cache_target(&target)? {
165            CacheTarget::Unix(path) => {
166                Endpoint::try_from("http://[::]:50051")?
167                    .connect_with_connector(service_fn(move |_: Uri| {
168                        let path = path.clone();
169                        async move { UnixStream::connect(path).await.map(TokioIo::new) }
170                    }))
171                    .await?
172            }
173            CacheTarget::Tcp(address) => {
174                Endpoint::from_shared(format!("http://{address}"))?
175                    .connect()
176                    .await?
177            }
178            CacheTarget::Tls(address) => {
179                Endpoint::from_shared(format!("https://{address}"))?
180                    .tls_config(ClientTlsConfig::new().with_native_roots())?
181                    .connect()
182                    .await?
183            }
184        };
185
186        Ok(Self {
187            client: CacheClient::with_interceptor(
188                channel,
189                relay_token_interceptor(relay_token.trim())?,
190            ),
191        })
192    }
193
194    /// Loads one cache value.
195    pub async fn get(&mut self, key: &str) -> std::result::Result<Option<Vec<u8>>, CacheError> {
196        let response = self
197            .client
198            .get(pb::CacheGetRequest {
199                key: key.to_string(),
200            })
201            .await?
202            .into_inner();
203        if !response.found {
204            return Ok(None);
205        }
206        Ok(Some(response.value))
207    }
208
209    /// Loads all present values for keys.
210    pub async fn get_many<S>(
211        &mut self,
212        keys: &[S],
213    ) -> std::result::Result<BTreeMap<String, Vec<u8>>, CacheError>
214    where
215        S: AsRef<str>,
216    {
217        let request_keys: Vec<String> = keys.iter().map(|key| key.as_ref().to_string()).collect();
218        let response = self
219            .client
220            .get_many(pb::CacheGetManyRequest { keys: request_keys })
221            .await?
222            .into_inner();
223        let mut values = BTreeMap::new();
224        for entry in response.entries {
225            if entry.found {
226                values.insert(entry.key, entry.value);
227            }
228        }
229        Ok(values)
230    }
231
232    /// Stores one cache value.
233    pub async fn set(
234        &mut self,
235        key: &str,
236        value: &[u8],
237        options: CacheSetOptions,
238    ) -> std::result::Result<(), CacheError> {
239        self.client
240            .set(pb::CacheSetRequest {
241                key: key.to_string(),
242                value: value.to_vec(),
243                ttl: duration_to_proto(options.ttl),
244            })
245            .await?;
246        Ok(())
247    }
248
249    /// Stores multiple cache values in one RPC.
250    pub async fn set_many(
251        &mut self,
252        entries: &[CacheEntry],
253        options: CacheSetOptions,
254    ) -> std::result::Result<(), CacheError> {
255        self.client
256            .set_many(pb::CacheSetManyRequest {
257                entries: entries
258                    .iter()
259                    .map(|entry| pb::CacheSetEntry {
260                        key: entry.key.clone(),
261                        value: entry.value.clone(),
262                    })
263                    .collect(),
264                ttl: duration_to_proto(options.ttl),
265            })
266            .await?;
267        Ok(())
268    }
269
270    /// Deletes one cache key.
271    pub async fn delete(&mut self, key: &str) -> std::result::Result<bool, CacheError> {
272        let response = self
273            .client
274            .delete(pb::CacheDeleteRequest {
275                key: key.to_string(),
276            })
277            .await?
278            .into_inner();
279        Ok(response.deleted)
280    }
281
282    /// Deletes many cache keys.
283    pub async fn delete_many<S>(&mut self, keys: &[S]) -> std::result::Result<i64, CacheError>
284    where
285        S: AsRef<str>,
286    {
287        let response = self
288            .client
289            .delete_many(pb::CacheDeleteManyRequest {
290                keys: keys.iter().map(|key| key.as_ref().to_string()).collect(),
291            })
292            .await?
293            .into_inner();
294        Ok(response.deleted)
295    }
296
297    /// Updates the TTL for one cache key.
298    pub async fn touch(
299        &mut self,
300        key: &str,
301        ttl: Duration,
302    ) -> std::result::Result<bool, CacheError> {
303        let response = self
304            .client
305            .touch(pb::CacheTouchRequest {
306                key: key.to_string(),
307                ttl: duration_to_proto(Some(ttl)),
308            })
309            .await?
310            .into_inner();
311        Ok(response.touched)
312    }
313}
314
315/// Returns the environment variable used for a named cache socket.
316pub fn cache_socket_env(name: &str) -> String {
317    let trimmed = name.trim();
318    if trimmed.is_empty() {
319        return ENV_CACHE_SOCKET.to_string();
320    }
321    let mut env = String::from(ENV_CACHE_SOCKET);
322    env.push('_');
323    for ch in trimmed.chars() {
324        if ch.is_ascii_alphanumeric() {
325            env.push(ch.to_ascii_uppercase());
326        } else {
327            env.push('_');
328        }
329    }
330    env
331}
332
333/// Returns the environment variable used for a named cache relay token.
334pub fn cache_socket_token_env(name: &str) -> String {
335    format!(
336        "{env}{}",
337        ENV_CACHE_SOCKET_TOKEN_SUFFIX,
338        env = cache_socket_env(name)
339    )
340}
341
342enum CacheTarget {
343    Unix(String),
344    Tcp(String),
345    Tls(String),
346}
347
348fn parse_cache_target(raw_target: &str) -> std::result::Result<CacheTarget, CacheError> {
349    let target = raw_target.trim();
350    if target.is_empty() {
351        return Err(CacheError::Env(
352            "cache: transport target is required".to_string(),
353        ));
354    }
355    if let Some(address) = target.strip_prefix("tcp://") {
356        let address = address.trim();
357        if address.is_empty() {
358            return Err(CacheError::Env(format!(
359                "cache: tcp target {raw_target:?} is missing host:port"
360            )));
361        }
362        return Ok(CacheTarget::Tcp(address.to_string()));
363    }
364    if let Some(address) = target.strip_prefix("tls://") {
365        let address = address.trim();
366        if address.is_empty() {
367            return Err(CacheError::Env(format!(
368                "cache: tls target {raw_target:?} is missing host:port"
369            )));
370        }
371        return Ok(CacheTarget::Tls(address.to_string()));
372    }
373    if let Some(path) = target.strip_prefix("unix://") {
374        let path = path.trim();
375        if path.is_empty() {
376            return Err(CacheError::Env(format!(
377                "cache: unix target {raw_target:?} is missing a socket path"
378            )));
379        }
380        return Ok(CacheTarget::Unix(path.to_string()));
381    }
382    if target.contains("://") {
383        let scheme = target.split("://").next().unwrap_or_default();
384        return Err(CacheError::Env(format!(
385            "cache: unsupported target scheme {scheme:?}"
386        )));
387    }
388    Ok(CacheTarget::Unix(target.to_string()))
389}
390
391fn relay_token_interceptor(token: &str) -> std::result::Result<RelayTokenInterceptor, CacheError> {
392    let header =
393        if token.trim().is_empty() {
394            None
395        } else {
396            Some(MetadataValue::try_from(token.to_string()).map_err(|err| {
397                CacheError::Env(format!("invalid cache relay token metadata: {err}"))
398            })?)
399        };
400    Ok(RelayTokenInterceptor { header })
401}
402
403#[derive(Clone)]
404struct RelayTokenInterceptor {
405    header: Option<MetadataValue<tonic::metadata::Ascii>>,
406}
407
408impl Interceptor for RelayTokenInterceptor {
409    fn call(
410        &mut self,
411        mut request: Request<()>,
412    ) -> std::result::Result<Request<()>, tonic::Status> {
413        if let Some(header) = self.header.clone() {
414            request
415                .metadata_mut()
416                .insert(CACHE_RELAY_TOKEN_HEADER, header);
417        }
418        Ok(request)
419    }
420}
421
422fn duration_to_proto(ttl: Option<Duration>) -> Option<prost_types::Duration> {
423    let ttl = ttl.filter(|ttl| !ttl.is_zero())?;
424    Some(prost_types::Duration {
425        seconds: i64::try_from(ttl.as_secs()).unwrap_or(i64::MAX),
426        nanos: i32::try_from(ttl.subsec_nanos()).unwrap_or(i32::MAX),
427    })
428}