oxios_memory/memory/manager/
ops.rs1use anyhow::Result;
4
5use crate::memory::auto_protect::AutoProtector;
6use crate::memory::hnsw_memory_index::{HnswMemoryIndex, SemanticHit};
7use crate::memory::storage::MemoryStorageExt;
8use crate::memory::types::{MemoryEntry, MemoryTier, MemoryType};
9
10use super::MemoryManager;
11
12impl MemoryManager {
13 pub async fn semantic_search(
19 &self,
20 query: &str,
21 memory_type: Option<MemoryType>,
22 limit: usize,
23 hnsw_index: &HnswMemoryIndex,
24 ) -> Result<Vec<SemanticHit>> {
25 if hnsw_index.is_empty() {
27 tracing::debug!("HNSW index empty, falling back to keyword search");
28 return self
29 .keyword_search(query, memory_type, limit)
30 .await
31 .map(|entries| {
32 entries
33 .into_iter()
34 .map(|entry| SemanticHit {
35 entry,
36 distance: 0.0,
37 similarity: 0.0,
38 })
39 .collect()
40 });
41 }
42
43 let query_vector = self.embedding.embed(query).await?;
45 let query_f32 = match query_vector.to_f32_dense() {
46 Some(v) => v,
47 None => {
48 tracing::debug!("Query embedding is sparse, falling back to keyword search");
49 return self
50 .keyword_search(query, memory_type, limit)
51 .await
52 .map(|entries| {
53 entries
54 .into_iter()
55 .map(|entry| SemanticHit {
56 entry,
57 distance: 0.0,
58 similarity: 0.0,
59 })
60 .collect()
61 });
62 }
63 };
64
65 let raw_hits = hnsw_index.search(&query_f32, limit * 2)?;
67
68 let types: &[MemoryType] = match memory_type {
70 Some(ref t) => std::slice::from_ref(t),
71 None => MemoryType::all(),
72 };
73
74 let mut results = Vec::new();
76 for (id, distance) in raw_hits {
77 for mt in types {
78 if let Ok(Some(mut entry)) = self
79 .storage
80 .load_json::<MemoryEntry>(mt.category(), &id)
81 .await
82 {
83 AutoProtector::record_access(&mut entry, "");
84
85 let similarity = 1.0 - distance;
86 results.push(SemanticHit {
87 entry,
88 distance,
89 similarity,
90 });
91 break;
92 }
93 }
94 if results.len() >= limit {
95 break;
96 }
97 }
98
99 results.sort_by(|a, b| {
101 b.similarity
102 .partial_cmp(&a.similarity)
103 .unwrap_or(std::cmp::Ordering::Equal)
104 });
105
106 tracing::debug!(
107 query = %query,
108 hits = results.len(),
109 "Semantic search complete"
110 );
111
112 if results.is_empty() {
114 return self
115 .keyword_search(query, memory_type, limit)
116 .await
117 .map(|entries| {
118 entries
119 .into_iter()
120 .map(|entry| SemanticHit {
121 entry,
122 distance: 0.0,
123 similarity: 0.0,
124 })
125 .collect()
126 });
127 }
128
129 Ok(results)
130 }
131
132 pub async fn rebuild_hnsw_index(&self, hnsw_index: &HnswMemoryIndex) -> Result<usize> {
136 let mut count = 0;
137
138 for mt in MemoryType::all() {
139 if let Ok(names) = self.storage.list_category(mt.category()).await {
140 for name in names {
141 if let Ok(Some(entry)) = self
142 .storage
143 .load_json::<MemoryEntry>(mt.category(), &name)
144 .await
145 {
146 let vector = self.embedding.embed(&entry.content).await?;
147 if let Some(f32_vec) = vector.to_f32_dense() {
148 if let Err(e) = hnsw_index.add_entry(&entry.id, &f32_vec) {
149 tracing::warn!(
150 id = %entry.id,
151 error = %e,
152 "Failed to add entry to HNSW index"
153 );
154 continue;
155 }
156 count += 1;
157 }
158 }
159 }
160 }
161 }
162
163 tracing::info!(entries = count, "HNSW index rebuilt");
164 Ok(count)
165 }
166
167 pub async fn list_by_tier(&self, tier: MemoryTier, limit: usize) -> Result<Vec<MemoryEntry>> {
173 #[cfg(feature = "sqlite-memory")]
174 if let Some(ref sqlite) = self.sqlite_store {
175 return sqlite.list_by_tier(tier, limit);
176 }
177
178 let mut results = Vec::new();
179 for mt in MemoryType::all() {
180 if let Ok(entries) = self.list(*mt, limit).await {
181 for entry in entries {
182 if entry.tier == tier {
183 results.push(entry);
184 }
185 }
186 }
187 if results.len() >= limit {
188 break;
189 }
190 }
191 results.truncate(limit);
192 Ok(results)
193 }
194
195 pub async fn get_by_id(&self, id: &str) -> Result<Option<MemoryEntry>> {
197 for mt in MemoryType::all() {
198 if let Ok(Some(entry)) = self.get(id, *mt).await {
199 return Ok(Some(entry));
200 }
201 }
202 Ok(None)
203 }
204
205 pub async fn load_by_reference(&self, reference: &str) -> Result<Option<MemoryEntry>> {
207 if let Ok(Some(entry)) = self.get_by_id(reference).await {
209 return Ok(Some(entry));
210 }
211 if let Some((cat, name)) = reference.split_once('/') {
213 if let Ok(Some(entry)) = self.storage.load_json::<MemoryEntry>(cat, name).await {
214 return Ok(Some(entry));
215 }
216 }
217 Ok(None)
218 }
219
220 pub async fn select_by_manifest(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
222 self.keyword_search(query, None, limit).await
223 }
224
225 pub async fn build_hot_context(&self, token_budget: usize) -> Result<String> {
227 let hot_entries = self.list_by_tier(MemoryTier::Hot, 50).await?;
228
229 let mut context_parts = Vec::new();
230 let mut char_budget = token_budget * 4;
231
232 for entry in &hot_entries {
233 let line = format!("- [{}] {}", entry.memory_type.label(), entry.content);
234 if line.len() > char_budget {
235 break;
236 }
237 char_budget -= line.len();
238 context_parts.push(line);
239 }
240
241 if context_parts.is_empty() {
242 Ok(String::new())
243 } else {
244 Ok(format!("## Active Context\n\n{}", context_parts.join("\n")))
245 }
246 }
247
248 pub async fn build_full_context(
250 &self,
251 _query: &str,
252 system_prompt: &str,
253 token_budget: usize,
254 ) -> Result<String> {
255 let hot_ctx = self.build_hot_context(token_budget).await?;
256
257 if hot_ctx.is_empty() {
258 return Ok(system_prompt.to_string());
259 }
260
261 Ok(format!("{system_prompt}\n\n{hot_ctx}"))
262 }
263
264 pub async fn shift_tier(&self, id: &str, from: MemoryTier, to: MemoryTier) -> Result<()> {
266 if let Ok(Some(mut entry)) = self.get_by_id(id).await {
267 if entry.tier == from {
268 entry.tier = to;
269 self.remember(entry).await?;
270 }
271 }
272 Ok(())
273 }
274
275 pub async fn pin(&self, id: &str) -> Result<()> {
277 if let Ok(Some(mut entry)) = self.get_by_id(id).await {
278 entry.pinned = true;
279 entry.protection = crate::memory::types::ProtectionLevel::Permanent;
280 self.remember(entry).await?;
281 }
282 Ok(())
283 }
284
285 pub async fn unpin(&self, id: &str) -> Result<()> {
287 if let Ok(Some(mut entry)) = self.get_by_id(id).await {
288 entry.pinned = false;
289 let protector = crate::memory::auto_protect::AutoProtector::default_protector();
291 entry.protection = protector.compute_protection(&entry);
292 self.remember(entry).await?;
293 }
294 Ok(())
295 }
296
297 pub async fn set_importance(&self, id: &str, importance: f32) -> Result<()> {
299 if let Ok(Some(mut entry)) = self.get_by_id(id).await {
300 entry.importance = importance.clamp(0.0, 1.0);
301 self.remember(entry).await?;
302 }
303 Ok(())
304 }
305
306 pub async fn recompute_all_decay(&self, multiplier: f32) -> Result<usize> {
310 let engine = crate::memory::decay::DecayEngine::new(multiplier);
311 let now = chrono::Utc::now();
312 let mut count = 0;
313
314 for mt in MemoryType::all() {
315 if let Ok(entries) = self.list(*mt, 1_000_000).await {
316 for mut entry in entries {
317 let new_decay = engine.compute_decay(&entry, now);
318 if (entry.decay_score - new_decay).abs() > 0.001 {
319 entry.decay_score = new_decay;
320 self.remember(entry).await?;
321 count += 1;
322 }
323 }
324 }
325 }
326
327 Ok(count)
328 }
329
330 pub async fn immediate_hot_overflow(&self, hot_max: usize) -> Result<usize> {
332 let hot_entries = self.list_by_tier(MemoryTier::Hot, hot_max * 2).await?;
333 if hot_entries.len() <= hot_max {
334 return Ok(0);
335 }
336
337 let overflow = hot_entries.len() - hot_max;
338 let mut candidates: Vec<MemoryEntry> = hot_entries
339 .into_iter()
340 .filter(|e| e.protection < crate::memory::types::ProtectionLevel::High && !e.pinned)
341 .collect();
342
343 candidates.sort_by(|a, b| {
344 a.protection.cmp(&b.protection).then(
345 a.decay_score
346 .partial_cmp(&b.decay_score)
347 .unwrap_or(std::cmp::Ordering::Equal),
348 )
349 });
350
351 let mut demoted = 0;
352 for entry in candidates.into_iter().take(overflow) {
353 self.shift_tier(&entry.id, MemoryTier::Hot, MemoryTier::Warm)
354 .await?;
355 demoted += 1;
356 }
357
358 Ok(demoted)
359 }
360}