Skip to main content

offline_intelligence/context_engine/
tier_manager.rs

1//! Manages the three-tier memory system with robust persistence and indexing
2
3use crate::memory::Message;
4use crate::memory_db::{MemoryDatabase, StoredMessage, SessionMetadata};
5use crate::cache_management::cache_scorer::score_message_importance;
6use moka::sync::Cache;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tracing::{debug, info};
10
11/// Configuration for tier management
12#[derive(Debug, Clone)]
13pub struct TierManagerConfig {
14    /// Maximum number of messages stored in Tier 1 hot cache per session.
15    /// Derived from CTX_SIZE: 40% of context window / avg 20 tokens per message.
16    pub tier1_max_messages: usize,
17    pub enable_tier3_persistence: bool,
18}
19
20impl Default for TierManagerConfig {
21    fn default() -> Self {
22        Self {
23            tier1_max_messages: 50,
24            enable_tier3_persistence: true,
25        }
26    }
27}
28
29impl TierManagerConfig {
30    /// Derive Tier 1 limits from the model's context window.
31    /// 40% of CTX_SIZE is reserved for current-session hot messages.
32    /// Divided by an average of 20 tokens per message to get the message count cap.
33    pub fn from_ctx_size(ctx_size: u32) -> Self {
34        let tier1_token_budget = (ctx_size as f32 * 0.40) as usize;
35        let avg_tokens_per_message = 20;
36        Self {
37            tier1_max_messages: (tier1_token_budget / avg_tokens_per_message).max(10),
38            enable_tier3_persistence: true,
39        }
40    }
41}
42
43/// Statistics about tier usage
44#[derive(Debug, Clone, Default)]
45pub struct TierStats {
46    pub tier1_count: usize,
47    pub tier3_count: usize,
48}
49
50pub struct TierManager {
51    database: Arc<MemoryDatabase>,
52    tier1_cache: Cache<String, (Vec<Message>, Instant)>,
53    pub config: TierManagerConfig,
54}
55
56impl TierManager {
57    pub fn new(
58        database: Arc<MemoryDatabase>,
59        config: TierManagerConfig
60    ) -> Self {
61        Self {
62            database,
63            tier1_cache: Cache::builder()
64                .max_capacity(1000)
65                .time_to_idle(Duration::from_secs(3600))
66                .build(),
67            config,
68        }
69    }
70
71    // --- Tier 1 (Cache) Methods ---
72
73    pub async fn store_tier1_content(&self, session_id: &str, messages: &[Message]) {
74        // Apply tier1 max messages limit
75        let messages_to_store = if messages.len() > self.config.tier1_max_messages {
76            &messages[messages.len() - self.config.tier1_max_messages..]
77        } else {
78            messages
79        };
80        
81        self.tier1_cache.insert(session_id.to_string(), (messages_to_store.to_vec(), Instant::now()));
82    }
83
84    pub async fn get_tier1_content(&self, session_id: &str) -> Option<Vec<Message>> {
85        self.tier1_cache.get(session_id).map(|(m, _)| m)
86    }
87
88    // --- Tier 3 (Database) Methods ---
89
90    pub async fn get_tier3_content(
91        &self, 
92        session_id: &str, 
93        limit: Option<i32>, 
94        offset: Option<i32>
95    ) -> anyhow::Result<Vec<StoredMessage>> {
96        self.database.conversations.get_session_messages(session_id, limit, offset)
97    }
98
99    pub async fn search_tier3_content(
100        &self, 
101        session_id: &str, 
102        query: &str, 
103        limit: usize
104    ) -> anyhow::Result<Vec<StoredMessage>> {
105        let messages = self.database.conversations.get_session_messages(session_id, Some(1000), None)?;
106        let query_lower = query.to_lowercase();
107        
108        let filtered = messages.into_iter()
109            .filter(|m| m.content.to_lowercase().contains(&query_lower))
110            .take(limit)
111            .collect();
112        
113        Ok(filtered)
114    }
115
116    pub async fn store_tier3_content(&self, session_id: &str, messages: &[Message]) -> anyhow::Result<()> {
117        if !self.config.enable_tier3_persistence || messages.is_empty() {
118            return Ok(());
119        }
120        
121        // Ensure session exists in database
122        self.ensure_session_exists(session_id, None).await?;
123        
124        // Get existing messages to find the next index AND check for duplicates
125        let existing_messages = self.database.conversations.get_session_messages(
126            session_id, Some(10000), Some(0)
127        ).unwrap_or_else(|_| vec![]);
128        
129        // Filter out messages that already exist (simple content-based deduplication)
130        let new_messages: Vec<&Message> = messages.iter()
131            .filter(|new_msg| {
132                !existing_messages.iter().any(|existing| {
133                    existing.content == new_msg.content && 
134                    existing.role == new_msg.role
135                })
136            })
137            .collect();
138        
139        if new_messages.is_empty() {
140            debug!("No new messages to save, all already exist in database");
141            return Ok(()); // Nothing new to save
142        }
143        
144        let start_index = existing_messages.len() as i32;
145        
146        // Create batch data for ONLY new messages
147        let batch_data: Vec<(String, String, i32, i32, f32)> = new_messages
148            .iter()
149            .enumerate()
150            .map(|(offset, m)| (
151                m.role.clone(),
152                m.content.clone(),
153                start_index + offset as i32, // Ensure unique index
154                (m.content.len() / 4) as i32,
155                score_message_importance(&m.role, &m.content)
156            ))
157            .collect();
158        
159        if !batch_data.is_empty() {
160            self.database.conversations.store_messages_batch(session_id, &batch_data)?;
161            info!("📝 Stored {} new messages to database for session {}", batch_data.len(), session_id);
162        }
163        
164        Ok(())
165    }
166
167    // --- Cross-Session Content Methods ---
168
169    /// Searches across all sessions except the current one based on keyword extraction
170    pub async fn search_cross_session_content(
171        &self,
172        current_session_id: &str,
173        query: &str,
174        limit: usize,
175    ) -> anyhow::Result<Vec<StoredMessage>> {
176        // Extract keywords from query
177        let keywords = self.extract_keywords(query);
178        
179        if keywords.is_empty() {
180            return Ok(vec![]);
181        }
182
183        // Search across ALL sessions except current one
184        self.database.conversations.search_messages_by_topic_across_sessions(
185            &keywords,
186            limit,
187            Some(current_session_id), // Exclude current session
188        ).await
189    }
190
191    fn extract_keywords(&self, text: &str) -> Vec<String> {
192        let words: Vec<&str> = text.split_whitespace().collect();
193        words.iter()
194            .filter(|w| w.len() > 3)
195            .map(|w| w.to_lowercase())
196            .filter(|w| !self.is_stop_word(w))
197            .collect()
198    }
199
200    fn is_stop_word(&self, word: &str) -> bool {
201        let stop_words = [
202            "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
203            "of", "with", "by", "is", "am", "are", "was", "were", "be", "been",
204            "being", "have", "has", "had", "do", "does", "did", "will", "would",
205            "shall", "should", "may", "might", "must", "can", "could",
206        ];
207        stop_words.contains(&word)
208    }
209
210    // --- Maintenance & Stats ---
211
212    pub async fn get_tier_stats(&self, session_id: &str) -> TierStats {
213        let tier1_count = self.get_tier1_content(session_id).await
214            .map(|m| m.len())
215            .unwrap_or(0);
216
217        let tier3_count = match self.database.conversations.get_session_messages(session_id, Some(10000), None) {
218            Ok(messages) => messages.len(),
219            Err(_) => 0,
220        };
221
222        TierStats { tier1_count, tier3_count }
223    }
224
225    pub async fn cleanup_cache(&self, _older_than_seconds: u64) -> usize {
226        let count = self.tier1_cache.entry_count();
227        self.tier1_cache.invalidate_all();
228        count as usize
229    }
230
231    /// Chat persistence: Ensure session exists in database with provided ID (no auto-generated placeholders)
232    pub async fn ensure_session_exists(
233        &self, 
234        session_id: &str, 
235        title: Option<String>
236    ) -> anyhow::Result<()> {
237        let exists = self.database.conversations.get_session(session_id)?;
238        if exists.is_none() {
239            // Create session with null title initially - title set via API after generation
240            let metadata = SessionMetadata {
241                title, // None initially; title updated later via update_conversation_title API
242                ..Default::default()
243            };
244            self.database.conversations.create_session_with_id(session_id, Some(metadata))?;
245        }
246        Ok(())
247    }
248}
249
250impl Clone for TierManager {
251    fn clone(&self) -> Self {
252        Self {
253            database: self.database.clone(),
254            tier1_cache: Cache::builder()
255                .max_capacity(1000)
256                .time_to_idle(Duration::from_secs(3600))
257                .build(),
258            config: self.config.clone(),
259        }
260    }
261}