1use anyhow::Result;
2use async_trait::async_trait;
3use foyer::{BlockEngineBuilder, DeviceBuilder, FsDeviceBuilder, HybridCache, HybridCacheBuilder};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use tempfile;
7use tokio::sync::mpsc;
8
9use crate::{Example, Prediction};
10
11type CacheKey = Vec<(String, Value)>;
12
13#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct CallResult {
15 pub prompt: String,
16 pub prediction: Prediction,
17}
18
19#[async_trait]
20pub trait Cache: Send + Sync {
21 async fn new() -> Self;
22 async fn get(&self, key: Example) -> Result<Option<Prediction>>;
23 async fn insert(&mut self, key: Example, rx: mpsc::Receiver<CallResult>) -> Result<()>;
24 async fn get_history(&self, n: usize) -> Result<Vec<CallResult>>;
25}
26
27#[derive(Clone)]
28pub struct ResponseCache {
29 handler: HybridCache<CacheKey, CallResult>,
30 window_size: usize,
31 history_window: Vec<CallResult>,
32}
33
34#[async_trait]
35impl Cache for ResponseCache {
36 async fn new() -> Self {
37 let dir = tempfile::tempdir().unwrap();
38
39 let device = FsDeviceBuilder::new(dir.path())
40 .with_capacity(1024 * 1024 * 1024)
41 .build()
42 .unwrap();
43
44 let hybrid: HybridCache<CacheKey, CallResult> = HybridCacheBuilder::new()
45 .memory(256 * 1024 * 1024)
46 .storage()
47 .with_engine_config(BlockEngineBuilder::new(device))
48 .build()
49 .await
50 .unwrap();
51 Self {
52 handler: hybrid,
53 window_size: 100,
54 history_window: Vec::new(),
55 }
56 }
57
58 async fn get(&self, key: Example) -> Result<Option<Prediction>> {
59 let key = key.into_iter().collect::<CacheKey>();
60
61 let value = self.handler.get(&key).await?.map(|v| v.value().clone());
62
63 Ok(value.map(|entry| entry.prediction))
64 }
65
66 async fn insert(&mut self, key: Example, mut rx: mpsc::Receiver<CallResult>) -> Result<()> {
67 let key = key.into_iter().collect::<CacheKey>();
68 let value = rx.recv().await.unwrap();
69
70 self.history_window.insert(0, value.clone());
71 if self.history_window.len() > self.window_size {
72 self.history_window.pop();
73 }
74 self.handler.insert(key, value.clone());
75
76 Ok(())
77 }
78
79 async fn get_history(&self, n: usize) -> Result<Vec<CallResult>> {
80 let actual_n = n.min(self.history_window.len());
81 Ok(self.history_window[..actual_n].to_vec())
82 }
83}