1use std::time::Duration;
2
3use serde::{de::DeserializeOwned, Serialize};
4
5use crate::backend::Backend;
6use crate::error::CachekitError;
7use crate::serializer;
8
9#[cfg(not(any(target_arch = "wasm32", feature = "unsync")))]
17pub type SharedBackend = std::sync::Arc<dyn Backend>;
18
19#[cfg(any(target_arch = "wasm32", feature = "unsync"))]
21pub type SharedBackend = std::rc::Rc<dyn Backend>;
22
23#[cfg(all(
31 feature = "encryption",
32 not(any(target_arch = "wasm32", feature = "unsync"))
33))]
34type SharedEncryption = std::sync::Arc<crate::encryption::EncryptionLayer>;
35
36#[cfg(all(
37 feature = "encryption",
38 any(target_arch = "wasm32", feature = "unsync")
39))]
40type SharedEncryption = std::rc::Rc<crate::encryption::EncryptionLayer>;
41
42const MAX_KEY_BYTES: usize = 1024;
45
46const L1_BACKFILL_TTL_SECS: u64 = 30;
49
50fn validate_key(key: &str) -> Result<(), CachekitError> {
51 if key.is_empty() {
52 return Err(CachekitError::InvalidKey(
53 "key must not be empty".to_owned(),
54 ));
55 }
56 if key.len() > MAX_KEY_BYTES {
57 return Err(CachekitError::InvalidKey(format!(
58 "key is {} bytes (limit: {MAX_KEY_BYTES})",
59 key.len()
60 )));
61 }
62 for b in key.bytes() {
63 if b < 0x20 || b == 0x7F {
64 return Err(CachekitError::InvalidKey(format!(
65 "key contains illegal control character 0x{b:02X}"
66 )));
67 }
68 }
69 Ok(())
70}
71
72pub struct CacheKit {
76 backend: SharedBackend,
77 default_ttl: Duration,
78 namespace: Option<String>,
79 max_payload_bytes: usize,
80
81 #[cfg(feature = "l1")]
82 l1: Option<crate::l1::L1Cache>,
83
84 #[cfg(feature = "encryption")]
85 encryption: Option<SharedEncryption>,
86}
87
88impl CacheKit {
89 pub fn builder() -> CacheKitBuilder {
91 CacheKitBuilder::default()
92 }
93
94 #[cfg(all(feature = "cachekitio", not(target_arch = "wasm32")))]
99 pub fn from_env() -> Result<CacheKitBuilder, CachekitError> {
100 use crate::backend::cachekitio::CachekitIO;
101 use crate::config::CachekitConfig;
102
103 let config = CachekitConfig::from_env()?;
104
105 let api_key_z = config
106 .api_key
107 .ok_or_else(|| CachekitError::Config("CACHEKIT_API_KEY is required".to_owned()))?;
108
109 let backend = CachekitIO::builder()
110 .api_key(api_key_z.as_str())
111 .api_url(config.api_url)
112 .build()
113 .map_err(|e| CachekitError::Config(e.to_string()))?;
114
115 #[cfg(not(feature = "unsync"))]
116 let shared: SharedBackend = std::sync::Arc::new(backend);
117 #[cfg(feature = "unsync")]
118 let shared: SharedBackend = std::rc::Rc::new(backend);
119
120 let mut builder = CacheKitBuilder::default()
121 .backend(shared)
122 .default_ttl(config.default_ttl)
123 .max_payload_bytes(config.max_payload_bytes)
124 .l1_capacity(config.l1_capacity);
125
126 if let Some(ns) = config.namespace.clone() {
127 builder = builder.namespace(ns);
128 }
129
130 #[cfg(feature = "encryption")]
132 if let Some(ref master_key) = config.master_key {
133 let namespace = config.namespace.as_deref().unwrap_or("default");
134 builder = builder.encryption_from_bytes(master_key, namespace)?;
135 }
136
137 Ok(builder)
138 }
139
140 fn namespaced_key(&self, key: &str) -> String {
143 match &self.namespace {
144 Some(ns) => format!("{ns}:{key}"),
145 None => key.to_owned(),
146 }
147 }
148
149 fn resolve_key(&self, key: &str) -> Result<String, CachekitError> {
151 validate_key(key)?;
152 Ok(self.namespaced_key(key))
153 }
154
155 #[cfg(feature = "l1")]
159 fn l1_get(&self, full_key: &str) -> Option<Vec<u8>> {
160 self.l1.as_ref().and_then(|l1| l1.get(full_key))
161 }
162
163 #[cfg(feature = "l1")]
165 fn l1_backfill(&self, full_key: &str, bytes: &[u8]) {
166 if let Some(ref l1) = self.l1 {
167 let l1_ttl = std::cmp::min(self.default_ttl, Duration::from_secs(L1_BACKFILL_TTL_SECS));
168 l1.set(full_key, bytes, l1_ttl);
169 }
170 }
171
172 #[cfg(feature = "l1")]
174 fn l1_set(&self, full_key: &str, bytes: &[u8], ttl: Duration) {
175 if let Some(ref l1) = self.l1 {
176 l1.set(full_key, bytes, ttl);
177 }
178 }
179
180 #[cfg(feature = "l1")]
182 fn l1_delete(&self, full_key: &str) {
183 if let Some(ref l1) = self.l1 {
184 l1.delete(full_key);
185 }
186 }
187
188 fn validate_ttl(ttl: Duration) -> Result<(), CachekitError> {
190 if ttl < Duration::from_secs(1) {
191 return Err(CachekitError::Config(format!(
192 "TTL must be at least 1 second; got {ttl:?}"
193 )));
194 }
195 Ok(())
196 }
197
198 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, CachekitError> {
205 let full_key = self.resolve_key(key)?;
206
207 #[cfg(feature = "l1")]
209 if let Some(bytes) = self.l1_get(&full_key) {
210 self.check_payload_size(bytes.len())?;
211 return Ok(Some(serializer::deserialize(&bytes)?));
212 }
213
214 let bytes = match self.backend.get(&full_key).await? {
216 Some(b) => b,
217 None => return Ok(None),
218 };
219
220 self.check_payload_size(bytes.len())?;
221
222 #[cfg(feature = "l1")]
224 self.l1_backfill(&full_key, &bytes);
225
226 Ok(Some(serializer::deserialize(&bytes)?))
227 }
228
229 pub async fn set<T: Serialize>(&self, key: &str, value: &T) -> Result<(), CachekitError> {
231 self.set_with_ttl(key, value, self.default_ttl).await
232 }
233
234 pub async fn set_with_ttl<T: Serialize>(
238 &self,
239 key: &str,
240 value: &T,
241 ttl: Duration,
242 ) -> Result<(), CachekitError> {
243 Self::validate_ttl(ttl)?;
244
245 let bytes = serializer::serialize(value)?;
246 self.check_payload_size(bytes.len())?;
247
248 let full_key = self.resolve_key(key)?;
249
250 #[cfg(feature = "l1")]
252 {
253 let l1_bytes = bytes.clone();
254 self.backend.set(&full_key, bytes, Some(ttl)).await?;
255 self.l1_set(&full_key, &l1_bytes, ttl);
256 }
257 #[cfg(not(feature = "l1"))]
258 {
259 self.backend.set(&full_key, bytes, Some(ttl)).await?;
260 }
261
262 Ok(())
263 }
264
265 pub async fn delete(&self, key: &str) -> Result<bool, CachekitError> {
269 let full_key = self.resolve_key(key)?;
270
271 #[cfg(feature = "l1")]
274 self.l1_delete(&full_key);
275
276 Ok(self.backend.delete(&full_key).await?)
277 }
278
279 pub async fn exists(&self, key: &str) -> Result<bool, CachekitError> {
281 let full_key = self.resolve_key(key)?;
282
283 #[cfg(feature = "l1")]
285 if self.l1_get(&full_key).is_some() {
286 return Ok(true);
287 }
288
289 Ok(self.backend.exists(&full_key).await?)
290 }
291
292 #[cfg(feature = "encryption")]
304 pub fn secure(&self) -> Result<SecureCache<'_>, CachekitError> {
305 let enc = self.encryption.as_ref().ok_or_else(|| {
306 CachekitError::Config(
307 "encryption requires CACHEKIT_MASTER_KEY or .encryption() on builder".to_owned(),
308 )
309 })?;
310 Ok(SecureCache {
311 client: self,
312 encryption: enc,
313 })
314 }
315
316 fn check_payload_size(&self, size: usize) -> Result<(), CachekitError> {
319 if size > self.max_payload_bytes {
320 return Err(CachekitError::PayloadTooLarge {
321 size,
322 limit: self.max_payload_bytes,
323 });
324 }
325 Ok(())
326 }
327}
328
329#[cfg(feature = "encryption")]
336pub struct SecureCache<'a> {
337 client: &'a CacheKit,
338 encryption: &'a crate::encryption::EncryptionLayer,
339}
340
341#[cfg(feature = "encryption")]
342impl std::fmt::Debug for SecureCache<'_> {
343 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344 f.debug_struct("SecureCache")
345 .field("tenant_id", &self.encryption.tenant_id())
346 .finish()
347 }
348}
349
350#[cfg(feature = "encryption")]
351impl SecureCache<'_> {
352 pub async fn set<T: Serialize>(&self, key: &str, value: &T) -> Result<(), CachekitError> {
354 self.set_with_ttl(key, value, self.client.default_ttl).await
355 }
356
357 pub async fn set_with_ttl<T: Serialize>(
359 &self,
360 key: &str,
361 value: &T,
362 ttl: Duration,
363 ) -> Result<(), CachekitError> {
364 CacheKit::validate_ttl(ttl)?;
365
366 let plaintext = serializer::serialize(value)?;
368 self.client.check_payload_size(plaintext.len())?;
369 let ciphertext = self.encryption.encrypt(&plaintext, key)?;
370
371 let full_key = self.client.resolve_key(key)?;
372
373 #[cfg(feature = "l1")]
375 {
376 let l1_bytes = ciphertext.clone();
377 self.client
378 .backend
379 .set(&full_key, ciphertext, Some(ttl))
380 .await?;
381 self.client.l1_set(&full_key, &l1_bytes, ttl);
382 }
383 #[cfg(not(feature = "l1"))]
384 {
385 self.client
386 .backend
387 .set(&full_key, ciphertext, Some(ttl))
388 .await?;
389 }
390
391 Ok(())
392 }
393
394 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, CachekitError> {
398 let full_key = self.client.resolve_key(key)?;
399
400 #[cfg(feature = "l1")]
402 if let Some(ciphertext) = self.client.l1_get(&full_key) {
403 self.client.check_payload_size(ciphertext.len())?;
404 let plaintext = self.encryption.decrypt(&ciphertext, key)?;
405 return Ok(Some(serializer::deserialize(&plaintext)?));
406 }
407
408 let ciphertext = match self.client.backend.get(&full_key).await? {
410 Some(b) => b,
411 None => return Ok(None),
412 };
413
414 self.client.check_payload_size(ciphertext.len())?;
415
416 #[cfg(feature = "l1")]
418 self.client.l1_backfill(&full_key, &ciphertext);
419
420 let plaintext = self.encryption.decrypt(&ciphertext, key)?;
421 Ok(Some(serializer::deserialize(&plaintext)?))
422 }
423
424 pub async fn delete(&self, key: &str) -> Result<bool, CachekitError> {
426 self.client.delete(key).await
427 }
428
429 pub async fn exists(&self, key: &str) -> Result<bool, CachekitError> {
431 self.client.exists(key).await
432 }
433}
434
435#[derive(Default)]
439#[must_use]
440pub struct CacheKitBuilder {
441 backend: Option<SharedBackend>,
442 default_ttl: Option<Duration>,
443 namespace: Option<String>,
444 max_payload_bytes: Option<usize>,
445
446 #[cfg(feature = "l1")]
447 l1_capacity: Option<usize>,
448
449 #[cfg(feature = "l1")]
450 no_l1: bool,
451
452 #[cfg(feature = "encryption")]
453 encryption: Option<SharedEncryption>,
454}
455
456impl CacheKitBuilder {
457 pub fn backend(mut self, backend: SharedBackend) -> Self {
459 self.backend = Some(backend);
460 self
461 }
462
463 pub fn default_ttl(mut self, ttl: Duration) -> Self {
465 self.default_ttl = Some(ttl);
466 self
467 }
468
469 pub fn namespace(mut self, ns: impl Into<String>) -> Self {
471 self.namespace = Some(ns.into());
472 self
473 }
474
475 pub fn max_payload_bytes(mut self, limit: usize) -> Self {
477 self.max_payload_bytes = Some(limit);
478 self
479 }
480
481 #[cfg(feature = "l1")]
483 pub fn l1_capacity(mut self, capacity: usize) -> Self {
484 self.l1_capacity = Some(capacity);
485 self
486 }
487
488 #[cfg(feature = "l1")]
490 pub fn no_l1(mut self) -> Self {
491 self.no_l1 = true;
492 self
493 }
494
495 #[cfg(not(feature = "l1"))]
497 pub fn l1_capacity(self, _capacity: usize) -> Self {
498 self
499 }
500
501 #[cfg(not(feature = "l1"))]
502 pub fn no_l1(self) -> Self {
503 self
504 }
505
506 #[cfg(feature = "encryption")]
511 pub fn encryption_from_bytes(
512 mut self,
513 master_key: &[u8],
514 tenant_id: &str,
515 ) -> Result<Self, CachekitError> {
516 let layer = crate::encryption::EncryptionLayer::new(master_key, tenant_id)?;
517 self.encryption = Some(SharedEncryption::new(layer));
518 Ok(self)
519 }
520
521 #[cfg(feature = "encryption")]
526 pub fn encryption(self, hex_key: &str, tenant_id: &str) -> Result<Self, CachekitError> {
527 let bytes = hex::decode(hex_key)
528 .map_err(|e| CachekitError::Config(format!("master key is not valid hex: {e}")))?;
529 self.encryption_from_bytes(&bytes, tenant_id)
530 }
531
532 #[cfg(not(feature = "encryption"))]
534 pub fn encryption_from_bytes(
535 self,
536 _master_key: &[u8],
537 _tenant_id: &str,
538 ) -> Result<Self, CachekitError> {
539 Ok(self)
540 }
541
542 #[cfg(not(feature = "encryption"))]
543 pub fn encryption(self, _hex_key: &str, _tenant_id: &str) -> Result<Self, CachekitError> {
544 Ok(self)
545 }
546
547 pub fn build(self) -> Result<CacheKit, CachekitError> {
551 let backend = self.backend.ok_or_else(|| {
552 CachekitError::Config("a backend must be provided via .backend()".to_owned())
553 })?;
554
555 if let Some(ref ns) = self.namespace {
557 if ns.is_empty() {
558 return Err(CachekitError::Config("namespace cannot be empty".into()));
559 }
560 if ns.len() > 255 {
561 return Err(CachekitError::Config("namespace exceeds 255 bytes".into()));
562 }
563 if !ns.bytes().all(|b| (0x20..=0x7E).contains(&b)) {
564 return Err(CachekitError::Config(
565 "namespace must be ASCII printable".into(),
566 ));
567 }
568 }
569
570 #[cfg(feature = "l1")]
571 let l1 = if self.no_l1 {
572 None
573 } else {
574 let capacity = self.l1_capacity.unwrap_or(1000);
575 Some(crate::l1::L1Cache::new(capacity))
576 };
577
578 Ok(CacheKit {
579 backend,
580 default_ttl: self.default_ttl.unwrap_or(Duration::from_secs(300)),
581 namespace: self.namespace,
582 max_payload_bytes: self.max_payload_bytes.unwrap_or(5 * 1024 * 1024),
583
584 #[cfg(feature = "l1")]
585 l1,
586
587 #[cfg(feature = "encryption")]
588 encryption: self.encryption,
589 })
590 }
591}