Skip to main content

crabtalk_proxy/ext/
cache.rs

1use axum::{Router, http::StatusCode, routing::delete};
2use crabtalk_core::{
3    BoxFuture, ChatCompletionRequest, ChatCompletionResponse, Prefix, RequestContext, Storage,
4    storage_key,
5};
6use sha2::{Digest, Sha256};
7use std::{
8    sync::Arc,
9    time::{SystemTime, UNIX_EPOCH},
10};
11
12/// Adapter that feeds bytes directly into a SHA-256 digest (no intermediate buffer).
13struct DigestWriter<'a>(&'a mut Sha256);
14
15impl std::io::Write for DigestWriter<'_> {
16    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
17        self.0.update(buf);
18        Ok(buf.len())
19    }
20    fn flush(&mut self) -> std::io::Result<()> {
21        Ok(())
22    }
23}
24
25pub struct Cache {
26    storage: Arc<dyn Storage>,
27    ttl_seconds: u64,
28}
29
30impl Cache {
31    const PREFIX: Prefix = *b"cach";
32
33    pub fn new(config: &serde_json::Value, storage: Arc<dyn Storage>) -> Result<Self, String> {
34        let ttl_seconds = config
35            .get("ttl_seconds")
36            .and_then(|v| v.as_i64())
37            .unwrap_or(300) as u64;
38
39        Ok(Self {
40            storage,
41            ttl_seconds,
42        })
43    }
44
45    fn cache_key(request: &ChatCompletionRequest) -> Vec<u8> {
46        let mut hasher = Sha256::new();
47        // Write JSON directly into the hasher — no intermediate String allocation.
48        let _ = serde_json::to_writer(DigestWriter(&mut hasher), request);
49        storage_key(&Self::PREFIX, &hasher.finalize())
50    }
51
52    fn now_secs() -> u64 {
53        SystemTime::now()
54            .duration_since(UNIX_EPOCH)
55            .unwrap_or_default()
56            .as_secs()
57    }
58
59    pub fn admin_routes(&self) -> Router {
60        let storage = self.storage.clone();
61        let prefix = Self::PREFIX;
62        Router::new().route(
63            "/v1/cache",
64            delete(move || {
65                let storage = storage.clone();
66                async move {
67                    let pairs = storage.list(&prefix).await.unwrap_or_default();
68                    for (key, _) in pairs {
69                        let _ = storage.delete(&key).await;
70                    }
71                    StatusCode::NO_CONTENT
72                }
73            }),
74        )
75    }
76}
77
78impl crabtalk_core::Extension for Cache {
79    fn name(&self) -> &str {
80        "cache"
81    }
82
83    fn prefix(&self) -> Prefix {
84        Self::PREFIX
85    }
86
87    fn on_cache_lookup(
88        &self,
89        request: &ChatCompletionRequest,
90    ) -> BoxFuture<'_, Option<ChatCompletionResponse>> {
91        let key = Self::cache_key(request);
92        let ttl = self.ttl_seconds;
93
94        Box::pin(async move {
95            let data = self.storage.get(&key).await.ok()??;
96            if data.len() < 8 {
97                return None;
98            }
99
100            let timestamp = u64::from_be_bytes(data[..8].try_into().ok()?);
101            if Self::now_secs().saturating_sub(timestamp) > ttl {
102                let _ = self.storage.delete(&key).await;
103                return None;
104            }
105
106            serde_json::from_slice(&data[8..]).ok()
107        })
108    }
109
110    fn on_response(
111        &self,
112        ctx: &RequestContext,
113        request: &ChatCompletionRequest,
114        response: &ChatCompletionResponse,
115    ) -> BoxFuture<'_, ()> {
116        if ctx.is_stream {
117            return Box::pin(async {});
118        }
119
120        let key = Self::cache_key(request);
121        let Ok(json) = serde_json::to_vec(response) else {
122            return Box::pin(async {});
123        };
124
125        let mut value = Vec::with_capacity(8 + json.len());
126        value.extend_from_slice(&Self::now_secs().to_be_bytes());
127        value.extend_from_slice(&json);
128
129        Box::pin(async move {
130            let _ = self.storage.set(&key, value).await;
131        })
132    }
133}