Skip to main content

alice_core/memory/
domain.rs

1//! Domain types for the memory subsystem.
2
3use serde::{Deserialize, Serialize};
4
5use crate::memory::error::MemoryValidationError;
6
7/// Importance level for memory entries.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9pub enum MemoryImportance {
10    /// Low signal.
11    Low,
12    /// Default level.
13    Medium,
14    /// High signal.
15    High,
16}
17
18impl MemoryImportance {
19    /// Serialize to persistence-friendly string.
20    #[must_use]
21    pub const fn as_str(&self) -> &'static str {
22        match self {
23            Self::Low => "low",
24            Self::Medium => "medium",
25            Self::High => "high",
26        }
27    }
28
29    /// Deserialize from storage string.
30    #[must_use]
31    pub fn from_db(value: &str) -> Self {
32        match value {
33            "low" => Self::Low,
34            "high" => Self::High,
35            _ => Self::Medium,
36        }
37    }
38}
39
40/// Persisted memory record.
41#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
42pub struct MemoryEntry {
43    /// Stable entry identifier.
44    pub id: String,
45    /// Conversation session identifier.
46    pub session_id: String,
47    /// Topic key.
48    pub topic: String,
49    /// Compact summary used for retrieval.
50    pub summary: String,
51    /// Full excerpt or concatenated turn content.
52    pub raw_excerpt: String,
53    /// Searchable keywords.
54    pub keywords: Vec<String>,
55    /// Importance signal.
56    pub importance: MemoryImportance,
57    /// Optional vector embedding.
58    pub embedding: Option<Vec<f32>>,
59    /// Unix epoch milliseconds.
60    pub created_at_epoch_ms: i64,
61}
62
63/// Query used for turn recall.
64#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
65pub struct RecallQuery {
66    /// Optional session filter.
67    pub session_id: Option<String>,
68    /// Free-form query text.
69    pub text: String,
70    /// Optional embedding used for vector retrieval.
71    pub query_embedding: Option<Vec<f32>>,
72    /// Max number of results.
73    pub limit: usize,
74}
75
76/// Weighted recall hit returned by hybrid retrieval.
77#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
78pub struct RecallHit {
79    /// Matched entry.
80    pub entry: MemoryEntry,
81    /// Normalized BM25 score in `[0, 1]`.
82    pub bm25_score: f32,
83    /// Normalized vector similarity in `[0, 1]`.
84    pub vector_score: Option<f32>,
85    /// Final fused score in `[0, 1]`.
86    pub final_score: f32,
87}
88
89/// Hybrid rank fusion weights.
90#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
91pub struct HybridWeights {
92    /// BM25 contribution.
93    pub bm25: f32,
94    /// Vector contribution.
95    pub vector: f32,
96}
97
98impl HybridWeights {
99    /// Build validated weights where each component is in `[0, 1]` and sum is positive.
100    pub fn new(bm25: f32, vector: f32) -> Result<Self, MemoryValidationError> {
101        if !(0.0..=1.0).contains(&bm25) || !(0.0..=1.0).contains(&vector) {
102            return Err(MemoryValidationError::InvalidHybridWeights { bm25, vector });
103        }
104        let total = bm25 + vector;
105        if total <= f32::EPSILON {
106            return Err(MemoryValidationError::InvalidHybridWeights { bm25, vector });
107        }
108        Ok(Self { bm25: bm25 / total, vector: vector / total })
109    }
110}
111
112impl Default for HybridWeights {
113    fn default() -> Self {
114        Self { bm25: 0.3, vector: 0.7 }
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    /// All `MemoryImportance` variants return the expected persistence string.
123    #[test]
124    fn memory_importance_as_str() {
125        assert_eq!(MemoryImportance::Low.as_str(), "low");
126        assert_eq!(MemoryImportance::Medium.as_str(), "medium");
127        assert_eq!(MemoryImportance::High.as_str(), "high");
128    }
129
130    /// Round-tripping through `as_str` then `from_db` preserves every variant.
131    #[test]
132    fn memory_importance_from_db_roundtrip() {
133        for variant in [MemoryImportance::Low, MemoryImportance::Medium, MemoryImportance::High] {
134            assert_eq!(MemoryImportance::from_db(variant.as_str()), variant);
135        }
136        // Unknown strings fall back to Medium.
137        assert_eq!(MemoryImportance::from_db("unknown"), MemoryImportance::Medium);
138    }
139
140    /// `MemoryEntry` fields are stored exactly as provided.
141    #[test]
142    fn memory_entry_construction() {
143        let entry = MemoryEntry {
144            id: "id-1".to_string(),
145            session_id: "sess-1".to_string(),
146            topic: "greetings".to_string(),
147            summary: "hello world".to_string(),
148            raw_excerpt: "raw".to_string(),
149            keywords: vec!["hello".to_string()],
150            importance: MemoryImportance::High,
151            embedding: None,
152            created_at_epoch_ms: 1_000,
153        };
154        assert_eq!(entry.id, "id-1");
155        assert_eq!(entry.session_id, "sess-1");
156        assert_eq!(entry.topic, "greetings");
157        assert_eq!(entry.importance, MemoryImportance::High);
158        assert!(entry.embedding.is_none());
159        assert_eq!(entry.created_at_epoch_ms, 1_000);
160    }
161
162    /// `RecallQuery` handles unicode and special characters without panicking.
163    #[test]
164    fn recall_query_with_special_chars() {
165        let query = RecallQuery {
166            session_id: Some("s-\u{1F600}".to_string()),
167            text: "\u{4F60}\u{597D} hello <>&\"'".to_string(),
168            query_embedding: None,
169            limit: 10,
170        };
171        assert!(query.text.contains('\u{4F60}'));
172        assert_eq!(query.limit, 10);
173    }
174
175    /// Default `HybridWeights` components sum to 1.0.
176    #[test]
177    fn hybrid_weights_default_sums_to_one() {
178        let w = HybridWeights::default();
179        let sum = w.bm25 + w.vector;
180        assert!((sum - 1.0).abs() < f32::EPSILON);
181    }
182}