crabtalk_proxy/ext/
cache.rs1use axum::{Router, http::StatusCode, routing::delete};
2use crabtalk_core::{
3 BoxFuture, ChatCompletionRequest, ChatCompletionResponse, Prefix, RequestContext, Storage,
4 storage_key,
5};
6use sha2::{Digest, Sha256};
7use std::{
8 sync::Arc,
9 time::{SystemTime, UNIX_EPOCH},
10};
11
12struct DigestWriter<'a>(&'a mut Sha256);
14
15impl std::io::Write for DigestWriter<'_> {
16 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
17 self.0.update(buf);
18 Ok(buf.len())
19 }
20 fn flush(&mut self) -> std::io::Result<()> {
21 Ok(())
22 }
23}
24
25pub struct Cache {
26 storage: Arc<dyn Storage>,
27 ttl_seconds: u64,
28}
29
30impl Cache {
31 const PREFIX: Prefix = *b"cach";
32
33 pub fn new(config: &serde_json::Value, storage: Arc<dyn Storage>) -> Result<Self, String> {
34 let ttl_seconds = config
35 .get("ttl_seconds")
36 .and_then(|v| v.as_i64())
37 .unwrap_or(300) as u64;
38
39 Ok(Self {
40 storage,
41 ttl_seconds,
42 })
43 }
44
45 fn cache_key(request: &ChatCompletionRequest) -> Vec<u8> {
46 let mut hasher = Sha256::new();
47 let _ = serde_json::to_writer(DigestWriter(&mut hasher), request);
49 storage_key(&Self::PREFIX, &hasher.finalize())
50 }
51
52 fn now_secs() -> u64 {
53 SystemTime::now()
54 .duration_since(UNIX_EPOCH)
55 .unwrap_or_default()
56 .as_secs()
57 }
58
59 pub fn admin_routes(&self) -> Router {
60 let storage = self.storage.clone();
61 let prefix = Self::PREFIX;
62 Router::new().route(
63 "/v1/cache",
64 delete(move || {
65 let storage = storage.clone();
66 async move {
67 let pairs = storage.list(&prefix).await.unwrap_or_default();
68 for (key, _) in pairs {
69 let _ = storage.delete(&key).await;
70 }
71 StatusCode::NO_CONTENT
72 }
73 }),
74 )
75 }
76}
77
78impl crabtalk_core::Extension for Cache {
79 fn name(&self) -> &str {
80 "cache"
81 }
82
83 fn prefix(&self) -> Prefix {
84 Self::PREFIX
85 }
86
87 fn on_cache_lookup(
88 &self,
89 request: &ChatCompletionRequest,
90 ) -> BoxFuture<'_, Option<ChatCompletionResponse>> {
91 let key = Self::cache_key(request);
92 let ttl = self.ttl_seconds;
93
94 Box::pin(async move {
95 let data = self.storage.get(&key).await.ok()??;
96 if data.len() < 8 {
97 return None;
98 }
99
100 let timestamp = u64::from_be_bytes(data[..8].try_into().ok()?);
101 if Self::now_secs().saturating_sub(timestamp) > ttl {
102 let _ = self.storage.delete(&key).await;
103 return None;
104 }
105
106 serde_json::from_slice(&data[8..]).ok()
107 })
108 }
109
110 fn on_response(
111 &self,
112 ctx: &RequestContext,
113 request: &ChatCompletionRequest,
114 response: &ChatCompletionResponse,
115 ) -> BoxFuture<'_, ()> {
116 if ctx.is_stream {
117 return Box::pin(async {});
118 }
119
120 let key = Self::cache_key(request);
121 let Ok(json) = serde_json::to_vec(response) else {
122 return Box::pin(async {});
123 };
124
125 let mut value = Vec::with_capacity(8 + json.len());
126 value.extend_from_slice(&Self::now_secs().to_be_bytes());
127 value.extend_from_slice(&json);
128
129 Box::pin(async move {
130 let _ = self.storage.set(&key, value).await;
131 })
132 }
133}