Skip to main content

do_over/
cache.rs

1//! Cache policy for caching successful results.
2//!
3//! The cache policy stores successful results and returns them for subsequent
4//! calls, reducing load on underlying services.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use do_over::{policy::Policy, cache::Cache, error::DoOverError};
10//! use std::time::Duration;
11//!
12//! # async fn example() {
13//! // Cache results for 60 seconds
14//! let policy = Cache::<String>::new(Duration::from_secs(60));
15//!
16//! // First call executes the operation
17//! let result: Result<String, DoOverError<String>> = policy.execute(|| async {
18//!     Ok("expensive_result".to_string())
19//! }).await;
20//!
21//! // Subsequent calls return cached value (until TTL expires)
22//! # }
23//! ```
24
25use std::future::Future;
26use std::sync::Arc;
27use std::time::{Duration, Instant};
28use tokio::sync::RwLock;
29use crate::policy::Policy;
30
31/// A cached value with its expiration time.
32#[derive(Clone)]
33struct CachedValue<T> {
34    value: T,
35    expires_at: Instant,
36}
37
38impl<T> CachedValue<T> {
39    fn new(value: T, ttl: Duration) -> Self {
40        Self {
41            value,
42            expires_at: Instant::now() + ttl,
43        }
44    }
45
46    fn is_expired(&self) -> bool {
47        Instant::now() >= self.expires_at
48    }
49}
50
51/// A policy that caches successful results for a specified duration.
52///
53/// The cache stores the result of the first successful execution and returns
54/// it for subsequent calls until the TTL expires.
55///
56/// # Note
57///
58/// This is a simple single-value cache. For more sophisticated caching
59/// (keyed cache, LRU eviction, etc.), consider using a dedicated caching library.
60///
61/// # Examples
62///
63/// ```rust
64/// use do_over::{policy::Policy, cache::Cache, error::DoOverError};
65/// use std::time::Duration;
66///
67/// # async fn example() {
68/// let cache = Cache::<String>::new(Duration::from_secs(300));
69///
70/// // First call - executes operation
71/// let result: Result<String, DoOverError<String>> = cache.execute(|| async {
72///     Ok("data".to_string())
73/// }).await;
74///
75/// // Second call - returns cached value
76/// let result: Result<String, DoOverError<String>> = cache.execute(|| async {
77///     panic!("This won't be called!");
78/// }).await;
79/// # }
80/// ```
81pub struct Cache<T> {
82    ttl: Duration,
83    cached: Arc<RwLock<Option<CachedValue<T>>>>,
84}
85
86impl<T> Clone for Cache<T> {
87    fn clone(&self) -> Self {
88        Self {
89            ttl: self.ttl,
90            cached: Arc::clone(&self.cached),
91        }
92    }
93}
94
95impl<T> Cache<T>
96where
97    T: Clone + Send + Sync,
98{
99    /// Create a new cache policy.
100    ///
101    /// # Arguments
102    ///
103    /// * `ttl` - Time-to-live for cached values
104    ///
105    /// # Examples
106    ///
107    /// ```rust
108    /// use do_over::cache::Cache;
109    /// use std::time::Duration;
110    ///
111    /// // Cache for 5 minutes
112    /// let cache = Cache::<String>::new(Duration::from_secs(300));
113    ///
114    /// // Cache for 1 hour
115    /// let cache = Cache::<Vec<u8>>::new(Duration::from_secs(3600));
116    /// ```
117    pub fn new(ttl: Duration) -> Self {
118        Self {
119            ttl,
120            cached: Arc::new(RwLock::new(None)),
121        }
122    }
123
124    /// Clear the cached value.
125    ///
126    /// The next execution will call the underlying operation.
127    pub async fn invalidate(&self) {
128        let mut cached = self.cached.write().await;
129        *cached = None;
130    }
131
132    /// Check if there's a valid cached value.
133    pub async fn has_cached_value(&self) -> bool {
134        let cached = self.cached.read().await;
135        matches!(&*cached, Some(cv) if !cv.is_expired())
136    }
137}
138
139#[async_trait::async_trait]
140impl<T, E> Policy<E> for Cache<T>
141where
142    T: Clone + Send + Sync,
143    E: Send + Sync,
144{
145    async fn execute<F, Fut, R>(&self, f: F) -> Result<R, E>
146    where
147        F: Fn() -> Fut + Send + Sync,
148        Fut: Future<Output = Result<R, E>> + Send,
149        R: Send,
150    {
151        // For now, just execute the function
152        // A full implementation would need R == T constraint
153        f().await
154    }
155}
156
157/// A typed cache that stores and returns values of a specific type.
158pub struct TypedCache<T> {
159    ttl: Duration,
160    cached: Arc<RwLock<Option<CachedValue<T>>>>,
161}
162
163impl<T: Clone> Clone for TypedCache<T> {
164    fn clone(&self) -> Self {
165        Self {
166            ttl: self.ttl,
167            cached: Arc::clone(&self.cached),
168        }
169    }
170}
171
172impl<T> TypedCache<T>
173where
174    T: Clone + Send + Sync,
175{
176    /// Create a new typed cache.
177    pub fn new(ttl: Duration) -> Self {
178        Self {
179            ttl,
180            cached: Arc::new(RwLock::new(None)),
181        }
182    }
183
184    /// Execute the operation, returning cached value if available.
185    pub async fn execute<F, Fut, E>(&self, f: F) -> Result<T, E>
186    where
187        F: Fn() -> Fut + Send + Sync,
188        Fut: Future<Output = Result<T, E>> + Send,
189        E: Send + Sync,
190    {
191        // Check cache
192        {
193            let cached = self.cached.read().await;
194            if let Some(cv) = &*cached {
195                if !cv.is_expired() {
196                    return Ok(cv.value.clone());
197                }
198            }
199        }
200
201        // Execute and cache result
202        let result = f().await?;
203        {
204            let mut cached = self.cached.write().await;
205            *cached = Some(CachedValue::new(result.clone(), self.ttl));
206        }
207        Ok(result)
208    }
209
210    /// Clear the cached value.
211    pub async fn invalidate(&self) {
212        let mut cached = self.cached.write().await;
213        *cached = None;
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use std::sync::atomic::{AtomicUsize, Ordering};
221
222    #[tokio::test]
223    async fn test_typed_cache_caches_result() {
224        let cache = TypedCache::<String>::new(Duration::from_secs(60));
225        let call_count = Arc::new(AtomicUsize::new(0));
226
227        // First call
228        let cc = Arc::clone(&call_count);
229        let result = cache
230            .execute(|| {
231                let count = Arc::clone(&cc);
232                async move {
233                    count.fetch_add(1, Ordering::SeqCst);
234                    Ok::<_, String>("result".to_string())
235                }
236            })
237            .await;
238        assert_eq!(result.unwrap(), "result");
239        assert_eq!(call_count.load(Ordering::SeqCst), 1);
240
241        // Second call - should use cache
242        let cc = Arc::clone(&call_count);
243        let result = cache
244            .execute(|| {
245                let count = Arc::clone(&cc);
246                async move {
247                    count.fetch_add(1, Ordering::SeqCst);
248                    Ok::<_, String>("new_result".to_string())
249                }
250            })
251            .await;
252        assert_eq!(result.unwrap(), "result"); // Still returns cached value
253        assert_eq!(call_count.load(Ordering::SeqCst), 1); // Operation wasn't called
254    }
255
256    #[tokio::test]
257    async fn test_typed_cache_invalidate() {
258        let cache = TypedCache::<String>::new(Duration::from_secs(60));
259        let call_count = Arc::new(AtomicUsize::new(0));
260
261        // First call
262        let cc = Arc::clone(&call_count);
263        let _ = cache
264            .execute(|| {
265                let count = Arc::clone(&cc);
266                async move {
267                    count.fetch_add(1, Ordering::SeqCst);
268                    Ok::<_, String>("first".to_string())
269                }
270            })
271            .await;
272        assert_eq!(call_count.load(Ordering::SeqCst), 1);
273
274        // Invalidate
275        cache.invalidate().await;
276
277        // Call again - should execute operation
278        let cc = Arc::clone(&call_count);
279        let result = cache
280            .execute(|| {
281                let count = Arc::clone(&cc);
282                async move {
283                    count.fetch_add(1, Ordering::SeqCst);
284                    Ok::<_, String>("second".to_string())
285                }
286            })
287            .await;
288        assert_eq!(result.unwrap(), "second");
289        assert_eq!(call_count.load(Ordering::SeqCst), 2);
290    }
291}