Skip to main content

brainwires_memory/
tiered_memory.rs

1//! Tiered Memory Storage System
2//!
3//! Implements a three-tier memory hierarchy for conversation storage:
4//! - **Hot**: Full messages - recent, important, or recently accessed
5//! - **Warm**: Compressed summaries - older messages that may be needed
6//! - **Cold**: Ultra-compressed key facts - archival storage
7//!
8//! Messages flow from hot → warm → cold based on age and importance,
9//! and can be promoted back up when accessed.
10//!
11//! ## Persistence
12//!
13//! All tiers are backed by LanceDB for persistence:
14//! - Hot tier: MessageStore (messages table)
15//! - Warm tier: SummaryStore (summaries table)
16//! - Cold tier: FactStore (facts table)
17//! - Metadata: TierMetadataStore (tier_metadata table)
18
19use std::sync::Arc;
20
21use anyhow::Result;
22use chrono::Utc;
23use uuid::Uuid;
24
25use brainwires_storage::CachedEmbeddingProvider;
26use brainwires_storage::databases::{LanceDatabase, StorageBackend};
27
28use brainwires_stores::{
29    FactStore, FactType, KeyFact, MemoryAuthority, MemoryTier, MentalModel, MentalModelStore,
30    MessageMetadata, MessageStore, MessageSummary, ModelType, SummaryStore, TierMetadata,
31    TierMetadataStore,
32};
33// Weight constants (SIMILARITY_WEIGHT, RECENCY_WEIGHT, IMPORTANCE_WEIGHT)
34// are defined locally below — they live in brainwires-stores::tier_types
35// for the schema stores' use, and are duplicated here intentionally to
36// keep the orchestration crate self-contained.
37
38const SIMILARITY_WEIGHT: f32 = 0.50;
39const RECENCY_WEIGHT: f32 = 0.30;
40const IMPORTANCE_WEIGHT: f32 = 0.20;
41const DEFAULT_HOT_RETENTION_HOURS: u64 = 24;
42const DEFAULT_WARM_RETENTION_HOURS: u64 = 168;
43const DEFAULT_HOT_IMPORTANCE_THRESHOLD: f32 = 0.3;
44const DEFAULT_WARM_IMPORTANCE_THRESHOLD: f32 = 0.1;
45const DEFAULT_MAX_HOT_MESSAGES: usize = 1000;
46const DEFAULT_MAX_WARM_SUMMARIES: usize = 5000;
47const FAST_DECAY_RATE: f32 = 0.05;
48
49/// Temporal keywords that imply a recency-sensitive query.
50const TEMPORAL_KEYWORDS: &[&str] = &[
51    "recent",
52    "recently",
53    "latest",
54    "last",
55    "current",
56    "currently",
57    "today",
58    "yesterday",
59    "this week",
60    "now",
61    "just",
62    "new",
63    "newest",
64];
65
66/// Detect whether a query is temporally sensitive.
67///
68/// Returns a score in `[0.0, 1.0]` based on keyword density: each matching
69/// keyword from [`TEMPORAL_KEYWORDS`] contributes 1 hit; score is clamped at
70/// `hits / 3.0` to avoid saturating on very long queries.
71fn detect_temporal_query(query: &str) -> f32 {
72    let lower = query.to_lowercase();
73    let hits = TEMPORAL_KEYWORDS
74        .iter()
75        .filter(|kw| lower.contains(*kw))
76        .count();
77    (hits as f32 / 3.0).min(1.0)
78}
79
80// MemoryAuthority, MemoryTier, TierMetadata, MessageSummary, KeyFact,
81// FactType — moved to brainwires-stores::tier_types (used by both the
82// schema stores and this orchestration layer).
83
84/// Capability token that unlocks writes to the `Canonical` memory authority tier.
85///
86/// The constructor is intentionally `pub(crate)` — external crates obtain one
87/// only through designated authorisation entry points (e.g. a CLI-layer
88/// function or a privileged agent config). This ensures that ordinary agent
89/// tool calls cannot silently promote their output to canonical authority.
90///
91/// ## Example
92/// ```ignore
93/// // Inside crate only:
94/// let token = CanonicalWriteToken::new();
95/// tiered_memory.add_canonical_message(message, 0.9, token).await?;
96/// ```
97#[derive(Debug)]
98pub struct CanonicalWriteToken(());
99
100impl CanonicalWriteToken {
101    /// Create a new token. Only callable within this crate.
102    #[allow(dead_code)]
103    pub(crate) fn new() -> Self {
104        Self(())
105    }
106}
107
108/// Combined retrieval score that blends similarity, recency, and stored importance.
109///
110/// Weights: similarity × 0.50 + recency × 0.30 + importance × 0.20.
111#[derive(Debug, Clone)]
112pub struct MultiFactorScore {
113    /// Raw cosine/dot-product similarity from the embedding search (0–1).
114    pub similarity: f32,
115    /// Recency factor: `exp(−0.01 × hours_since_last_access)`.  1.0 = just
116    /// accessed, approaches 0 for very old entries.
117    pub recency: f32,
118    /// Stored importance score (0–1) from [`TierMetadata::importance`].
119    pub importance: f32,
120    /// Weighted combined score used for ranking.
121    pub combined: f32,
122}
123
124impl MultiFactorScore {
125    /// Compute the combined score from its components using default weights.
126    pub fn compute(similarity: f32, recency: f32, importance: f32) -> Self {
127        Self::compute_with_weights(
128            similarity,
129            recency,
130            importance,
131            SIMILARITY_WEIGHT,
132            RECENCY_WEIGHT,
133            IMPORTANCE_WEIGHT,
134        )
135    }
136
137    /// Compute the combined score using caller-supplied weights.
138    ///
139    /// Weights need not sum to 1.0 — `combined` is the raw dot product and is
140    /// clamped to `[0.0, 1.0]` for consistency.
141    pub fn compute_with_weights(
142        similarity: f32,
143        recency: f32,
144        importance: f32,
145        sim_w: f32,
146        rec_w: f32,
147        imp_w: f32,
148    ) -> Self {
149        let combined = (similarity * sim_w + recency * rec_w + importance * imp_w).clamp(0.0, 1.0);
150        Self {
151            similarity,
152            recency,
153            importance,
154            combined,
155        }
156    }
157
158    /// Decay rate used for the recency factor (per hour).
159    const DECAY_RATE: f32 = 0.01;
160
161    /// Compute the recency factor from `hours_since_last_access`.
162    pub fn recency_from_hours(hours_since_access: f32) -> f32 {
163        (-Self::DECAY_RATE * hours_since_access).exp()
164    }
165
166    /// Compute the recency factor using the fast decay rate (`exp(-0.05 × h)`).
167    pub fn recency_from_hours_fast(hours_since_access: f32) -> f32 {
168        (-FAST_DECAY_RATE * hours_since_access).exp()
169    }
170}
171
172/// Result from adaptive search across tiers
173#[derive(Debug, Clone)]
174pub struct TieredSearchResult {
175    /// The content text.
176    pub content: String,
177    /// Raw similarity score returned by the vector store (0-1).
178    pub score: f32,
179    /// Memory tier this result came from.
180    pub tier: MemoryTier,
181    /// Original message identifier.
182    pub original_message_id: Option<String>,
183    /// Full message metadata if available.
184    pub metadata: Option<MessageMetadata>,
185    /// Multi-factor score blending similarity, recency, and importance.
186    /// Populated by [`TieredMemory::search_adaptive_multi_factor`]; `None` when
187    /// returned by the basic [`TieredMemory::search_adaptive`].
188    pub multi_factor_score: Option<MultiFactorScore>,
189}
190
191/// Configuration for tiered memory behavior
192#[derive(Debug, Clone)]
193pub struct TieredMemoryConfig {
194    /// Hours before considering demotion from hot to warm
195    pub hot_retention_hours: u64,
196    /// Hours before considering demotion from warm to cold
197    pub warm_retention_hours: u64,
198    /// Minimum importance score to stay in hot tier
199    pub hot_importance_threshold: f32,
200    /// Minimum importance score to stay in warm tier
201    pub warm_importance_threshold: f32,
202    /// Maximum messages in hot tier
203    pub max_hot_messages: usize,
204    /// Maximum summaries in warm tier
205    pub max_warm_summaries: usize,
206    /// Optional TTL for session-tier messages, in seconds.
207    ///
208    /// When set, every message added via [`TieredMemory::add_message`] receives
209    /// an `expires_at` timestamp of `now + session_ttl_secs`.  Expired entries
210    /// are removed by [`TieredMemory::evict_expired`] or lazily during
211    /// [`TieredMemory::search_adaptive`].
212    ///
213    /// `None` (the default) means no TTL — messages persist until explicitly
214    /// deleted or demoted.
215    pub session_ttl_secs: Option<u64>,
216    /// Extra recency weight added when a query is detected as temporally
217    /// sensitive (e.g. contains "recent", "latest", "today").
218    ///
219    /// The additional weight is proportional to the query's temporal score
220    /// (`0.0–1.0`).  The three weights are renormalised so they always sum to
221    /// `1.0`.  Default: `0.3`.
222    pub temporal_boost: f32,
223    /// Use a faster recency decay rate (`exp(-0.05 × h)` instead of
224    /// `exp(-0.01 × h)`) when the query is temporally sensitive.
225    ///
226    /// Default: `false`.
227    pub fast_decay: bool,
228    /// Maximum number of synthesised mental models to retain.
229    ///
230    /// Default: `500`.
231    pub max_mental_models: usize,
232}
233
234impl Default for TieredMemoryConfig {
235    fn default() -> Self {
236        Self {
237            hot_retention_hours: DEFAULT_HOT_RETENTION_HOURS,
238            warm_retention_hours: DEFAULT_WARM_RETENTION_HOURS,
239            hot_importance_threshold: DEFAULT_HOT_IMPORTANCE_THRESHOLD,
240            warm_importance_threshold: DEFAULT_WARM_IMPORTANCE_THRESHOLD,
241            max_hot_messages: DEFAULT_MAX_HOT_MESSAGES,
242            max_warm_summaries: DEFAULT_MAX_WARM_SUMMARIES,
243            session_ttl_secs: None,
244            temporal_boost: 0.3,
245            fast_decay: false,
246            max_mental_models: 500,
247        }
248    }
249}
250
251/// Three-tier memory storage system with persistence
252pub struct TieredMemory {
253    /// Hot tier: Full messages (LanceDB-backed)
254    pub hot: Arc<MessageStore>,
255
256    /// Warm tier: Summaries (LanceDB-backed)
257    warm: SummaryStore,
258
259    /// Cold tier: Key facts (LanceDB-backed)
260    cold: FactStore,
261
262    /// Metadata tracking for all messages (LanceDB-backed)
263    tier_metadata: TierMetadataStore,
264
265    /// Mental model tier: synthesised beliefs (LanceDB-backed)
266    mental_model: MentalModelStore,
267
268    /// Configuration
269    config: TieredMemoryConfig,
270
271    /// Embedding provider for searches
272    #[allow(dead_code)]
273    embeddings: Arc<CachedEmbeddingProvider>,
274}
275
276impl TieredMemory {
277    /// Create a new tiered memory system with persistent storage
278    pub async fn new(
279        hot_store: Arc<MessageStore>,
280        db: Arc<LanceDatabase>,
281        embeddings: Arc<CachedEmbeddingProvider>,
282        config: TieredMemoryConfig,
283    ) -> Self {
284        let mental_model = MentalModelStore::new(
285            Arc::clone(&db) as Arc<dyn StorageBackend>,
286            Arc::clone(&embeddings),
287        );
288        Self {
289            hot: hot_store,
290            warm: SummaryStore::new(Arc::clone(&db), Arc::clone(&embeddings)),
291            cold: FactStore::new(Arc::clone(&db), Arc::clone(&embeddings)),
292            tier_metadata: TierMetadataStore::new(db),
293            mental_model,
294            config,
295            embeddings,
296        }
297    }
298
299    /// Create with default configuration
300    pub async fn with_defaults(
301        hot_store: Arc<MessageStore>,
302        db: Arc<LanceDatabase>,
303        embeddings: Arc<CachedEmbeddingProvider>,
304    ) -> Self {
305        Self::new(hot_store, db, embeddings, TieredMemoryConfig::default()).await
306    }
307
308    /// Add a message to the hot tier with `Session` authority.
309    ///
310    /// If `TieredMemoryConfig::session_ttl_secs` is set, the message will be
311    /// assigned an expiry timestamp and will be removed by [`Self::evict_expired`]
312    /// after the configured duration.
313    pub async fn add_message(
314        &mut self,
315        mut message: MessageMetadata,
316        importance: f32,
317    ) -> Result<()> {
318        // Apply TTL if configured
319        if let Some(ttl_secs) = self.config.session_ttl_secs {
320            message.expires_at = Some(Utc::now().timestamp() + ttl_secs as i64);
321        }
322        let metadata = TierMetadata::new(message.message_id.clone(), importance);
323        self.tier_metadata.add(metadata).await?;
324        self.hot.add(message).await
325    }
326
327    /// Add a message to the hot tier with `Canonical` authority.
328    ///
329    /// Canonical entries are long-lived and immune to session-TTL eviction.
330    /// A [`CanonicalWriteToken`] is required to call this method; obtain one
331    /// through an authorised entry point in the CLI layer.
332    pub async fn add_canonical_message(
333        &mut self,
334        message: MessageMetadata,
335        importance: f32,
336        _token: CanonicalWriteToken,
337    ) -> Result<()> {
338        // Canonical entries intentionally have no TTL
339        let metadata = TierMetadata::with_authority(
340            message.message_id.clone(),
341            importance,
342            MemoryAuthority::Canonical,
343        );
344        self.tier_metadata.add(metadata).await?;
345        self.hot.add(message).await
346    }
347
348    /// Delete all hot-tier messages whose TTL has expired.
349    ///
350    /// Returns the number of entries evicted.  Call this at agent run
351    /// completion or on a periodic background schedule.
352    ///
353    /// Canonical-authority messages are never evicted here regardless of
354    /// any `expires_at` value, because they are expected to have `None`.
355    pub async fn evict_expired(&self) -> Result<usize> {
356        let evicted = self.hot.delete_expired().await?;
357        if evicted > 0 {
358            tracing::info!(
359                evicted,
360                "TieredMemory: evicted {} expired message(s)",
361                evicted
362            );
363        }
364        Ok(evicted)
365    }
366
367    /// Record access to a message (for promotion/retention decisions)
368    pub async fn record_access(&mut self, message_id: &str) -> Result<()> {
369        if let Some(mut meta) = self.tier_metadata.get(message_id).await? {
370            meta.record_access();
371            self.tier_metadata.update(meta).await?;
372        }
373        Ok(())
374    }
375
376    /// Search across all tiers with adaptive resolution
377    pub async fn search_adaptive(
378        &mut self,
379        query: &str,
380        conversation_id: Option<&str>,
381    ) -> Result<Vec<TieredSearchResult>> {
382        let mut results = Vec::new();
383
384        // 1. Search hot tier first (full messages)
385        let hot_results = if let Some(conv_id) = conversation_id {
386            self.hot.search_conversation(conv_id, query, 5, 0.6).await?
387        } else {
388            self.hot.search(query, 5, 0.6).await?
389        };
390
391        for (msg, score) in hot_results {
392            // Lazy eviction: skip entries whose TTL has expired
393            if let Some(exp) = msg.expires_at
394                && exp <= Utc::now().timestamp()
395            {
396                continue;
397            }
398
399            // Record access for retention tracking
400            let _ = self.record_access(&msg.message_id).await;
401
402            results.push(TieredSearchResult {
403                content: msg.content.clone(),
404                score,
405                tier: MemoryTier::Hot,
406                original_message_id: Some(msg.message_id.clone()),
407                metadata: Some(msg),
408                multi_factor_score: None,
409            });
410        }
411
412        // If we have high-confidence hot results, return early
413        if results.iter().any(|r| r.score > 0.85) {
414            return Ok(results);
415        }
416
417        // 2. Search warm tier (summaries)
418        let warm_results = if let Some(conv_id) = conversation_id {
419            self.warm
420                .search_conversation(conv_id, query, 3, 0.5)
421                .await?
422        } else {
423            self.warm.search(query, 3, 0.5).await?
424        };
425
426        for (summary, score) in warm_results {
427            results.push(TieredSearchResult {
428                content: summary.summary.clone(),
429                score,
430                tier: MemoryTier::Warm,
431                original_message_id: Some(summary.original_message_id.clone()),
432                metadata: None,
433                multi_factor_score: None,
434            });
435        }
436
437        // 3. If still no good results, search cold tier
438        if results.iter().all(|r| r.score < 0.7) {
439            let cold_results = if let Some(conv_id) = conversation_id {
440                self.cold
441                    .search_conversation(conv_id, query, 3, 0.4)
442                    .await?
443            } else {
444                self.cold.search(query, 3, 0.4).await?
445            };
446
447            for (fact, score) in cold_results {
448                results.push(TieredSearchResult {
449                    content: fact.fact.clone(),
450                    score,
451                    tier: MemoryTier::Cold,
452                    original_message_id: fact.original_message_ids.first().cloned(),
453                    metadata: None,
454                    multi_factor_score: None,
455                });
456            }
457        }
458
459        // Sort by score descending
460        results.sort_by(|a, b| {
461            b.score
462                .partial_cmp(&a.score)
463                .unwrap_or(std::cmp::Ordering::Equal)
464        });
465
466        Ok(results)
467    }
468
469    /// Search across all tiers and score results using combined similarity,
470    /// recency, and importance signals.
471    ///
472    /// This is the preferred retrieval method for long-horizon agent tasks where
473    /// a pure similarity score can surface stale or low-importance results.
474    ///
475    /// The returned results are sorted by [`MultiFactorScore::combined`]
476    /// (descending).  Each result has `multi_factor_score` populated.
477    pub async fn search_adaptive_multi_factor(
478        &mut self,
479        query: &str,
480        conversation_id: Option<&str>,
481    ) -> Result<Vec<TieredSearchResult>> {
482        // Reuse the base search to get similarity-ranked results.
483        let mut results = self.search_adaptive(query, conversation_id).await?;
484
485        // Collect message IDs that have associated tier metadata (hot tier).
486        let ids: Vec<&str> = results
487            .iter()
488            .filter_map(|r| r.original_message_id.as_deref())
489            .collect();
490
491        let meta_map = self.tier_metadata.get_many(&ids).await.unwrap_or_default();
492
493        let now_secs = chrono::Utc::now().timestamp();
494
495        // Compute temporal sensitivity once for the whole query.
496        let temporal_factor = detect_temporal_query(query);
497        let use_fast_decay = self.config.fast_decay && temporal_factor > 0.0;
498
499        // Derive per-query weights, renormalised so they sum to 1.0.
500        let extra_recency = self.config.temporal_boost * temporal_factor;
501        let rec_w = (RECENCY_WEIGHT + extra_recency).min(1.0);
502        let remaining = 1.0 - rec_w;
503        let sim_share = SIMILARITY_WEIGHT / (SIMILARITY_WEIGHT + IMPORTANCE_WEIGHT);
504        let sim_w = sim_share * remaining;
505        let imp_w = remaining - sim_w;
506
507        for result in &mut results {
508            let similarity = result.score;
509
510            let (recency, importance) = if let Some(id) = &result.original_message_id {
511                if let Some(meta) = meta_map.get(id.as_str()) {
512                    let hours_since = (now_secs - meta.last_accessed).max(0) as f32 / 3600.0;
513                    let rec = if use_fast_decay {
514                        MultiFactorScore::recency_from_hours_fast(hours_since)
515                    } else {
516                        MultiFactorScore::recency_from_hours(hours_since)
517                    };
518                    (rec, meta.importance)
519                } else {
520                    (1.0_f32, 0.5_f32) // Fallback: assume fresh + average importance
521                }
522            } else {
523                (1.0_f32, 0.5_f32)
524            };
525
526            result.multi_factor_score = Some(MultiFactorScore::compute_with_weights(
527                similarity, recency, importance, sim_w, rec_w, imp_w,
528            ));
529        }
530
531        // Append mental model tier results (up to 5).
532        if let Ok(mm_results) = self.search_mental_models(query, 5).await {
533            for mut mm in mm_results {
534                mm.multi_factor_score = Some(MultiFactorScore::compute_with_weights(
535                    mm.score, 1.0, // mental models have no recency — treat as always fresh
536                    0.5, // default importance
537                    sim_w, rec_w, imp_w,
538                ));
539                results.push(mm);
540            }
541        }
542
543        // Re-sort by combined score (highest first).
544        results.sort_by(|a, b| {
545            let sa = a
546                .multi_factor_score
547                .as_ref()
548                .map_or(a.score, |s| s.combined);
549            let sb = b
550                .multi_factor_score
551                .as_ref()
552                .map_or(b.score, |s| s.combined);
553            sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal)
554        });
555
556        Ok(results)
557    }
558
559    /// Demote a message from hot to warm tier
560    pub async fn demote_to_warm(
561        &mut self,
562        message_id: &str,
563        summary: MessageSummary,
564    ) -> Result<()> {
565        // Update tier metadata
566        if let Some(mut meta) = self.tier_metadata.get(message_id).await? {
567            meta.tier = MemoryTier::Warm;
568            self.tier_metadata.update(meta).await?;
569        }
570
571        // Add summary to warm tier
572        self.warm.add(summary).await
573    }
574
575    /// Demote a summary from warm to cold tier
576    pub async fn demote_to_cold(&mut self, summary_id: &str, fact: KeyFact) -> Result<()> {
577        // Remove from warm
578        self.warm.delete(summary_id).await?;
579
580        // Add to cold
581        self.cold.add(fact).await
582    }
583
584    /// Promote a message back to hot tier (re-fetch full content)
585    pub async fn promote_to_hot(&mut self, message_id: &str) -> Result<Option<MessageMetadata>> {
586        // Update metadata
587        if let Some(mut meta) = self.tier_metadata.get(message_id).await? {
588            meta.tier = MemoryTier::Hot;
589            meta.record_access();
590            self.tier_metadata.update(meta).await?;
591        }
592
593        // The message should still be in the hot store (we don't delete on demotion)
594        // Just update access tracking
595        Ok(None)
596    }
597
598    /// Get messages that should be considered for demotion
599    pub async fn get_demotion_candidates(
600        &self,
601        tier: MemoryTier,
602        count: usize,
603    ) -> Result<Vec<String>> {
604        let all_metadata = self.tier_metadata.get_by_tier(tier).await?;
605
606        let mut candidates: Vec<_> = all_metadata
607            .into_iter()
608            .map(|m| (m.message_id.clone(), m.retention_score()))
609            .collect();
610
611        // Sort by retention score (lowest first = demote first)
612        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
613
614        Ok(candidates
615            .into_iter()
616            .take(count)
617            .map(|(id, _)| id)
618            .collect())
619    }
620
621    /// Get statistics about tier distribution
622    pub async fn get_stats(&self) -> Result<TieredMemoryStats> {
623        let hot_count = self.tier_metadata.count_by_tier(MemoryTier::Hot).await?;
624        let warm_count = self.warm.count().await?;
625        let cold_count = self.cold.count().await?;
626        let mental_model_count = self.mental_model.count().await.unwrap_or(0);
627        let total_tracked = self.tier_metadata.count().await?;
628
629        Ok(TieredMemoryStats {
630            hot_count,
631            warm_count,
632            cold_count,
633            mental_model_count,
634            total_tracked,
635        })
636    }
637
638    /// Fallback summarization without LLM
639    pub fn fallback_summarize(&self, content: &str) -> String {
640        let words: Vec<&str> = content.split_whitespace().collect();
641        if words.len() <= 75 {
642            content.to_string()
643        } else {
644            format!("{}...", words[..75].join(" "))
645        }
646    }
647
648    /// Create a fallback fact from a summary
649    pub fn fallback_fact(&self, summary: &MessageSummary) -> KeyFact {
650        KeyFact {
651            fact_id: Uuid::new_v4().to_string(),
652            original_message_ids: vec![summary.original_message_id.clone()],
653            conversation_id: summary.conversation_id.clone(),
654            fact: summary.summary.clone(),
655            fact_type: FactType::Other,
656            created_at: Utc::now().timestamp(),
657        }
658    }
659
660    // ── Mental model tier ────────────────────────────────────────────────────
661
662    /// Synthesise and store a new mental model from a set of source fact IDs.
663    ///
664    /// Returns the new model's ID.
665    ///
666    /// The table is created on first call; subsequent calls reuse it.
667    pub async fn synthesize_mental_model(
668        &mut self,
669        fact_ids: &[String],
670        model_text: String,
671        model_type: ModelType,
672        conversation_id: String,
673    ) -> Result<String> {
674        self.mental_model.ensure_table().await?;
675
676        let mut model =
677            MentalModel::new(model_text, model_type, conversation_id, fact_ids.to_vec());
678        model.evidence_count = fact_ids.len() as u32;
679        let id = model.model_id.clone();
680        self.mental_model.add(model).await?;
681        Ok(id)
682    }
683
684    /// Semantic search over the mental model tier.
685    pub async fn search_mental_models(
686        &self,
687        query: &str,
688        limit: usize,
689    ) -> Result<Vec<TieredSearchResult>> {
690        let raw = self.mental_model.search(query, limit).await?;
691        Ok(raw
692            .into_iter()
693            .map(|(model, score)| TieredSearchResult {
694                content: model.model_text.clone(),
695                score,
696                tier: MemoryTier::MentalModel,
697                original_message_id: model.source_fact_ids.first().cloned(),
698                metadata: None,
699                multi_factor_score: None,
700            })
701            .collect())
702    }
703}
704
705/// Statistics about tiered memory usage
706#[derive(Debug, Clone)]
707pub struct TieredMemoryStats {
708    /// Number of entries in the hot tier.
709    pub hot_count: usize,
710    /// Number of entries in the warm tier.
711    pub warm_count: usize,
712    /// Number of entries in the cold tier.
713    pub cold_count: usize,
714    /// Number of synthesised mental models.
715    pub mental_model_count: usize,
716    /// Total tracked entries across all tiers.
717    pub total_tracked: usize,
718}
719
720#[cfg(test)]
721mod tests {
722    use super::*;
723
724    // ── MultiFactorScore ───────────────────────────────────────────────────
725
726    #[test]
727    fn test_multi_factor_score_weights_sum_to_one() {
728        // weights: 0.50 + 0.30 + 0.20 = 1.0
729        let score = MultiFactorScore::compute(1.0, 1.0, 1.0);
730        assert!(
731            (score.combined - 1.0).abs() < 1e-6,
732            "all-one inputs should yield combined=1"
733        );
734    }
735
736    #[test]
737    fn test_multi_factor_score_zero_inputs() {
738        let score = MultiFactorScore::compute(0.0, 0.0, 0.0);
739        assert_eq!(score.combined, 0.0);
740    }
741
742    #[test]
743    fn test_recency_factor_fresh_entry() {
744        // An entry accessed 0 hours ago should have recency ≈ 1.0
745        let r = MultiFactorScore::recency_from_hours(0.0);
746        assert!((r - 1.0).abs() < 1e-6);
747    }
748
749    #[test]
750    fn test_recency_factor_decays_over_time() {
751        let r_now = MultiFactorScore::recency_from_hours(0.0);
752        let r_day = MultiFactorScore::recency_from_hours(24.0);
753        let r_week = MultiFactorScore::recency_from_hours(168.0);
754        assert!(
755            r_now > r_day,
756            "fresh entry must score higher than 1-day-old"
757        );
758        assert!(
759            r_day > r_week,
760            "1-day-old must score higher than 1-week-old"
761        );
762        assert!(r_week > 0.0, "recency factor must remain positive");
763    }
764
765    #[test]
766    fn test_high_similarity_low_recency_can_be_beaten_by_balanced_entry() {
767        // High similarity but stale (1 week old, no importance)
768        let stale =
769            MultiFactorScore::compute(0.95, MultiFactorScore::recency_from_hours(168.0), 0.0);
770        // Moderate similarity but recent and important
771        let fresh = MultiFactorScore::compute(0.70, MultiFactorScore::recency_from_hours(1.0), 0.9);
772        // The balanced entry should edge ahead
773        assert!(
774            fresh.combined > stale.combined,
775            "fresh important entry ({:.3}) should beat stale high-similarity entry ({:.3})",
776            fresh.combined,
777            stale.combined
778        );
779    }
780
781    // ── Tier demotion / promotion ─────────────────────────────────────────
782
783    #[test]
784    fn test_tier_demotion() {
785        assert_eq!(MemoryTier::Hot.demote(), Some(MemoryTier::Warm));
786        assert_eq!(MemoryTier::Warm.demote(), Some(MemoryTier::Cold));
787        assert_eq!(MemoryTier::Cold.demote(), Some(MemoryTier::MentalModel));
788        assert_eq!(MemoryTier::MentalModel.demote(), None);
789    }
790
791    #[test]
792    fn test_tier_promotion() {
793        assert_eq!(MemoryTier::Hot.promote(), None);
794        assert_eq!(MemoryTier::Warm.promote(), Some(MemoryTier::Hot));
795        assert_eq!(MemoryTier::Cold.promote(), Some(MemoryTier::Warm));
796        assert_eq!(MemoryTier::MentalModel.promote(), Some(MemoryTier::Cold));
797    }
798
799    #[test]
800    fn test_tier_metadata_retention_score() {
801        let mut meta = TierMetadata::new("test-1".to_string(), 0.8);
802
803        // High importance should give higher score
804        let score1 = meta.retention_score();
805        assert!(score1 > 0.0);
806
807        // Recording access should maintain or increase score
808        meta.record_access();
809        let score2 = meta.retention_score();
810        assert!(score2 >= score1 * 0.9); // Allow some variance due to time
811    }
812
813    #[test]
814    fn test_default_config() {
815        let config = TieredMemoryConfig::default();
816        assert_eq!(config.hot_retention_hours, 24);
817        assert_eq!(config.warm_retention_hours, 168);
818        assert!(config.hot_importance_threshold > 0.0);
819        assert!(config.session_ttl_secs.is_none());
820    }
821
822    #[test]
823    fn test_config_with_session_ttl() {
824        let config = TieredMemoryConfig {
825            session_ttl_secs: Some(3600),
826            ..TieredMemoryConfig::default()
827        };
828        assert_eq!(config.session_ttl_secs, Some(3600));
829    }
830
831    // ── MemoryAuthority ───────────────────────────────────────────────────
832
833    #[test]
834    fn test_memory_authority_default() {
835        assert_eq!(MemoryAuthority::default(), MemoryAuthority::Session);
836    }
837
838    #[test]
839    fn test_memory_authority_round_trip() {
840        for auth in [
841            MemoryAuthority::Ephemeral,
842            MemoryAuthority::Session,
843            MemoryAuthority::Canonical,
844        ] {
845            assert_eq!(MemoryAuthority::parse(auth.as_str()), auth);
846        }
847    }
848
849    #[test]
850    fn test_memory_authority_unknown_defaults_to_session() {
851        assert_eq!(MemoryAuthority::parse("bogus"), MemoryAuthority::Session);
852    }
853
854    #[test]
855    fn test_tier_metadata_default_authority() {
856        let meta = TierMetadata::new("m-1".to_string(), 0.5);
857        assert_eq!(meta.authority, MemoryAuthority::Session);
858    }
859
860    #[test]
861    fn test_tier_metadata_with_authority() {
862        let meta = TierMetadata::with_authority("m-2".to_string(), 0.9, MemoryAuthority::Canonical);
863        assert_eq!(meta.authority, MemoryAuthority::Canonical);
864        assert_eq!(meta.importance, 0.9);
865    }
866
867    #[test]
868    fn test_canonical_write_token_is_crate_private() {
869        // CanonicalWriteToken::new() is pub(crate) — this test being inside
870        // the same crate confirms we can construct it; external crates cannot.
871        let _token = CanonicalWriteToken::new();
872    }
873
874    // ── Feature 2: Temporal scoring ─────────────────────────────────────
875
876    #[test]
877    fn test_detect_temporal_query_empty() {
878        assert_eq!(detect_temporal_query(""), 0.0);
879    }
880
881    #[test]
882    fn test_detect_temporal_query_no_keywords() {
883        assert_eq!(detect_temporal_query("how does authentication work?"), 0.0);
884    }
885
886    #[test]
887    fn test_detect_temporal_query_single_keyword() {
888        let score = detect_temporal_query("what is the latest approach?");
889        assert!(score > 0.0, "expected score > 0 for 'latest'");
890    }
891
892    #[test]
893    fn test_detect_temporal_query_dense() {
894        let score = detect_temporal_query("what was the latest change today?");
895        assert!(score > 0.0);
896    }
897
898    #[test]
899    fn test_detect_temporal_query_max_clamp() {
900        // Five keywords — result must be <= 1.0
901        let score = detect_temporal_query("recent latest current today now new");
902        assert!(score <= 1.0, "score must not exceed 1.0");
903        assert!(score > 0.0);
904    }
905
906    #[test]
907    fn test_compute_with_weights_sum_normalised() {
908        // Weights sum to 1.0 — combined should equal the weighted sum, clamped.
909        let sim_w = 0.4_f32;
910        let rec_w = 0.4_f32;
911        let imp_w = 0.2_f32;
912        let score = MultiFactorScore::compute_with_weights(0.8, 0.9, 0.6, sim_w, rec_w, imp_w);
913        let expected = (0.8 * sim_w + 0.9 * rec_w + 0.6 * imp_w).clamp(0.0, 1.0);
914        assert!((score.combined - expected).abs() < 1e-5);
915    }
916
917    #[test]
918    fn test_compute_with_weights_matches_compute_for_default_weights() {
919        let a = MultiFactorScore::compute(0.7, 0.8, 0.5);
920        let b = MultiFactorScore::compute_with_weights(
921            0.7,
922            0.8,
923            0.5,
924            SIMILARITY_WEIGHT,
925            RECENCY_WEIGHT,
926            IMPORTANCE_WEIGHT,
927        );
928        assert!((a.combined - b.combined).abs() < 1e-5);
929    }
930
931    #[test]
932    fn test_temporal_config_defaults() {
933        let cfg = TieredMemoryConfig::default();
934        assert_eq!(cfg.temporal_boost, 0.3);
935        assert!(!cfg.fast_decay);
936    }
937
938    #[test]
939    fn test_fast_decay_rate_higher_than_normal() {
940        // Fast decay should make old items score lower than normal decay.
941        let hours = 48.0_f32;
942        let normal = MultiFactorScore::recency_from_hours(hours);
943        let fast = MultiFactorScore::recency_from_hours_fast(hours);
944        assert!(
945            fast < normal,
946            "fast decay should produce lower recency for old items"
947        );
948    }
949}