Skip to main content

cachekit/
client.rs

1use std::time::Duration;
2
3use serde::{de::DeserializeOwned, Serialize};
4
5use crate::backend::Backend;
6use crate::error::CachekitError;
7use crate::serializer;
8
9// ── SharedBackend type alias ──────────────────────────────────────────────────
10
11/// Reference-counted pointer to a heap-allocated backend.
12///
13/// On native targets (without `unsync`) we require `Send + Sync` via `Arc`.
14/// On `wasm32` or with the `unsync` feature, `Rc` is used instead — the runtime
15/// is single-threaded so `Send` bounds are unnecessary.
16#[cfg(not(any(target_arch = "wasm32", feature = "unsync")))]
17pub type SharedBackend = std::sync::Arc<dyn Backend>;
18
19/// Reference-counted pointer to a heap-allocated backend (`?Send` variant).
20#[cfg(any(target_arch = "wasm32", feature = "unsync"))]
21pub type SharedBackend = std::rc::Rc<dyn Backend>;
22
23// ── SharedEncryption type alias ──────────────────────────────────────────────
24
25/// Reference-counted pointer to the encryption layer.
26///
27/// On native targets (without `unsync`) `Arc` is used (requires `Sync`).
28/// On `wasm32` or with `unsync`, `Rc` is used — avoids the `!Sync` problem
29/// caused by `Cell<u64>` inside cachekit-core's nonce counter.
30#[cfg(all(
31    feature = "encryption",
32    not(any(target_arch = "wasm32", feature = "unsync"))
33))]
34type SharedEncryption = std::sync::Arc<crate::encryption::EncryptionLayer>;
35
36#[cfg(all(
37    feature = "encryption",
38    any(target_arch = "wasm32", feature = "unsync")
39))]
40type SharedEncryption = std::rc::Rc<crate::encryption::EncryptionLayer>;
41
42// ── Key validation ────────────────────────────────────────────────────────────
43
44const MAX_KEY_BYTES: usize = 1024;
45
46/// Maximum TTL for L1 entries populated from L2 cache hits.
47/// Uses a short ceiling to limit staleness when the original TTL is unknown.
48const L1_BACKFILL_TTL_SECS: u64 = 30;
49
50fn validate_key(key: &str) -> Result<(), CachekitError> {
51    if key.is_empty() {
52        return Err(CachekitError::InvalidKey(
53            "key must not be empty".to_owned(),
54        ));
55    }
56    if key.len() > MAX_KEY_BYTES {
57        return Err(CachekitError::InvalidKey(format!(
58            "key is {} bytes (limit: {MAX_KEY_BYTES})",
59            key.len()
60        )));
61    }
62    for b in key.bytes() {
63        if b < 0x20 || b == 0x7F {
64            return Err(CachekitError::InvalidKey(format!(
65                "key contains illegal control character 0x{b:02X}"
66            )));
67        }
68    }
69    Ok(())
70}
71
72// ── CacheKit ─────────────────────────────────────────────────────────────────
73
74/// Production-ready cache client with optional L1 in-process cache layer.
75pub struct CacheKit {
76    backend: SharedBackend,
77    default_ttl: Duration,
78    namespace: Option<String>,
79    max_payload_bytes: usize,
80
81    #[cfg(feature = "l1")]
82    l1: Option<crate::l1::L1Cache>,
83
84    #[cfg(feature = "encryption")]
85    encryption: Option<SharedEncryption>,
86}
87
88impl CacheKit {
89    /// Create a new builder.
90    pub fn builder() -> CacheKitBuilder {
91        CacheKitBuilder::default()
92    }
93
94    /// Build from environment variables via [`crate::config::CachekitConfig::from_env`].
95    ///
96    /// Creates a [`crate::backend::cachekitio::CachekitIO`] backend from the
97    /// config. Requires the `cachekitio` feature.
98    #[cfg(all(feature = "cachekitio", not(target_arch = "wasm32")))]
99    pub fn from_env() -> Result<CacheKitBuilder, CachekitError> {
100        use crate::backend::cachekitio::CachekitIO;
101        use crate::config::CachekitConfig;
102
103        let config = CachekitConfig::from_env()?;
104
105        let api_key_z = config
106            .api_key
107            .ok_or_else(|| CachekitError::Config("CACHEKIT_API_KEY is required".to_owned()))?;
108
109        let backend = CachekitIO::builder()
110            .api_key(api_key_z.as_str())
111            .api_url(config.api_url)
112            .build()
113            .map_err(|e| CachekitError::Config(e.to_string()))?;
114
115        #[cfg(not(feature = "unsync"))]
116        let shared: SharedBackend = std::sync::Arc::new(backend);
117        #[cfg(feature = "unsync")]
118        let shared: SharedBackend = std::rc::Rc::new(backend);
119
120        let mut builder = CacheKitBuilder::default()
121            .backend(shared)
122            .default_ttl(config.default_ttl)
123            .max_payload_bytes(config.max_payload_bytes)
124            .l1_capacity(config.l1_capacity);
125
126        if let Some(ns) = config.namespace.clone() {
127            builder = builder.namespace(ns);
128        }
129
130        // Wire up encryption if master key is configured
131        #[cfg(feature = "encryption")]
132        if let Some(ref master_key) = config.master_key {
133            let namespace = config.namespace.as_deref().unwrap_or("default");
134            builder = builder.encryption_from_bytes(master_key, namespace)?;
135        }
136
137        Ok(builder)
138    }
139
140    // ── Namespacing ───────────────────────────────────────────────────────────
141
142    fn namespaced_key(&self, key: &str) -> String {
143        match &self.namespace {
144            Some(ns) => format!("{ns}:{key}"),
145            None => key.to_owned(),
146        }
147    }
148
149    /// Validate key and return the namespaced version.
150    fn resolve_key(&self, key: &str) -> Result<String, CachekitError> {
151        validate_key(key)?;
152        Ok(self.namespaced_key(key))
153    }
154
155    // ── L1 helpers ───────────────────────────────────────────────────────────
156
157    /// Try L1 cache first. Returns Some(bytes) on hit.
158    #[cfg(feature = "l1")]
159    fn l1_get(&self, full_key: &str) -> Option<Vec<u8>> {
160        self.l1.as_ref().and_then(|l1| l1.get(full_key))
161    }
162
163    /// Populate L1 from an L2 hit with capped TTL to limit staleness.
164    #[cfg(feature = "l1")]
165    fn l1_backfill(&self, full_key: &str, bytes: &[u8]) {
166        if let Some(ref l1) = self.l1 {
167            let l1_ttl = std::cmp::min(self.default_ttl, Duration::from_secs(L1_BACKFILL_TTL_SECS));
168            l1.set(full_key, bytes, l1_ttl);
169        }
170    }
171
172    /// Write-through to L1.
173    #[cfg(feature = "l1")]
174    fn l1_set(&self, full_key: &str, bytes: &[u8], ttl: Duration) {
175        if let Some(ref l1) = self.l1 {
176            l1.set(full_key, bytes, ttl);
177        }
178    }
179
180    /// Invalidate L1 entry.
181    #[cfg(feature = "l1")]
182    fn l1_delete(&self, full_key: &str) {
183        if let Some(ref l1) = self.l1 {
184            l1.delete(full_key);
185        }
186    }
187
188    /// Validate TTL is at least 1 second.
189    fn validate_ttl(ttl: Duration) -> Result<(), CachekitError> {
190        if ttl < Duration::from_secs(1) {
191            return Err(CachekitError::Config(format!(
192                "TTL must be at least 1 second; got {ttl:?}"
193            )));
194        }
195        Ok(())
196    }
197
198    // ── Public operations ─────────────────────────────────────────────────────
199
200    /// Retrieve and deserialize a value stored under `key`.
201    ///
202    /// Returns `None` if the key does not exist.
203    /// Checks L1 cache before hitting the backend.
204    pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, CachekitError> {
205        let full_key = self.resolve_key(key)?;
206
207        // L1 hit
208        #[cfg(feature = "l1")]
209        if let Some(bytes) = self.l1_get(&full_key) {
210            self.check_payload_size(bytes.len())?;
211            return Ok(Some(serializer::deserialize(&bytes)?));
212        }
213
214        // L2 backend
215        let bytes = match self.backend.get(&full_key).await? {
216            Some(b) => b,
217            None => return Ok(None),
218        };
219
220        self.check_payload_size(bytes.len())?;
221
222        // Populate L1 on L2 hit (capped TTL to limit staleness)
223        #[cfg(feature = "l1")]
224        self.l1_backfill(&full_key, &bytes);
225
226        Ok(Some(serializer::deserialize(&bytes)?))
227    }
228
229    /// Serialize and store `value` under `key` using the client's default TTL.
230    pub async fn set<T: Serialize>(&self, key: &str, value: &T) -> Result<(), CachekitError> {
231        self.set_with_ttl(key, value, self.default_ttl).await
232    }
233
234    /// Serialize and store `value` under `key` with an explicit `ttl`.
235    ///
236    /// Returns [`CachekitError::Config`] if `ttl` is less than 1 second.
237    pub async fn set_with_ttl<T: Serialize>(
238        &self,
239        key: &str,
240        value: &T,
241        ttl: Duration,
242    ) -> Result<(), CachekitError> {
243        Self::validate_ttl(ttl)?;
244
245        let bytes = serializer::serialize(value)?;
246        self.check_payload_size(bytes.len())?;
247
248        let full_key = self.resolve_key(key)?;
249
250        // Only clone bytes when L1 needs a copy after the backend consumes them.
251        #[cfg(feature = "l1")]
252        {
253            let l1_bytes = bytes.clone();
254            self.backend.set(&full_key, bytes, Some(ttl)).await?;
255            self.l1_set(&full_key, &l1_bytes, ttl);
256        }
257        #[cfg(not(feature = "l1"))]
258        {
259            self.backend.set(&full_key, bytes, Some(ttl)).await?;
260        }
261
262        Ok(())
263    }
264
265    /// Delete `key` and return `true` if it existed.
266    ///
267    /// Invalidates the L1 entry regardless of the backend result.
268    pub async fn delete(&self, key: &str) -> Result<bool, CachekitError> {
269        let full_key = self.resolve_key(key)?;
270
271        // Invalidate L1 first so callers never read a stale value even if the
272        // backend delete fails partway through.
273        #[cfg(feature = "l1")]
274        self.l1_delete(&full_key);
275
276        Ok(self.backend.delete(&full_key).await?)
277    }
278
279    /// Return `true` if `key` exists without fetching the value.
280    pub async fn exists(&self, key: &str) -> Result<bool, CachekitError> {
281        let full_key = self.resolve_key(key)?;
282
283        // Check L1 first — avoids a network round-trip for warm entries.
284        #[cfg(feature = "l1")]
285        if self.l1_get(&full_key).is_some() {
286            return Ok(true);
287        }
288
289        Ok(self.backend.exists(&full_key).await?)
290    }
291
292    // ── Secure cache ─────────────────────────────────────────────────────────
293
294    /// Return a [`SecureCache`] handle that encrypts all values before storage.
295    ///
296    /// L1 stores **ciphertext** (not plaintext) to preserve the zero-knowledge
297    /// property across all cache layers.
298    ///
299    /// # Errors
300    /// Returns `CachekitError::Config` if no encryption layer is configured.
301    /// Configure encryption via [`CacheKitBuilder::encryption`] or
302    /// [`CacheKitBuilder::encryption_from_bytes`].
303    #[cfg(feature = "encryption")]
304    pub fn secure(&self) -> Result<SecureCache<'_>, CachekitError> {
305        let enc = self.encryption.as_ref().ok_or_else(|| {
306            CachekitError::Config(
307                "encryption requires CACHEKIT_MASTER_KEY or .encryption() on builder".to_owned(),
308            )
309        })?;
310        Ok(SecureCache {
311            client: self,
312            encryption: enc,
313        })
314    }
315
316    // ── Private helpers ───────────────────────────────────────────────────────
317
318    fn check_payload_size(&self, size: usize) -> Result<(), CachekitError> {
319        if size > self.max_payload_bytes {
320            return Err(CachekitError::PayloadTooLarge {
321                size,
322                limit: self.max_payload_bytes,
323            });
324        }
325        Ok(())
326    }
327}
328
329// ── SecureCache ──────────────────────────────────────────────────────────────
330
331/// Encrypted cache handle returned by [`CacheKit::secure()`].
332///
333/// All values are serialized, then encrypted with AES-256-GCM before storage.
334/// L1 stores ciphertext to maintain zero-knowledge guarantees.
335#[cfg(feature = "encryption")]
336pub struct SecureCache<'a> {
337    client: &'a CacheKit,
338    encryption: &'a crate::encryption::EncryptionLayer,
339}
340
341#[cfg(feature = "encryption")]
342impl std::fmt::Debug for SecureCache<'_> {
343    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344        f.debug_struct("SecureCache")
345            .field("tenant_id", &self.encryption.tenant_id())
346            .finish()
347    }
348}
349
350#[cfg(feature = "encryption")]
351impl SecureCache<'_> {
352    /// Encrypt and store `value` under `key` using the client's default TTL.
353    pub async fn set<T: Serialize>(&self, key: &str, value: &T) -> Result<(), CachekitError> {
354        self.set_with_ttl(key, value, self.client.default_ttl).await
355    }
356
357    /// Encrypt and store `value` under `key` with an explicit `ttl`.
358    pub async fn set_with_ttl<T: Serialize>(
359        &self,
360        key: &str,
361        value: &T,
362        ttl: Duration,
363    ) -> Result<(), CachekitError> {
364        CacheKit::validate_ttl(ttl)?;
365
366        // Serialize then encrypt
367        let plaintext = serializer::serialize(value)?;
368        self.client.check_payload_size(plaintext.len())?;
369        let ciphertext = self.encryption.encrypt(&plaintext, key)?;
370
371        let full_key = self.client.resolve_key(key)?;
372
373        // Only clone when L1 needs a copy after the backend consumes the data.
374        #[cfg(feature = "l1")]
375        {
376            let l1_bytes = ciphertext.clone();
377            self.client
378                .backend
379                .set(&full_key, ciphertext, Some(ttl))
380                .await?;
381            self.client.l1_set(&full_key, &l1_bytes, ttl);
382        }
383        #[cfg(not(feature = "l1"))]
384        {
385            self.client
386                .backend
387                .set(&full_key, ciphertext, Some(ttl))
388                .await?;
389        }
390
391        Ok(())
392    }
393
394    /// Retrieve, decrypt, and deserialize a value stored under `key`.
395    ///
396    /// Checks L1 (which holds ciphertext) before the backend.
397    pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, CachekitError> {
398        let full_key = self.client.resolve_key(key)?;
399
400        // L1 hit (ciphertext)
401        #[cfg(feature = "l1")]
402        if let Some(ciphertext) = self.client.l1_get(&full_key) {
403            self.client.check_payload_size(ciphertext.len())?;
404            let plaintext = self.encryption.decrypt(&ciphertext, key)?;
405            return Ok(Some(serializer::deserialize(&plaintext)?));
406        }
407
408        // L2 backend
409        let ciphertext = match self.client.backend.get(&full_key).await? {
410            Some(b) => b,
411            None => return Ok(None),
412        };
413
414        self.client.check_payload_size(ciphertext.len())?;
415
416        // Populate L1 with ciphertext on L2 hit (capped TTL to limit staleness)
417        #[cfg(feature = "l1")]
418        self.client.l1_backfill(&full_key, &ciphertext);
419
420        let plaintext = self.encryption.decrypt(&ciphertext, key)?;
421        Ok(Some(serializer::deserialize(&plaintext)?))
422    }
423
424    /// Delete an encrypted key. Behaves identically to [`CacheKit::delete`].
425    pub async fn delete(&self, key: &str) -> Result<bool, CachekitError> {
426        self.client.delete(key).await
427    }
428
429    /// Check if an encrypted key exists. Behaves identically to [`CacheKit::exists`].
430    pub async fn exists(&self, key: &str) -> Result<bool, CachekitError> {
431        self.client.exists(key).await
432    }
433}
434
435// ── CacheKitBuilder ───────────────────────────────────────────────────────────
436
437/// Fluent builder for [`CacheKit`].
438#[derive(Default)]
439#[must_use]
440pub struct CacheKitBuilder {
441    backend: Option<SharedBackend>,
442    default_ttl: Option<Duration>,
443    namespace: Option<String>,
444    max_payload_bytes: Option<usize>,
445
446    #[cfg(feature = "l1")]
447    l1_capacity: Option<usize>,
448
449    #[cfg(feature = "l1")]
450    no_l1: bool,
451
452    #[cfg(feature = "encryption")]
453    encryption: Option<SharedEncryption>,
454}
455
456impl CacheKitBuilder {
457    /// Set the storage backend.
458    pub fn backend(mut self, backend: SharedBackend) -> Self {
459        self.backend = Some(backend);
460        self
461    }
462
463    /// Override the default TTL (used when no per-call TTL is specified).
464    pub fn default_ttl(mut self, ttl: Duration) -> Self {
465        self.default_ttl = Some(ttl);
466        self
467    }
468
469    /// Set a namespace prefix. All keys will be stored as `{namespace}:{key}`.
470    pub fn namespace(mut self, ns: impl Into<String>) -> Self {
471        self.namespace = Some(ns.into());
472        self
473    }
474
475    /// Set the maximum accepted payload size in bytes.
476    pub fn max_payload_bytes(mut self, limit: usize) -> Self {
477        self.max_payload_bytes = Some(limit);
478        self
479    }
480
481    /// Set the L1 cache capacity (max entries).
482    #[cfg(feature = "l1")]
483    pub fn l1_capacity(mut self, capacity: usize) -> Self {
484        self.l1_capacity = Some(capacity);
485        self
486    }
487
488    /// Disable the L1 cache entirely.
489    #[cfg(feature = "l1")]
490    pub fn no_l1(mut self) -> Self {
491        self.no_l1 = true;
492        self
493    }
494
495    // Stubs for when the l1 feature is disabled — still compile cleanly.
496    #[cfg(not(feature = "l1"))]
497    pub fn l1_capacity(self, _capacity: usize) -> Self {
498        self
499    }
500
501    #[cfg(not(feature = "l1"))]
502    pub fn no_l1(self) -> Self {
503        self
504    }
505
506    /// Configure encryption from raw master key bytes and tenant ID.
507    ///
508    /// The master key must be at least 16 bytes (32 recommended).
509    /// Keys are derived per-tenant via HKDF-SHA256.
510    #[cfg(feature = "encryption")]
511    pub fn encryption_from_bytes(
512        mut self,
513        master_key: &[u8],
514        tenant_id: &str,
515    ) -> Result<Self, CachekitError> {
516        let layer = crate::encryption::EncryptionLayer::new(master_key, tenant_id)?;
517        self.encryption = Some(SharedEncryption::new(layer));
518        Ok(self)
519    }
520
521    /// Configure encryption from a hex-encoded master key string.
522    ///
523    /// Convenience wrapper that hex-decodes then delegates to
524    /// [`Self::encryption_from_bytes`].
525    #[cfg(feature = "encryption")]
526    pub fn encryption(self, hex_key: &str, tenant_id: &str) -> Result<Self, CachekitError> {
527        let bytes = hex::decode(hex_key)
528            .map_err(|e| CachekitError::Config(format!("master key is not valid hex: {e}")))?;
529        self.encryption_from_bytes(&bytes, tenant_id)
530    }
531
532    // Stub for when encryption feature is disabled.
533    #[cfg(not(feature = "encryption"))]
534    pub fn encryption_from_bytes(
535        self,
536        _master_key: &[u8],
537        _tenant_id: &str,
538    ) -> Result<Self, CachekitError> {
539        Ok(self)
540    }
541
542    #[cfg(not(feature = "encryption"))]
543    pub fn encryption(self, _hex_key: &str, _tenant_id: &str) -> Result<Self, CachekitError> {
544        Ok(self)
545    }
546
547    /// Finalise and build the [`CacheKit`] client.
548    ///
549    /// Returns an error if no backend was provided.
550    pub fn build(self) -> Result<CacheKit, CachekitError> {
551        let backend = self.backend.ok_or_else(|| {
552            CachekitError::Config("a backend must be provided via .backend()".to_owned())
553        })?;
554
555        // Validate namespace if provided
556        if let Some(ref ns) = self.namespace {
557            if ns.is_empty() {
558                return Err(CachekitError::Config("namespace cannot be empty".into()));
559            }
560            if ns.len() > 255 {
561                return Err(CachekitError::Config("namespace exceeds 255 bytes".into()));
562            }
563            if !ns.bytes().all(|b| (0x20..=0x7E).contains(&b)) {
564                return Err(CachekitError::Config(
565                    "namespace must be ASCII printable".into(),
566                ));
567            }
568        }
569
570        #[cfg(feature = "l1")]
571        let l1 = if self.no_l1 {
572            None
573        } else {
574            let capacity = self.l1_capacity.unwrap_or(1000);
575            Some(crate::l1::L1Cache::new(capacity))
576        };
577
578        Ok(CacheKit {
579            backend,
580            default_ttl: self.default_ttl.unwrap_or(Duration::from_secs(300)),
581            namespace: self.namespace,
582            max_payload_bytes: self.max_payload_bytes.unwrap_or(5 * 1024 * 1024),
583
584            #[cfg(feature = "l1")]
585            l1,
586
587            #[cfg(feature = "encryption")]
588            encryption: self.encryption,
589        })
590    }
591}