alice_core/memory/
domain.rs1use serde::{Deserialize, Serialize};
4
5use crate::memory::error::MemoryValidationError;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9pub enum MemoryImportance {
10 Low,
12 Medium,
14 High,
16}
17
18impl MemoryImportance {
19 #[must_use]
21 pub const fn as_str(&self) -> &'static str {
22 match self {
23 Self::Low => "low",
24 Self::Medium => "medium",
25 Self::High => "high",
26 }
27 }
28
29 #[must_use]
31 pub fn from_db(value: &str) -> Self {
32 match value {
33 "low" => Self::Low,
34 "high" => Self::High,
35 _ => Self::Medium,
36 }
37 }
38}
39
40#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
42pub struct MemoryEntry {
43 pub id: String,
45 pub session_id: String,
47 pub topic: String,
49 pub summary: String,
51 pub raw_excerpt: String,
53 pub keywords: Vec<String>,
55 pub importance: MemoryImportance,
57 pub embedding: Option<Vec<f32>>,
59 pub created_at_epoch_ms: i64,
61}
62
63#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
65pub struct RecallQuery {
66 pub session_id: Option<String>,
68 pub text: String,
70 pub query_embedding: Option<Vec<f32>>,
72 pub limit: usize,
74}
75
76#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
78pub struct RecallHit {
79 pub entry: MemoryEntry,
81 pub bm25_score: f32,
83 pub vector_score: Option<f32>,
85 pub final_score: f32,
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
91pub struct HybridWeights {
92 pub bm25: f32,
94 pub vector: f32,
96}
97
98impl HybridWeights {
99 pub fn new(bm25: f32, vector: f32) -> Result<Self, MemoryValidationError> {
101 if !(0.0..=1.0).contains(&bm25) || !(0.0..=1.0).contains(&vector) {
102 return Err(MemoryValidationError::InvalidHybridWeights { bm25, vector });
103 }
104 let total = bm25 + vector;
105 if total <= f32::EPSILON {
106 return Err(MemoryValidationError::InvalidHybridWeights { bm25, vector });
107 }
108 Ok(Self { bm25: bm25 / total, vector: vector / total })
109 }
110}
111
112impl Default for HybridWeights {
113 fn default() -> Self {
114 Self { bm25: 0.3, vector: 0.7 }
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
124 fn memory_importance_as_str() {
125 assert_eq!(MemoryImportance::Low.as_str(), "low");
126 assert_eq!(MemoryImportance::Medium.as_str(), "medium");
127 assert_eq!(MemoryImportance::High.as_str(), "high");
128 }
129
130 #[test]
132 fn memory_importance_from_db_roundtrip() {
133 for variant in [MemoryImportance::Low, MemoryImportance::Medium, MemoryImportance::High] {
134 assert_eq!(MemoryImportance::from_db(variant.as_str()), variant);
135 }
136 assert_eq!(MemoryImportance::from_db("unknown"), MemoryImportance::Medium);
138 }
139
140 #[test]
142 fn memory_entry_construction() {
143 let entry = MemoryEntry {
144 id: "id-1".to_string(),
145 session_id: "sess-1".to_string(),
146 topic: "greetings".to_string(),
147 summary: "hello world".to_string(),
148 raw_excerpt: "raw".to_string(),
149 keywords: vec!["hello".to_string()],
150 importance: MemoryImportance::High,
151 embedding: None,
152 created_at_epoch_ms: 1_000,
153 };
154 assert_eq!(entry.id, "id-1");
155 assert_eq!(entry.session_id, "sess-1");
156 assert_eq!(entry.topic, "greetings");
157 assert_eq!(entry.importance, MemoryImportance::High);
158 assert!(entry.embedding.is_none());
159 assert_eq!(entry.created_at_epoch_ms, 1_000);
160 }
161
162 #[test]
164 fn recall_query_with_special_chars() {
165 let query = RecallQuery {
166 session_id: Some("s-\u{1F600}".to_string()),
167 text: "\u{4F60}\u{597D} hello <>&\"'".to_string(),
168 query_embedding: None,
169 limit: 10,
170 };
171 assert!(query.text.contains('\u{4F60}'));
172 assert_eq!(query.limit, 10);
173 }
174
175 #[test]
177 fn hybrid_weights_default_sums_to_one() {
178 let w = HybridWeights::default();
179 let sum = w.bm25 + w.vector;
180 assert!((sum - 1.0).abs() < f32::EPSILON);
181 }
182}