Skip to main content

heliosdb_proxy/distribcache/ai/
conversation.rs

1//! Conversation context cache for AI agents
2//!
3//! Caches recent conversation turns for quick context retrieval.
4
5use dashmap::DashMap;
6use std::collections::VecDeque;
7use std::sync::Mutex;
8use std::time::Instant;
9
10/// Conversation identifier
11pub type ConversationId = String;
12
13/// A conversation turn
14#[derive(Debug, Clone)]
15pub struct Turn {
16    /// Turn identifier
17    pub id: String,
18    /// Role (user, assistant, system)
19    pub role: String,
20    /// Content
21    pub content: String,
22    /// Timestamp
23    pub timestamp: Instant,
24    /// Token count (approximate)
25    pub token_count: usize,
26    /// Metadata
27    pub metadata: Option<serde_json::Value>,
28}
29
30impl Turn {
31    /// Create a new turn
32    pub fn new(id: impl Into<String>, role: impl Into<String>, content: impl Into<String>) -> Self {
33        let content = content.into();
34        let token_count = content.split_whitespace().count() * 4 / 3; // Rough estimate
35
36        Self {
37            id: id.into(),
38            role: role.into(),
39            content,
40            timestamp: Instant::now(),
41            token_count,
42            metadata: None,
43        }
44    }
45
46    /// Add metadata
47    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
48        self.metadata = Some(metadata);
49        self
50    }
51
52    /// Approximate size in bytes
53    pub fn size(&self) -> usize {
54        self.id.len() + self.role.len() + self.content.len() + 64
55    }
56}
57
58/// Conversation context
59#[derive(Debug)]
60pub struct ConversationContext {
61    /// Conversation ID
62    pub id: ConversationId,
63    /// Conversation turns
64    pub turns: VecDeque<Turn>,
65    /// Maximum turns to keep
66    max_turns: usize,
67    /// Total token count
68    total_tokens: usize,
69    /// Last access time
70    last_access: Instant,
71}
72
73impl ConversationContext {
74    fn new(id: ConversationId, max_turns: usize) -> Self {
75        Self {
76            id,
77            turns: VecDeque::with_capacity(max_turns),
78            max_turns,
79            total_tokens: 0,
80            last_access: Instant::now(),
81        }
82    }
83
84    fn append(&mut self, turn: Turn) {
85        self.total_tokens += turn.token_count;
86        self.turns.push_back(turn);
87
88        // Maintain size limit
89        while self.turns.len() > self.max_turns {
90            if let Some(removed) = self.turns.pop_front() {
91                self.total_tokens = self.total_tokens.saturating_sub(removed.token_count);
92            }
93        }
94
95        self.last_access = Instant::now();
96    }
97
98    fn get_recent(&self, count: usize) -> Vec<Turn> {
99        self.turns.iter()
100            .rev()
101            .take(count)
102            .rev()
103            .cloned()
104            .collect()
105    }
106
107    fn size(&self) -> usize {
108        self.turns.iter().map(|t| t.size()).sum()
109    }
110}
111
112/// LRU cache for conversation eviction
113struct LruTracker {
114    order: Mutex<VecDeque<ConversationId>>,
115    max_size: usize,
116}
117
118impl LruTracker {
119    fn new(max_size: usize) -> Self {
120        Self {
121            order: Mutex::new(VecDeque::with_capacity(max_size)),
122            max_size,
123        }
124    }
125
126    fn touch(&self, id: &ConversationId) {
127        let mut order = self.order.lock().unwrap();
128
129        // Remove existing entry
130        if let Some(pos) = order.iter().position(|x| x == id) {
131            order.remove(pos);
132        }
133
134        // Add to end (most recent)
135        order.push_back(id.clone());
136    }
137
138    fn evict_oldest(&self) -> Option<ConversationId> {
139        self.order.lock().unwrap().pop_front()
140    }
141}
142
143/// Conversation context cache
144pub struct ConversationContextCache {
145    /// Contexts per conversation
146    contexts: DashMap<ConversationId, ConversationContext>,
147
148    /// LRU tracker
149    lru: LruTracker,
150
151    /// Maximum turns per conversation
152    max_turns: usize,
153
154    /// Maximum conversations to cache
155    max_conversations: usize,
156}
157
158impl ConversationContextCache {
159    /// Create a new cache
160    pub fn new(max_conversations: usize, max_turns: usize) -> Self {
161        Self {
162            contexts: DashMap::new(),
163            lru: LruTracker::new(max_conversations),
164            max_turns,
165            max_conversations,
166        }
167    }
168
169    /// Get context for a conversation
170    pub fn get_context(&self, conv_id: &str, max_turns: usize) -> Option<Vec<Turn>> {
171        self.lru.touch(&conv_id.to_string());
172
173        let ctx = self.contexts.get(conv_id)?;
174        Some(ctx.get_recent(max_turns))
175    }
176
177    /// Get full context
178    pub fn get_full_context(&self, conv_id: &str) -> Option<Vec<Turn>> {
179        self.lru.touch(&conv_id.to_string());
180
181        let ctx = self.contexts.get(conv_id)?;
182        Some(ctx.turns.iter().cloned().collect())
183    }
184
185    /// Append a turn to a conversation
186    pub fn append_turn(&self, conv_id: &str, turn: Turn) {
187        self.lru.touch(&conv_id.to_string());
188
189        // Evict if at capacity
190        while self.contexts.len() >= self.max_conversations {
191            if let Some(old_id) = self.lru.evict_oldest() {
192                self.contexts.remove(&old_id);
193            } else {
194                break;
195            }
196        }
197
198        // Get or create context
199        let mut ctx = self.contexts
200            .entry(conv_id.to_string())
201            .or_insert_with(|| ConversationContext::new(conv_id.to_string(), self.max_turns));
202
203        ctx.append(turn);
204    }
205
206    /// Clear a conversation
207    pub fn clear_conversation(&self, conv_id: &str) {
208        self.contexts.remove(conv_id);
209    }
210
211    /// Get conversation count
212    pub fn conversation_count(&self) -> usize {
213        self.contexts.len()
214    }
215
216    /// Get total tokens cached
217    pub fn total_tokens(&self) -> usize {
218        self.contexts.iter()
219            .map(|ctx| ctx.total_tokens)
220            .sum()
221    }
222
223    /// Get stats
224    pub fn stats(&self) -> ConversationCacheStats {
225        let mut total_turns = 0;
226        let mut total_size = 0;
227
228        for ctx in self.contexts.iter() {
229            total_turns += ctx.turns.len();
230            total_size += ctx.size();
231        }
232
233        ConversationCacheStats {
234            conversations: self.contexts.len(),
235            total_turns,
236            total_size_bytes: total_size,
237            total_tokens: self.total_tokens(),
238        }
239    }
240}
241
242/// Conversation cache statistics
243#[derive(Debug, Clone)]
244pub struct ConversationCacheStats {
245    pub conversations: usize,
246    pub total_turns: usize,
247    pub total_size_bytes: usize,
248    pub total_tokens: usize,
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_turn_creation() {
257        let turn = Turn::new("1", "user", "Hello, how are you?");
258        assert_eq!(turn.role, "user");
259        assert!(turn.token_count > 0);
260    }
261
262    #[test]
263    fn test_append_and_get_context() {
264        let cache = ConversationContextCache::new(100, 50);
265
266        cache.append_turn("conv-1", Turn::new("1", "user", "Hello"));
267        cache.append_turn("conv-1", Turn::new("2", "assistant", "Hi there!"));
268        cache.append_turn("conv-1", Turn::new("3", "user", "How are you?"));
269
270        let context = cache.get_context("conv-1", 2).unwrap();
271        assert_eq!(context.len(), 2);
272        assert_eq!(context[0].content, "Hi there!");
273        assert_eq!(context[1].content, "How are you?");
274    }
275
276    #[test]
277    fn test_max_turns_limit() {
278        let cache = ConversationContextCache::new(100, 3);
279
280        for i in 0..5 {
281            cache.append_turn("conv-1", Turn::new(
282                format!("{}", i),
283                "user",
284                format!("Message {}", i),
285            ));
286        }
287
288        let context = cache.get_full_context("conv-1").unwrap();
289        assert_eq!(context.len(), 3);
290        assert_eq!(context[0].content, "Message 2");
291    }
292
293    #[test]
294    fn test_lru_eviction() {
295        let cache = ConversationContextCache::new(2, 10);
296
297        cache.append_turn("conv-1", Turn::new("1", "user", "Hello 1"));
298        cache.append_turn("conv-2", Turn::new("1", "user", "Hello 2"));
299
300        // This should evict conv-1
301        cache.append_turn("conv-3", Turn::new("1", "user", "Hello 3"));
302
303        assert!(cache.get_context("conv-1", 1).is_none());
304        assert!(cache.get_context("conv-2", 1).is_some());
305        assert!(cache.get_context("conv-3", 1).is_some());
306    }
307
308    #[test]
309    fn test_stats() {
310        let cache = ConversationContextCache::new(100, 50);
311
312        cache.append_turn("conv-1", Turn::new("1", "user", "Hello"));
313        cache.append_turn("conv-1", Turn::new("2", "assistant", "Hi"));
314        cache.append_turn("conv-2", Turn::new("1", "user", "Test"));
315
316        let stats = cache.stats();
317        assert_eq!(stats.conversations, 2);
318        assert_eq!(stats.total_turns, 3);
319        assert!(stats.total_tokens > 0);
320    }
321}