Skip to main content

oxios_memory/memory/manager/
ops.rs

1//! Advanced memory operations — semantic search, HNSW rebuild, tier management.
2
3use anyhow::Result;
4
5use crate::memory::auto_protect::AutoProtector;
6use crate::memory::hnsw_memory_index::{HnswMemoryIndex, SemanticHit};
7use crate::memory::storage::MemoryStorageExt;
8use crate::memory::types::{MemoryEntry, MemoryTier, MemoryType};
9
10use super::MemoryManager;
11
12impl MemoryManager {
13    /// Semantic search using HNSW index.
14    ///
15    /// Unlike `search()` which uses brute-force cosine similarity over the
16    /// in-memory HashMap, `semantic_search()` uses the HNSW approximate
17    /// nearest neighbor index for sub-linear time complexity.
18    pub async fn semantic_search(
19        &self,
20        query: &str,
21        memory_type: Option<MemoryType>,
22        limit: usize,
23        hnsw_index: &HnswMemoryIndex,
24    ) -> Result<Vec<SemanticHit>> {
25        // Skip if index is empty
26        if hnsw_index.is_empty() {
27            tracing::debug!("HNSW index empty, falling back to keyword search");
28            return self
29                .keyword_search(query, memory_type, limit)
30                .await
31                .map(|entries| {
32                    entries
33                        .into_iter()
34                        .map(|entry| SemanticHit {
35                            entry,
36                            distance: 0.0,
37                            similarity: 0.0,
38                        })
39                        .collect()
40                });
41        }
42
43        // Generate embedding for query
44        let query_vector = self.embedding.embed(query).await?;
45        let query_f32 = match query_vector.to_f32_dense() {
46            Some(v) => v,
47            None => {
48                tracing::debug!("Query embedding is sparse, falling back to keyword search");
49                return self
50                    .keyword_search(query, memory_type, limit)
51                    .await
52                    .map(|entries| {
53                        entries
54                            .into_iter()
55                            .map(|entry| SemanticHit {
56                                entry,
57                                distance: 0.0,
58                                similarity: 0.0,
59                            })
60                            .collect()
61                    });
62            }
63        };
64
65        // Search HNSW index
66        let raw_hits = hnsw_index.search(&query_f32, limit * 2)?;
67
68        // Determine which memory types to search
69        let types: &[MemoryType] = match memory_type {
70            Some(ref t) => std::slice::from_ref(t),
71            None => MemoryType::all(),
72        };
73
74        // Load entries and build results
75        let mut results = Vec::new();
76        for (id, distance) in raw_hits {
77            for mt in types {
78                if let Ok(Some(mut entry)) = self
79                    .storage
80                    .load_json::<MemoryEntry>(mt.category(), &id)
81                    .await
82                {
83                    AutoProtector::record_access(&mut entry, "");
84
85                    let similarity = 1.0 - distance;
86                    results.push(SemanticHit {
87                        entry,
88                        distance,
89                        similarity,
90                    });
91                    break;
92                }
93            }
94            if results.len() >= limit {
95                break;
96            }
97        }
98
99        // Sort by similarity descending
100        results.sort_by(|a, b| {
101            b.similarity
102                .partial_cmp(&a.similarity)
103                .unwrap_or(std::cmp::Ordering::Equal)
104        });
105
106        tracing::debug!(
107            query = %query,
108            hits = results.len(),
109            "Semantic search complete"
110        );
111
112        // Fall back if no results
113        if results.is_empty() {
114            return self
115                .keyword_search(query, memory_type, limit)
116                .await
117                .map(|entries| {
118                    entries
119                        .into_iter()
120                        .map(|entry| SemanticHit {
121                            entry,
122                            distance: 0.0,
123                            similarity: 0.0,
124                        })
125                        .collect()
126                });
127        }
128
129        Ok(results)
130    }
131
132    /// Rebuild the HNSW index from all stored memories.
133    ///
134    /// Call this at startup or after bulk operations.
135    pub async fn rebuild_hnsw_index(&self, hnsw_index: &HnswMemoryIndex) -> Result<usize> {
136        let mut count = 0;
137
138        for mt in MemoryType::all() {
139            if let Ok(names) = self.storage.list_category(mt.category()).await {
140                for name in names {
141                    if let Ok(Some(entry)) = self
142                        .storage
143                        .load_json::<MemoryEntry>(mt.category(), &name)
144                        .await
145                    {
146                        let vector = self.embedding.embed(&entry.content).await?;
147                        if let Some(f32_vec) = vector.to_f32_dense() {
148                            if let Err(e) = hnsw_index.add_entry(&entry.id, &f32_vec) {
149                                tracing::warn!(
150                                    id = %entry.id,
151                                    error = %e,
152                                    "Failed to add entry to HNSW index"
153                                );
154                                continue;
155                            }
156                            count += 1;
157                        }
158                    }
159                }
160            }
161        }
162
163        tracing::info!(entries = count, "HNSW index rebuilt");
164        Ok(count)
165    }
166
167    // ------------------------------------------------------------------
168    // RFC-008: Tier-aware and new memory operations
169    // ------------------------------------------------------------------
170
171    /// List memories by tier (loads all types, filters by tier field).
172    pub async fn list_by_tier(&self, tier: MemoryTier, limit: usize) -> Result<Vec<MemoryEntry>> {
173        #[cfg(feature = "sqlite-memory")]
174        if let Some(ref sqlite) = self.sqlite_store {
175            return sqlite.list_by_tier(tier, limit);
176        }
177
178        let mut results = Vec::new();
179        for mt in MemoryType::all() {
180            if let Ok(entries) = self.list(*mt, limit).await {
181                for entry in entries {
182                    if entry.tier == tier {
183                        results.push(entry);
184                    }
185                }
186            }
187            if results.len() >= limit {
188                break;
189            }
190        }
191        results.truncate(limit);
192        Ok(results)
193    }
194
195    /// Get a memory entry by ID (searches all types).
196    pub async fn get_by_id(&self, id: &str) -> Result<Option<MemoryEntry>> {
197        for mt in MemoryType::all() {
198            if let Ok(Some(entry)) = self.get(id, *mt).await {
199                return Ok(Some(entry));
200            }
201        }
202        Ok(None)
203    }
204
205    /// Load a memory entry by reference string (ID or category/id).
206    pub async fn load_by_reference(&self, reference: &str) -> Result<Option<MemoryEntry>> {
207        // Try as direct ID first
208        if let Ok(Some(entry)) = self.get_by_id(reference).await {
209            return Ok(Some(entry));
210        }
211        // Try as category/name format
212        if let Some((cat, name)) = reference.split_once('/') {
213            if let Ok(Some(entry)) = self.storage.load_json::<MemoryEntry>(cat, name).await {
214                return Ok(Some(entry));
215            }
216        }
217        Ok(None)
218    }
219
220    /// Select memories by manifest (keyword matching against content).
221    pub async fn select_by_manifest(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
222        self.keyword_search(query, None, limit).await
223    }
224
225    /// Build the Hot tier context for agent prompt injection.
226    pub async fn build_hot_context(&self, token_budget: usize) -> Result<String> {
227        let hot_entries = self.list_by_tier(MemoryTier::Hot, 50).await?;
228
229        let mut context_parts = Vec::new();
230        let mut char_budget = token_budget * 4;
231
232        for entry in &hot_entries {
233            let line = format!("- [{}] {}", entry.memory_type.label(), entry.content);
234            if line.len() > char_budget {
235                break;
236            }
237            char_budget -= line.len();
238            context_parts.push(line);
239        }
240
241        if context_parts.is_empty() {
242            Ok(String::new())
243        } else {
244            Ok(format!("## Active Context\n\n{}", context_parts.join("\n")))
245        }
246    }
247
248    /// Build full context: hot context + proactive recall blended into system prompt.
249    pub async fn build_full_context(
250        &self,
251        _query: &str,
252        system_prompt: &str,
253        token_budget: usize,
254    ) -> Result<String> {
255        let hot_ctx = self.build_hot_context(token_budget).await?;
256
257        if hot_ctx.is_empty() {
258            return Ok(system_prompt.to_string());
259        }
260
261        Ok(format!("{system_prompt}\n\n{hot_ctx}"))
262    }
263
264    /// Shift a memory entry between tiers.
265    pub async fn shift_tier(&self, id: &str, from: MemoryTier, to: MemoryTier) -> Result<()> {
266        if let Ok(Some(mut entry)) = self.get_by_id(id).await {
267            if entry.tier == from {
268                entry.tier = to;
269                self.remember(entry).await?;
270            }
271        }
272        Ok(())
273    }
274
275    /// Pin a memory (set permanent protection).
276    pub async fn pin(&self, id: &str) -> Result<()> {
277        if let Ok(Some(mut entry)) = self.get_by_id(id).await {
278            entry.pinned = true;
279            entry.protection = crate::memory::types::ProtectionLevel::Permanent;
280            self.remember(entry).await?;
281        }
282        Ok(())
283    }
284
285    /// Unpin a memory (revert to auto-computed protection).
286    pub async fn unpin(&self, id: &str) -> Result<()> {
287        if let Ok(Some(mut entry)) = self.get_by_id(id).await {
288            entry.pinned = false;
289            // Recompute protection
290            let protector = crate::memory::auto_protect::AutoProtector::default_protector();
291            entry.protection = protector.compute_protection(&entry);
292            self.remember(entry).await?;
293        }
294        Ok(())
295    }
296
297    /// Set importance for a memory entry.
298    pub async fn set_importance(&self, id: &str, importance: f32) -> Result<()> {
299        if let Ok(Some(mut entry)) = self.get_by_id(id).await {
300            entry.importance = importance.clamp(0.0, 1.0);
301            self.remember(entry).await?;
302        }
303        Ok(())
304    }
305
306    /// Recompute decay scores for all entries.
307    ///
308    /// Returns the number of entries updated.
309    pub async fn recompute_all_decay(&self, multiplier: f32) -> Result<usize> {
310        let engine = crate::memory::decay::DecayEngine::new(multiplier);
311        let now = chrono::Utc::now();
312        let mut count = 0;
313
314        for mt in MemoryType::all() {
315            if let Ok(entries) = self.list(*mt, 1_000_000).await {
316                for mut entry in entries {
317                    let new_decay = engine.compute_decay(&entry, now);
318                    if (entry.decay_score - new_decay).abs() > 0.001 {
319                        entry.decay_score = new_decay;
320                        self.remember(entry).await?;
321                        count += 1;
322                    }
323                }
324            }
325        }
326
327        Ok(count)
328    }
329
330    /// Immediate Hot overflow handling.
331    pub async fn immediate_hot_overflow(&self, hot_max: usize) -> Result<usize> {
332        let hot_entries = self.list_by_tier(MemoryTier::Hot, hot_max * 2).await?;
333        if hot_entries.len() <= hot_max {
334            return Ok(0);
335        }
336
337        let overflow = hot_entries.len() - hot_max;
338        let mut candidates: Vec<MemoryEntry> = hot_entries
339            .into_iter()
340            .filter(|e| e.protection < crate::memory::types::ProtectionLevel::High && !e.pinned)
341            .collect();
342
343        candidates.sort_by(|a, b| {
344            a.protection.cmp(&b.protection).then(
345                a.decay_score
346                    .partial_cmp(&b.decay_score)
347                    .unwrap_or(std::cmp::Ordering::Equal),
348            )
349        });
350
351        let mut demoted = 0;
352        for entry in candidates.into_iter().take(overflow) {
353            self.shift_tier(&entry.id, MemoryTier::Hot, MemoryTier::Warm)
354                .await?;
355            demoted += 1;
356        }
357
358        Ok(demoted)
359    }
360}