heliosdb_proxy/distribcache/ai/
conversation.rs1use dashmap::DashMap;
6use std::collections::VecDeque;
7use std::sync::Mutex;
8use std::time::Instant;
9
10pub type ConversationId = String;
12
13#[derive(Debug, Clone)]
15pub struct Turn {
16 pub id: String,
18 pub role: String,
20 pub content: String,
22 pub timestamp: Instant,
24 pub token_count: usize,
26 pub metadata: Option<serde_json::Value>,
28}
29
30impl Turn {
31 pub fn new(id: impl Into<String>, role: impl Into<String>, content: impl Into<String>) -> Self {
33 let content = content.into();
34 let token_count = content.split_whitespace().count() * 4 / 3; Self {
37 id: id.into(),
38 role: role.into(),
39 content,
40 timestamp: Instant::now(),
41 token_count,
42 metadata: None,
43 }
44 }
45
46 pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
48 self.metadata = Some(metadata);
49 self
50 }
51
52 pub fn size(&self) -> usize {
54 self.id.len() + self.role.len() + self.content.len() + 64
55 }
56}
57
58#[derive(Debug)]
60pub struct ConversationContext {
61 pub id: ConversationId,
63 pub turns: VecDeque<Turn>,
65 max_turns: usize,
67 total_tokens: usize,
69 last_access: Instant,
71}
72
73impl ConversationContext {
74 fn new(id: ConversationId, max_turns: usize) -> Self {
75 Self {
76 id,
77 turns: VecDeque::with_capacity(max_turns),
78 max_turns,
79 total_tokens: 0,
80 last_access: Instant::now(),
81 }
82 }
83
84 fn append(&mut self, turn: Turn) {
85 self.total_tokens += turn.token_count;
86 self.turns.push_back(turn);
87
88 while self.turns.len() > self.max_turns {
90 if let Some(removed) = self.turns.pop_front() {
91 self.total_tokens = self.total_tokens.saturating_sub(removed.token_count);
92 }
93 }
94
95 self.last_access = Instant::now();
96 }
97
98 fn get_recent(&self, count: usize) -> Vec<Turn> {
99 self.turns.iter()
100 .rev()
101 .take(count)
102 .rev()
103 .cloned()
104 .collect()
105 }
106
107 fn size(&self) -> usize {
108 self.turns.iter().map(|t| t.size()).sum()
109 }
110}
111
112struct LruTracker {
114 order: Mutex<VecDeque<ConversationId>>,
115 max_size: usize,
116}
117
118impl LruTracker {
119 fn new(max_size: usize) -> Self {
120 Self {
121 order: Mutex::new(VecDeque::with_capacity(max_size)),
122 max_size,
123 }
124 }
125
126 fn touch(&self, id: &ConversationId) {
127 let mut order = self.order.lock().unwrap();
128
129 if let Some(pos) = order.iter().position(|x| x == id) {
131 order.remove(pos);
132 }
133
134 order.push_back(id.clone());
136 }
137
138 fn evict_oldest(&self) -> Option<ConversationId> {
139 self.order.lock().unwrap().pop_front()
140 }
141}
142
143pub struct ConversationContextCache {
145 contexts: DashMap<ConversationId, ConversationContext>,
147
148 lru: LruTracker,
150
151 max_turns: usize,
153
154 max_conversations: usize,
156}
157
158impl ConversationContextCache {
159 pub fn new(max_conversations: usize, max_turns: usize) -> Self {
161 Self {
162 contexts: DashMap::new(),
163 lru: LruTracker::new(max_conversations),
164 max_turns,
165 max_conversations,
166 }
167 }
168
169 pub fn get_context(&self, conv_id: &str, max_turns: usize) -> Option<Vec<Turn>> {
171 self.lru.touch(&conv_id.to_string());
172
173 let ctx = self.contexts.get(conv_id)?;
174 Some(ctx.get_recent(max_turns))
175 }
176
177 pub fn get_full_context(&self, conv_id: &str) -> Option<Vec<Turn>> {
179 self.lru.touch(&conv_id.to_string());
180
181 let ctx = self.contexts.get(conv_id)?;
182 Some(ctx.turns.iter().cloned().collect())
183 }
184
185 pub fn append_turn(&self, conv_id: &str, turn: Turn) {
187 self.lru.touch(&conv_id.to_string());
188
189 while self.contexts.len() >= self.max_conversations {
191 if let Some(old_id) = self.lru.evict_oldest() {
192 self.contexts.remove(&old_id);
193 } else {
194 break;
195 }
196 }
197
198 let mut ctx = self.contexts
200 .entry(conv_id.to_string())
201 .or_insert_with(|| ConversationContext::new(conv_id.to_string(), self.max_turns));
202
203 ctx.append(turn);
204 }
205
206 pub fn clear_conversation(&self, conv_id: &str) {
208 self.contexts.remove(conv_id);
209 }
210
211 pub fn conversation_count(&self) -> usize {
213 self.contexts.len()
214 }
215
216 pub fn total_tokens(&self) -> usize {
218 self.contexts.iter()
219 .map(|ctx| ctx.total_tokens)
220 .sum()
221 }
222
223 pub fn stats(&self) -> ConversationCacheStats {
225 let mut total_turns = 0;
226 let mut total_size = 0;
227
228 for ctx in self.contexts.iter() {
229 total_turns += ctx.turns.len();
230 total_size += ctx.size();
231 }
232
233 ConversationCacheStats {
234 conversations: self.contexts.len(),
235 total_turns,
236 total_size_bytes: total_size,
237 total_tokens: self.total_tokens(),
238 }
239 }
240}
241
242#[derive(Debug, Clone)]
244pub struct ConversationCacheStats {
245 pub conversations: usize,
246 pub total_turns: usize,
247 pub total_size_bytes: usize,
248 pub total_tokens: usize,
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_turn_creation() {
257 let turn = Turn::new("1", "user", "Hello, how are you?");
258 assert_eq!(turn.role, "user");
259 assert!(turn.token_count > 0);
260 }
261
262 #[test]
263 fn test_append_and_get_context() {
264 let cache = ConversationContextCache::new(100, 50);
265
266 cache.append_turn("conv-1", Turn::new("1", "user", "Hello"));
267 cache.append_turn("conv-1", Turn::new("2", "assistant", "Hi there!"));
268 cache.append_turn("conv-1", Turn::new("3", "user", "How are you?"));
269
270 let context = cache.get_context("conv-1", 2).unwrap();
271 assert_eq!(context.len(), 2);
272 assert_eq!(context[0].content, "Hi there!");
273 assert_eq!(context[1].content, "How are you?");
274 }
275
276 #[test]
277 fn test_max_turns_limit() {
278 let cache = ConversationContextCache::new(100, 3);
279
280 for i in 0..5 {
281 cache.append_turn("conv-1", Turn::new(
282 format!("{}", i),
283 "user",
284 format!("Message {}", i),
285 ));
286 }
287
288 let context = cache.get_full_context("conv-1").unwrap();
289 assert_eq!(context.len(), 3);
290 assert_eq!(context[0].content, "Message 2");
291 }
292
293 #[test]
294 fn test_lru_eviction() {
295 let cache = ConversationContextCache::new(2, 10);
296
297 cache.append_turn("conv-1", Turn::new("1", "user", "Hello 1"));
298 cache.append_turn("conv-2", Turn::new("1", "user", "Hello 2"));
299
300 cache.append_turn("conv-3", Turn::new("1", "user", "Hello 3"));
302
303 assert!(cache.get_context("conv-1", 1).is_none());
304 assert!(cache.get_context("conv-2", 1).is_some());
305 assert!(cache.get_context("conv-3", 1).is_some());
306 }
307
308 #[test]
309 fn test_stats() {
310 let cache = ConversationContextCache::new(100, 50);
311
312 cache.append_turn("conv-1", Turn::new("1", "user", "Hello"));
313 cache.append_turn("conv-1", Turn::new("2", "assistant", "Hi"));
314 cache.append_turn("conv-2", Turn::new("1", "user", "Test"));
315
316 let stats = cache.stats();
317 assert_eq!(stats.conversations, 2);
318 assert_eq!(stats.total_turns, 3);
319 assert!(stats.total_tokens > 0);
320 }
321}