Skip to main content

gestalt/
cache.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::time::Duration;
3
4use hyper_util::rt::TokioIo;
5use tokio::net::UnixStream;
6use tonic::codegen::async_trait;
7use tonic::transport::{Channel, Endpoint, Uri};
8use tower::service_fn;
9
10use crate::api::RuntimeMetadata;
11use crate::error::Result;
12use crate::generated::v1::{self as pb, cache_client::CacheClient};
13
14pub const ENV_CACHE_SOCKET: &str = "GESTALT_CACHE_SOCKET";
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct CacheEntry {
18    pub key: String,
19    pub value: Vec<u8>,
20}
21
22#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
23pub struct CacheSetOptions {
24    pub ttl: Option<Duration>,
25}
26
27#[derive(Debug, thiserror::Error)]
28pub enum CacheError {
29    #[error("{0}")]
30    Transport(#[from] tonic::transport::Error),
31    #[error("{0}")]
32    Status(#[from] tonic::Status),
33    #[error("{0}")]
34    Env(String),
35}
36
37#[async_trait]
38pub trait CacheProvider: Send + Sync + 'static {
39    async fn configure(
40        &self,
41        _name: &str,
42        _config: serde_json::Map<String, serde_json::Value>,
43    ) -> Result<()> {
44        Ok(())
45    }
46
47    fn metadata(&self) -> Option<RuntimeMetadata> {
48        None
49    }
50
51    fn warnings(&self) -> Vec<String> {
52        Vec::new()
53    }
54
55    async fn health_check(&self) -> Result<()> {
56        Ok(())
57    }
58
59    async fn close(&self) -> Result<()> {
60        Ok(())
61    }
62
63    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>>;
64
65    async fn get_many(&self, keys: &[String]) -> Result<BTreeMap<String, Vec<u8>>> {
66        let mut values = BTreeMap::new();
67        for key in keys {
68            if let Some(value) = self.get(key).await? {
69                values.insert(key.clone(), value);
70            }
71        }
72        Ok(values)
73    }
74
75    async fn set(&self, key: &str, value: &[u8], options: CacheSetOptions) -> Result<()>;
76
77    async fn set_many(&self, entries: &[CacheEntry], options: CacheSetOptions) -> Result<()> {
78        for entry in entries {
79            self.set(&entry.key, &entry.value, options).await?;
80        }
81        Ok(())
82    }
83
84    async fn delete(&self, key: &str) -> Result<bool>;
85
86    async fn delete_many(&self, keys: &[String]) -> Result<i64> {
87        let mut deleted = 0_i64;
88        let mut seen = BTreeSet::new();
89        for key in keys {
90            if !seen.insert(key.as_str()) {
91                continue;
92            }
93            if self.delete(key).await? {
94                deleted += 1;
95            }
96        }
97        Ok(deleted)
98    }
99
100    async fn touch(&self, key: &str, ttl: Duration) -> Result<bool>;
101}
102
103pub struct Cache {
104    client: CacheClient<Channel>,
105}
106
107impl Cache {
108    pub async fn connect() -> std::result::Result<Self, CacheError> {
109        Self::connect_named("").await
110    }
111
112    pub async fn connect_named(name: &str) -> std::result::Result<Self, CacheError> {
113        let env_name = cache_socket_env(name);
114        let socket_path = std::env::var(&env_name)
115            .map_err(|_| CacheError::Env(format!("{env_name} is not set")))?;
116
117        let channel = Endpoint::try_from("http://[::]:50051")?
118            .connect_with_connector(service_fn(move |_: Uri| {
119                let path = socket_path.clone();
120                async move { UnixStream::connect(path).await.map(TokioIo::new) }
121            }))
122            .await?;
123
124        Ok(Self {
125            client: CacheClient::new(channel),
126        })
127    }
128
129    pub async fn get(&mut self, key: &str) -> std::result::Result<Option<Vec<u8>>, CacheError> {
130        let response = self
131            .client
132            .get(pb::CacheGetRequest {
133                key: key.to_string(),
134            })
135            .await?
136            .into_inner();
137        if !response.found {
138            return Ok(None);
139        }
140        Ok(Some(response.value))
141    }
142
143    pub async fn get_many<S>(
144        &mut self,
145        keys: &[S],
146    ) -> std::result::Result<BTreeMap<String, Vec<u8>>, CacheError>
147    where
148        S: AsRef<str>,
149    {
150        let request_keys: Vec<String> = keys.iter().map(|key| key.as_ref().to_string()).collect();
151        let response = self
152            .client
153            .get_many(pb::CacheGetManyRequest { keys: request_keys })
154            .await?
155            .into_inner();
156        let mut values = BTreeMap::new();
157        for entry in response.entries {
158            if entry.found {
159                values.insert(entry.key, entry.value);
160            }
161        }
162        Ok(values)
163    }
164
165    pub async fn set(
166        &mut self,
167        key: &str,
168        value: &[u8],
169        options: CacheSetOptions,
170    ) -> std::result::Result<(), CacheError> {
171        self.client
172            .set(pb::CacheSetRequest {
173                key: key.to_string(),
174                value: value.to_vec(),
175                ttl: duration_to_proto(options.ttl),
176            })
177            .await?;
178        Ok(())
179    }
180
181    pub async fn set_many(
182        &mut self,
183        entries: &[CacheEntry],
184        options: CacheSetOptions,
185    ) -> std::result::Result<(), CacheError> {
186        self.client
187            .set_many(pb::CacheSetManyRequest {
188                entries: entries
189                    .iter()
190                    .map(|entry| pb::CacheSetEntry {
191                        key: entry.key.clone(),
192                        value: entry.value.clone(),
193                    })
194                    .collect(),
195                ttl: duration_to_proto(options.ttl),
196            })
197            .await?;
198        Ok(())
199    }
200
201    pub async fn delete(&mut self, key: &str) -> std::result::Result<bool, CacheError> {
202        let response = self
203            .client
204            .delete(pb::CacheDeleteRequest {
205                key: key.to_string(),
206            })
207            .await?
208            .into_inner();
209        Ok(response.deleted)
210    }
211
212    pub async fn delete_many<S>(&mut self, keys: &[S]) -> std::result::Result<i64, CacheError>
213    where
214        S: AsRef<str>,
215    {
216        let response = self
217            .client
218            .delete_many(pb::CacheDeleteManyRequest {
219                keys: keys.iter().map(|key| key.as_ref().to_string()).collect(),
220            })
221            .await?
222            .into_inner();
223        Ok(response.deleted)
224    }
225
226    pub async fn touch(
227        &mut self,
228        key: &str,
229        ttl: Duration,
230    ) -> std::result::Result<bool, CacheError> {
231        let response = self
232            .client
233            .touch(pb::CacheTouchRequest {
234                key: key.to_string(),
235                ttl: duration_to_proto(Some(ttl)),
236            })
237            .await?
238            .into_inner();
239        Ok(response.touched)
240    }
241}
242
243pub fn cache_socket_env(name: &str) -> String {
244    let trimmed = name.trim();
245    if trimmed.is_empty() {
246        return ENV_CACHE_SOCKET.to_string();
247    }
248    let mut env = String::from(ENV_CACHE_SOCKET);
249    env.push('_');
250    for ch in trimmed.chars() {
251        if ch.is_ascii_alphanumeric() {
252            env.push(ch.to_ascii_uppercase());
253        } else {
254            env.push('_');
255        }
256    }
257    env
258}
259
260fn duration_to_proto(ttl: Option<Duration>) -> Option<prost_types::Duration> {
261    let ttl = ttl.filter(|ttl| !ttl.is_zero())?;
262    Some(prost_types::Duration {
263        seconds: i64::try_from(ttl.as_secs()).unwrap_or(i64::MAX),
264        nanos: i32::try_from(ttl.subsec_nanos()).unwrap_or(i32::MAX),
265    })
266}