dspy_rs/utils/
cache.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use foyer::{BlockEngineBuilder, DeviceBuilder, FsDeviceBuilder, HybridCache, HybridCacheBuilder};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use tempfile;
7use tokio::sync::mpsc;
8
9use crate::{Example, Prediction};
10
11type CacheKey = Vec<(String, Value)>;
12
13#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct CallResult {
15    pub prompt: String,
16    pub prediction: Prediction,
17}
18
19#[async_trait]
20pub trait Cache: Send + Sync {
21    async fn new() -> Self;
22    async fn get(&self, key: Example) -> Result<Option<Prediction>>;
23    async fn insert(&mut self, key: Example, rx: mpsc::Receiver<CallResult>) -> Result<()>;
24    async fn get_history(&self, n: usize) -> Result<Vec<CallResult>>;
25}
26
27#[derive(Clone)]
28pub struct ResponseCache {
29    handler: HybridCache<CacheKey, CallResult>,
30    window_size: usize,
31    history_window: Vec<CallResult>,
32}
33
34#[async_trait]
35impl Cache for ResponseCache {
36    async fn new() -> Self {
37        let dir = tempfile::tempdir().unwrap();
38
39        let device = FsDeviceBuilder::new(dir.path())
40            .with_capacity(1024 * 1024 * 1024)
41            .build()
42            .unwrap();
43
44        let hybrid: HybridCache<CacheKey, CallResult> = HybridCacheBuilder::new()
45            .memory(256 * 1024 * 1024)
46            .storage()
47            .with_engine_config(BlockEngineBuilder::new(device))
48            .build()
49            .await
50            .unwrap();
51        Self {
52            handler: hybrid,
53            window_size: 100,
54            history_window: Vec::new(),
55        }
56    }
57
58    async fn get(&self, key: Example) -> Result<Option<Prediction>> {
59        let key = key.into_iter().collect::<CacheKey>();
60
61        let value = self.handler.get(&key).await?.map(|v| v.value().clone());
62
63        Ok(value.map(|entry| entry.prediction))
64    }
65
66    async fn insert(&mut self, key: Example, mut rx: mpsc::Receiver<CallResult>) -> Result<()> {
67        let key = key.into_iter().collect::<CacheKey>();
68        let value = rx.recv().await.unwrap();
69
70        self.history_window.insert(0, value.clone());
71        if self.history_window.len() > self.window_size {
72            self.history_window.pop();
73        }
74        self.handler.insert(key, value.clone());
75
76        Ok(())
77    }
78
79    async fn get_history(&self, n: usize) -> Result<Vec<CallResult>> {
80        let actual_n = n.min(self.history_window.len());
81        Ok(self.history_window[..actual_n].to_vec())
82    }
83}