Skip to main content

rok_cache/
lib.rs

1pub mod driver;
2pub mod drivers;
3
4mod cache;
5mod error;
6
7#[cfg(feature = "axum")]
8mod cache_layer;
9
10pub use cache::{scope_cache, Cache};
11pub use driver::{build_driver, CacheHandle, Driver};
12pub use error::CacheError;
13
14#[cfg(feature = "axum")]
15pub use cache_layer::CacheLayer;
16
17// ── Tests ─────────────────────────────────────────────────────────────────────
18
19#[cfg(test)]
20mod tests {
21    use std::sync::Arc;
22    use std::time::Duration;
23
24    use crate::driver::{CacheHandle, Driver};
25    use crate::drivers::MemoryDriver;
26    use crate::CacheError;
27    use crate::{scope_cache, Cache};
28
29    fn handle(prefix: &str) -> Arc<CacheHandle> {
30        Arc::new(CacheHandle::new(
31            Driver::Memory(MemoryDriver::new()),
32            prefix,
33        ))
34    }
35
36    #[tokio::test]
37    async fn set_and_get_string() {
38        scope_cache(handle(""), async {
39            Cache::set("greeting", &"hello", None).await.unwrap();
40            let v: Option<String> = Cache::get("greeting").await.unwrap();
41            assert_eq!(v.as_deref(), Some("hello"));
42        })
43        .await;
44    }
45
46    #[tokio::test]
47    async fn get_missing_returns_none() {
48        scope_cache(handle(""), async {
49            let v: Option<i32> = Cache::get("nothing").await.unwrap();
50            assert_eq!(v, None);
51        })
52        .await;
53    }
54
55    #[tokio::test]
56    async fn set_and_get_struct() {
57        use serde::{Deserialize, Serialize};
58        #[derive(Serialize, Deserialize, PartialEq, Debug)]
59        struct User {
60            name: String,
61            age: u32,
62        }
63
64        scope_cache(handle(""), async {
65            let u = User {
66                name: "Alice".into(),
67                age: 30,
68            };
69            Cache::set("user", &u, None).await.unwrap();
70            let got: Option<User> = Cache::get("user").await.unwrap();
71            assert_eq!(
72                got,
73                Some(User {
74                    name: "Alice".into(),
75                    age: 30
76                })
77            );
78        })
79        .await;
80    }
81
82    #[tokio::test]
83    async fn forget_removes_key() {
84        scope_cache(handle(""), async {
85            Cache::set("tmp", &99i32, None).await.unwrap();
86            Cache::forget("tmp").await.unwrap();
87            let v: Option<i32> = Cache::get("tmp").await.unwrap();
88            assert_eq!(v, None);
89        })
90        .await;
91    }
92
93    #[tokio::test]
94    async fn flush_clears_everything() {
95        scope_cache(handle(""), async {
96            Cache::set("a", &1i32, None).await.unwrap();
97            Cache::set("b", &2i32, None).await.unwrap();
98            Cache::flush().await.unwrap();
99            let a: Option<i32> = Cache::get("a").await.unwrap();
100            let b: Option<i32> = Cache::get("b").await.unwrap();
101            assert!(a.is_none() && b.is_none());
102        })
103        .await;
104    }
105
106    #[tokio::test]
107    async fn ttl_expiry_returns_none_after_expiry() {
108        scope_cache(handle(""), async {
109            Cache::set("exp", &"value", Some(Duration::from_millis(50)))
110                .await
111                .unwrap();
112            tokio::time::sleep(Duration::from_millis(100)).await;
113            let v: Option<String> = Cache::get("exp").await.unwrap();
114            assert_eq!(v, None, "expired key should return None");
115        })
116        .await;
117    }
118
119    #[tokio::test]
120    async fn no_expiry_persists() {
121        scope_cache(handle(""), async {
122            Cache::set("perm", &"stay", None).await.unwrap();
123            tokio::time::sleep(Duration::from_millis(20)).await;
124            let v: Option<String> = Cache::get("perm").await.unwrap();
125            assert_eq!(v.as_deref(), Some("stay"));
126        })
127        .await;
128    }
129
130    #[tokio::test]
131    async fn namespace_isolation() {
132        let h_a = handle("ns_a:");
133        let h_b = handle("ns_b:");
134
135        scope_cache(Arc::clone(&h_a), async {
136            Cache::set("key", &"from_a", None).await.unwrap();
137        })
138        .await;
139
140        scope_cache(Arc::clone(&h_b), async {
141            let v: Option<String> = Cache::get("key").await.unwrap();
142            assert_eq!(v, None, "namespace B should not see A's keys");
143        })
144        .await;
145    }
146
147    #[tokio::test]
148    async fn remember_calls_fetcher_once() {
149        use std::sync::atomic::{AtomicU32, Ordering};
150        let calls = Arc::new(AtomicU32::new(0));
151        let c = Arc::clone(&calls);
152
153        scope_cache(handle(""), async move {
154            for _ in 0..4 {
155                let _: String = Cache::remember("expensive", Duration::from_secs(60), || {
156                    let c = Arc::clone(&c);
157                    async move {
158                        c.fetch_add(1, Ordering::SeqCst);
159                        Ok::<_, String>("computed".to_string())
160                    }
161                })
162                .await
163                .unwrap();
164            }
165        })
166        .await;
167
168        assert_eq!(calls.load(std::sync::atomic::Ordering::SeqCst), 1);
169    }
170
171    #[tokio::test]
172    async fn get_without_layer_returns_not_configured() {
173        let res: Result<Option<String>, CacheError> = Cache::get("k").await;
174        assert!(matches!(res, Err(CacheError::NotConfigured)));
175    }
176}