Skip to main content

heliosdb_proxy/distribcache/ai/
conversation.rs

1//! Conversation context cache for AI agents
2//!
3//! Caches recent conversation turns for quick context retrieval.
4
5use dashmap::DashMap;
6use std::collections::VecDeque;
7use std::sync::Mutex;
8use std::time::Instant;
9
10/// Conversation identifier
11pub type ConversationId = String;
12
13/// A conversation turn
14#[derive(Debug, Clone)]
15pub struct Turn {
16    /// Turn identifier
17    pub id: String,
18    /// Role (user, assistant, system)
19    pub role: String,
20    /// Content
21    pub content: String,
22    /// Timestamp
23    pub timestamp: Instant,
24    /// Token count (approximate)
25    pub token_count: usize,
26    /// Metadata
27    pub metadata: Option<serde_json::Value>,
28}
29
30impl Turn {
31    /// Create a new turn
32    pub fn new(id: impl Into<String>, role: impl Into<String>, content: impl Into<String>) -> Self {
33        let content = content.into();
34        let token_count = content.split_whitespace().count() * 4 / 3; // Rough estimate
35
36        Self {
37            id: id.into(),
38            role: role.into(),
39            content,
40            timestamp: Instant::now(),
41            token_count,
42            metadata: None,
43        }
44    }
45
46    /// Add metadata
47    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
48        self.metadata = Some(metadata);
49        self
50    }
51
52    /// Approximate size in bytes
53    pub fn size(&self) -> usize {
54        self.id.len() + self.role.len() + self.content.len() + 64
55    }
56}
57
58/// Conversation context
59#[derive(Debug)]
60pub struct ConversationContext {
61    /// Conversation ID
62    pub id: ConversationId,
63    /// Conversation turns
64    pub turns: VecDeque<Turn>,
65    /// Maximum turns to keep
66    max_turns: usize,
67    /// Total token count
68    total_tokens: usize,
69    /// Last access time
70    last_access: Instant,
71}
72
73impl ConversationContext {
74    fn new(id: ConversationId, max_turns: usize) -> Self {
75        Self {
76            id,
77            turns: VecDeque::with_capacity(max_turns),
78            max_turns,
79            total_tokens: 0,
80            last_access: Instant::now(),
81        }
82    }
83
84    fn append(&mut self, turn: Turn) {
85        self.total_tokens += turn.token_count;
86        self.turns.push_back(turn);
87
88        // Maintain size limit
89        while self.turns.len() > self.max_turns {
90            if let Some(removed) = self.turns.pop_front() {
91                self.total_tokens = self.total_tokens.saturating_sub(removed.token_count);
92            }
93        }
94
95        self.last_access = Instant::now();
96    }
97
98    fn get_recent(&self, count: usize) -> Vec<Turn> {
99        self.turns.iter().rev().take(count).rev().cloned().collect()
100    }
101
102    fn size(&self) -> usize {
103        self.turns.iter().map(|t| t.size()).sum()
104    }
105}
106
107/// LRU cache for conversation eviction
108struct LruTracker {
109    order: Mutex<VecDeque<ConversationId>>,
110    #[allow(dead_code)]
111    max_size: usize,
112}
113
114impl LruTracker {
115    fn new(max_size: usize) -> Self {
116        Self {
117            order: Mutex::new(VecDeque::with_capacity(max_size)),
118            max_size,
119        }
120    }
121
122    fn touch(&self, id: &ConversationId) {
123        let mut order = self.order.lock().unwrap();
124
125        // Remove existing entry
126        if let Some(pos) = order.iter().position(|x| x == id) {
127            order.remove(pos);
128        }
129
130        // Add to end (most recent)
131        order.push_back(id.clone());
132    }
133
134    fn evict_oldest(&self) -> Option<ConversationId> {
135        self.order.lock().unwrap().pop_front()
136    }
137}
138
139/// Conversation context cache
140pub struct ConversationContextCache {
141    /// Contexts per conversation
142    contexts: DashMap<ConversationId, ConversationContext>,
143
144    /// LRU tracker
145    lru: LruTracker,
146
147    /// Maximum turns per conversation
148    max_turns: usize,
149
150    /// Maximum conversations to cache
151    max_conversations: usize,
152}
153
154impl ConversationContextCache {
155    /// Create a new cache
156    pub fn new(max_conversations: usize, max_turns: usize) -> Self {
157        Self {
158            contexts: DashMap::new(),
159            lru: LruTracker::new(max_conversations),
160            max_turns,
161            max_conversations,
162        }
163    }
164
165    /// Get context for a conversation
166    pub fn get_context(&self, conv_id: &str, max_turns: usize) -> Option<Vec<Turn>> {
167        self.lru.touch(&conv_id.to_string());
168
169        let ctx = self.contexts.get(conv_id)?;
170        Some(ctx.get_recent(max_turns))
171    }
172
173    /// Get full context
174    pub fn get_full_context(&self, conv_id: &str) -> Option<Vec<Turn>> {
175        self.lru.touch(&conv_id.to_string());
176
177        let ctx = self.contexts.get(conv_id)?;
178        Some(ctx.turns.iter().cloned().collect())
179    }
180
181    /// Append a turn to a conversation
182    pub fn append_turn(&self, conv_id: &str, turn: Turn) {
183        self.lru.touch(&conv_id.to_string());
184
185        // Evict if at capacity
186        while self.contexts.len() >= self.max_conversations {
187            if let Some(old_id) = self.lru.evict_oldest() {
188                self.contexts.remove(&old_id);
189            } else {
190                break;
191            }
192        }
193
194        // Get or create context
195        let mut ctx = self
196            .contexts
197            .entry(conv_id.to_string())
198            .or_insert_with(|| ConversationContext::new(conv_id.to_string(), self.max_turns));
199
200        ctx.append(turn);
201    }
202
203    /// Clear a conversation
204    pub fn clear_conversation(&self, conv_id: &str) {
205        self.contexts.remove(conv_id);
206    }
207
208    /// Get conversation count
209    pub fn conversation_count(&self) -> usize {
210        self.contexts.len()
211    }
212
213    /// Get total tokens cached
214    pub fn total_tokens(&self) -> usize {
215        self.contexts.iter().map(|ctx| ctx.total_tokens).sum()
216    }
217
218    /// Get stats
219    pub fn stats(&self) -> ConversationCacheStats {
220        let mut total_turns = 0;
221        let mut total_size = 0;
222
223        for ctx in self.contexts.iter() {
224            total_turns += ctx.turns.len();
225            total_size += ctx.size();
226        }
227
228        ConversationCacheStats {
229            conversations: self.contexts.len(),
230            total_turns,
231            total_size_bytes: total_size,
232            total_tokens: self.total_tokens(),
233        }
234    }
235}
236
237/// Conversation cache statistics
238#[derive(Debug, Clone)]
239pub struct ConversationCacheStats {
240    pub conversations: usize,
241    pub total_turns: usize,
242    pub total_size_bytes: usize,
243    pub total_tokens: usize,
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_turn_creation() {
252        let turn = Turn::new("1", "user", "Hello, how are you?");
253        assert_eq!(turn.role, "user");
254        assert!(turn.token_count > 0);
255    }
256
257    #[test]
258    fn test_append_and_get_context() {
259        let cache = ConversationContextCache::new(100, 50);
260
261        cache.append_turn("conv-1", Turn::new("1", "user", "Hello"));
262        cache.append_turn("conv-1", Turn::new("2", "assistant", "Hi there!"));
263        cache.append_turn("conv-1", Turn::new("3", "user", "How are you?"));
264
265        let context = cache.get_context("conv-1", 2).unwrap();
266        assert_eq!(context.len(), 2);
267        assert_eq!(context[0].content, "Hi there!");
268        assert_eq!(context[1].content, "How are you?");
269    }
270
271    #[test]
272    fn test_max_turns_limit() {
273        let cache = ConversationContextCache::new(100, 3);
274
275        for i in 0..5 {
276            cache.append_turn(
277                "conv-1",
278                Turn::new(format!("{}", i), "user", format!("Message {}", i)),
279            );
280        }
281
282        let context = cache.get_full_context("conv-1").unwrap();
283        assert_eq!(context.len(), 3);
284        assert_eq!(context[0].content, "Message 2");
285    }
286
287    #[test]
288    fn test_lru_eviction() {
289        let cache = ConversationContextCache::new(2, 10);
290
291        cache.append_turn("conv-1", Turn::new("1", "user", "Hello 1"));
292        cache.append_turn("conv-2", Turn::new("1", "user", "Hello 2"));
293
294        // This should evict conv-1
295        cache.append_turn("conv-3", Turn::new("1", "user", "Hello 3"));
296
297        assert!(cache.get_context("conv-1", 1).is_none());
298        assert!(cache.get_context("conv-2", 1).is_some());
299        assert!(cache.get_context("conv-3", 1).is_some());
300    }
301
302    #[test]
303    fn test_stats() {
304        let cache = ConversationContextCache::new(100, 50);
305
306        cache.append_turn("conv-1", Turn::new("1", "user", "Hello"));
307        cache.append_turn("conv-1", Turn::new("2", "assistant", "Hi"));
308        cache.append_turn("conv-2", Turn::new("1", "user", "Test"));
309
310        let stats = cache.stats();
311        assert_eq!(stats.conversations, 2);
312        assert_eq!(stats.total_turns, 3);
313        assert!(stats.total_tokens > 0);
314    }
315}