heliosdb_proxy/distribcache/ai/
conversation.rs1use dashmap::DashMap;
6use std::collections::VecDeque;
7use std::sync::Mutex;
8use std::time::Instant;
9
10pub type ConversationId = String;
12
13#[derive(Debug, Clone)]
15pub struct Turn {
16 pub id: String,
18 pub role: String,
20 pub content: String,
22 pub timestamp: Instant,
24 pub token_count: usize,
26 pub metadata: Option<serde_json::Value>,
28}
29
30impl Turn {
31 pub fn new(id: impl Into<String>, role: impl Into<String>, content: impl Into<String>) -> Self {
33 let content = content.into();
34 let token_count = content.split_whitespace().count() * 4 / 3; Self {
37 id: id.into(),
38 role: role.into(),
39 content,
40 timestamp: Instant::now(),
41 token_count,
42 metadata: None,
43 }
44 }
45
46 pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
48 self.metadata = Some(metadata);
49 self
50 }
51
52 pub fn size(&self) -> usize {
54 self.id.len() + self.role.len() + self.content.len() + 64
55 }
56}
57
58#[derive(Debug)]
60pub struct ConversationContext {
61 pub id: ConversationId,
63 pub turns: VecDeque<Turn>,
65 max_turns: usize,
67 total_tokens: usize,
69 last_access: Instant,
71}
72
73impl ConversationContext {
74 fn new(id: ConversationId, max_turns: usize) -> Self {
75 Self {
76 id,
77 turns: VecDeque::with_capacity(max_turns),
78 max_turns,
79 total_tokens: 0,
80 last_access: Instant::now(),
81 }
82 }
83
84 fn append(&mut self, turn: Turn) {
85 self.total_tokens += turn.token_count;
86 self.turns.push_back(turn);
87
88 while self.turns.len() > self.max_turns {
90 if let Some(removed) = self.turns.pop_front() {
91 self.total_tokens = self.total_tokens.saturating_sub(removed.token_count);
92 }
93 }
94
95 self.last_access = Instant::now();
96 }
97
98 fn get_recent(&self, count: usize) -> Vec<Turn> {
99 self.turns.iter().rev().take(count).rev().cloned().collect()
100 }
101
102 fn size(&self) -> usize {
103 self.turns.iter().map(|t| t.size()).sum()
104 }
105}
106
107struct LruTracker {
109 order: Mutex<VecDeque<ConversationId>>,
110 #[allow(dead_code)]
111 max_size: usize,
112}
113
114impl LruTracker {
115 fn new(max_size: usize) -> Self {
116 Self {
117 order: Mutex::new(VecDeque::with_capacity(max_size)),
118 max_size,
119 }
120 }
121
122 fn touch(&self, id: &ConversationId) {
123 let mut order = self.order.lock().unwrap();
124
125 if let Some(pos) = order.iter().position(|x| x == id) {
127 order.remove(pos);
128 }
129
130 order.push_back(id.clone());
132 }
133
134 fn evict_oldest(&self) -> Option<ConversationId> {
135 self.order.lock().unwrap().pop_front()
136 }
137}
138
139pub struct ConversationContextCache {
141 contexts: DashMap<ConversationId, ConversationContext>,
143
144 lru: LruTracker,
146
147 max_turns: usize,
149
150 max_conversations: usize,
152}
153
154impl ConversationContextCache {
155 pub fn new(max_conversations: usize, max_turns: usize) -> Self {
157 Self {
158 contexts: DashMap::new(),
159 lru: LruTracker::new(max_conversations),
160 max_turns,
161 max_conversations,
162 }
163 }
164
165 pub fn get_context(&self, conv_id: &str, max_turns: usize) -> Option<Vec<Turn>> {
167 self.lru.touch(&conv_id.to_string());
168
169 let ctx = self.contexts.get(conv_id)?;
170 Some(ctx.get_recent(max_turns))
171 }
172
173 pub fn get_full_context(&self, conv_id: &str) -> Option<Vec<Turn>> {
175 self.lru.touch(&conv_id.to_string());
176
177 let ctx = self.contexts.get(conv_id)?;
178 Some(ctx.turns.iter().cloned().collect())
179 }
180
181 pub fn append_turn(&self, conv_id: &str, turn: Turn) {
183 self.lru.touch(&conv_id.to_string());
184
185 while self.contexts.len() >= self.max_conversations {
187 if let Some(old_id) = self.lru.evict_oldest() {
188 self.contexts.remove(&old_id);
189 } else {
190 break;
191 }
192 }
193
194 let mut ctx = self
196 .contexts
197 .entry(conv_id.to_string())
198 .or_insert_with(|| ConversationContext::new(conv_id.to_string(), self.max_turns));
199
200 ctx.append(turn);
201 }
202
203 pub fn clear_conversation(&self, conv_id: &str) {
205 self.contexts.remove(conv_id);
206 }
207
208 pub fn conversation_count(&self) -> usize {
210 self.contexts.len()
211 }
212
213 pub fn total_tokens(&self) -> usize {
215 self.contexts.iter().map(|ctx| ctx.total_tokens).sum()
216 }
217
218 pub fn stats(&self) -> ConversationCacheStats {
220 let mut total_turns = 0;
221 let mut total_size = 0;
222
223 for ctx in self.contexts.iter() {
224 total_turns += ctx.turns.len();
225 total_size += ctx.size();
226 }
227
228 ConversationCacheStats {
229 conversations: self.contexts.len(),
230 total_turns,
231 total_size_bytes: total_size,
232 total_tokens: self.total_tokens(),
233 }
234 }
235}
236
237#[derive(Debug, Clone)]
239pub struct ConversationCacheStats {
240 pub conversations: usize,
241 pub total_turns: usize,
242 pub total_size_bytes: usize,
243 pub total_tokens: usize,
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_turn_creation() {
252 let turn = Turn::new("1", "user", "Hello, how are you?");
253 assert_eq!(turn.role, "user");
254 assert!(turn.token_count > 0);
255 }
256
257 #[test]
258 fn test_append_and_get_context() {
259 let cache = ConversationContextCache::new(100, 50);
260
261 cache.append_turn("conv-1", Turn::new("1", "user", "Hello"));
262 cache.append_turn("conv-1", Turn::new("2", "assistant", "Hi there!"));
263 cache.append_turn("conv-1", Turn::new("3", "user", "How are you?"));
264
265 let context = cache.get_context("conv-1", 2).unwrap();
266 assert_eq!(context.len(), 2);
267 assert_eq!(context[0].content, "Hi there!");
268 assert_eq!(context[1].content, "How are you?");
269 }
270
271 #[test]
272 fn test_max_turns_limit() {
273 let cache = ConversationContextCache::new(100, 3);
274
275 for i in 0..5 {
276 cache.append_turn(
277 "conv-1",
278 Turn::new(format!("{}", i), "user", format!("Message {}", i)),
279 );
280 }
281
282 let context = cache.get_full_context("conv-1").unwrap();
283 assert_eq!(context.len(), 3);
284 assert_eq!(context[0].content, "Message 2");
285 }
286
287 #[test]
288 fn test_lru_eviction() {
289 let cache = ConversationContextCache::new(2, 10);
290
291 cache.append_turn("conv-1", Turn::new("1", "user", "Hello 1"));
292 cache.append_turn("conv-2", Turn::new("1", "user", "Hello 2"));
293
294 cache.append_turn("conv-3", Turn::new("1", "user", "Hello 3"));
296
297 assert!(cache.get_context("conv-1", 1).is_none());
298 assert!(cache.get_context("conv-2", 1).is_some());
299 assert!(cache.get_context("conv-3", 1).is_some());
300 }
301
302 #[test]
303 fn test_stats() {
304 let cache = ConversationContextCache::new(100, 50);
305
306 cache.append_turn("conv-1", Turn::new("1", "user", "Hello"));
307 cache.append_turn("conv-1", Turn::new("2", "assistant", "Hi"));
308 cache.append_turn("conv-2", Turn::new("1", "user", "Test"));
309
310 let stats = cache.stats();
311 assert_eq!(stats.conversations, 2);
312 assert_eq!(stats.total_turns, 3);
313 assert!(stats.total_tokens > 0);
314 }
315}