Skip to main content

oxios_memory/memory/
hnsw_memory_index.rs

1//! HNSW index manager for memory entries.
2
3use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::atomic::{AtomicU64, Ordering};
6
7use anyhow::Result;
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10
11use super::l2_normalize_f32;
12use super::HnswIndex;
13use super::MemoryEntry;
14
15/// Result of a semantic search hit.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct SemanticHit {
18    /// Memory entry.
19    pub entry: MemoryEntry,
20    /// Cosine distance (0.0 = identical).
21    pub distance: f32,
22    /// Cosine similarity (1.0 = identical).
23    pub similarity: f32,
24}
25
26/// HNSW index manager for memory entries.
27///
28/// Maintains a mapping from u64 keys to String IDs, and the HNSW index
29/// itself. Thread-safe via `RwLock`.
30pub struct HnswMemoryIndex {
31    /// The HNSW index.
32    index: RwLock<HnswIndex>,
33    /// Map: u64 key → String memory ID.
34    key_to_id: RwLock<HashMap<u64, String>>,
35    /// Map: String memory ID → u64 key.
36    id_to_key: RwLock<HashMap<String, u64>>,
37    /// Next key counter.
38    next_key: AtomicU64,
39    /// Base path for index persistence.
40    persist_path: Option<PathBuf>,
41}
42
43impl std::fmt::Debug for HnswMemoryIndex {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        f.debug_struct("HnswMemoryIndex")
46            .field("size", &self.len())
47            .field("dimensions", &self.index.read().dimensions())
48            .finish()
49    }
50}
51
52impl HnswMemoryIndex {
53    /// Create a new HNSW memory index.
54    ///
55    /// # Arguments
56    /// * `dimensions` — Embedding vector dimensions.
57    /// * `capacity` — Initial capacity hint.
58    /// * `persist_path` — Optional directory for index file persistence.
59    pub fn new(dimensions: usize, capacity: usize, persist_path: Option<PathBuf>) -> Result<Self> {
60        let index = HnswIndex::new(dimensions, capacity)?;
61        Ok(Self {
62            index: RwLock::new(index),
63            key_to_id: RwLock::new(HashMap::new()),
64            id_to_key: RwLock::new(HashMap::new()),
65            next_key: AtomicU64::new(1), // 0 is used by usearch as sentinel
66            persist_path,
67        })
68    }
69
70    /// Try to restore from disk, fall back to new index.
71    pub fn restore_or_new(
72        dimensions: usize,
73        capacity: usize,
74        persist_path: Option<PathBuf>,
75    ) -> Result<Self> {
76        if let Some(ref path) = persist_path {
77            let index_path = path.join("memory.usearch");
78            let mapping_path = path.join("key_map.json");
79
80            if index_path.exists() && mapping_path.exists() {
81                tracing::info!(path = %index_path.display(), "Restoring HNSW index from disk");
82
83                if let Ok(index) = HnswIndex::load(&index_path) {
84                    if let Ok(data) = std::fs::read_to_string(&mapping_path) {
85                        if let Ok((k2i, i2k)) = serde_json::from_str::<(
86                            HashMap<u64, String>,
87                            HashMap<String, u64>,
88                        )>(&data)
89                        {
90                            let max_key = k2i.keys().max().copied().unwrap_or(0);
91                            return Ok(Self {
92                                index: RwLock::new(index),
93                                key_to_id: RwLock::new(k2i),
94                                id_to_key: RwLock::new(i2k),
95                                next_key: AtomicU64::new(max_key + 1),
96                                persist_path,
97                            });
98                        }
99                    }
100                }
101
102                tracing::warn!("Failed to restore HNSW index, creating new one");
103            }
104        }
105
106        Self::new(dimensions, capacity, persist_path)
107    }
108
109    /// Get or create a u64 key for a String ID.
110    fn get_or_create_key(&self, id: &str) -> u64 {
111        // Fast path: check read lock
112        {
113            let i2k = self.id_to_key.read();
114            if let Some(&key) = i2k.get(id) {
115                return key;
116            }
117        }
118
119        // Slow path: write lock
120        let mut i2k = self.id_to_key.write();
121        let mut k2i = self.key_to_id.write();
122
123        // Double-check after acquiring write lock
124        if let Some(&key) = i2k.get(id) {
125            return key;
126        }
127
128        let key = self.next_key.fetch_add(1, Ordering::Relaxed);
129        i2k.insert(id.to_string(), key);
130        k2i.insert(key, id.to_string());
131        key
132    }
133
134    /// Add an entry to the HNSW index.
135    pub fn add_entry(&self, id: &str, vector: &[f32]) -> Result<()> {
136        let key = self.get_or_create_key(id);
137        let mut normalized = vector.to_vec();
138        l2_normalize_f32(&mut normalized);
139        self.index.write().add(key, &normalized)?;
140        Ok(())
141    }
142
143    /// Remove an entry from the index.
144    pub fn remove_entry(&self, id: &str) -> Result<()> {
145        let key = {
146            let i2k = self.id_to_key.read();
147            i2k.get(id).copied()
148        };
149        if let Some(key) = key {
150            self.index.write().remove(key)?;
151            let mut k2i = self.key_to_id.write();
152            let mut i2k = self.id_to_key.write();
153            k2i.remove(&key);
154            i2k.remove(id);
155        }
156        Ok(())
157    }
158
159    /// Search for k nearest neighbors.
160    ///
161    /// Returns (String ID, distance) pairs.
162    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
163        let mut normalized = query.to_vec();
164        l2_normalize_f32(&mut normalized);
165
166        let raw = self.index.read().search(&normalized, k)?;
167        let k2i = self.key_to_id.read();
168
169        let results = raw
170            .into_iter()
171            .filter_map(|(key, dist)| k2i.get(&key).map(|id| (id.clone(), dist)))
172            .collect();
173
174        Ok(results)
175    }
176
177    /// Number of entries in the index.
178    pub fn len(&self) -> usize {
179        self.index.read().len()
180    }
181
182    /// Whether the index is empty.
183    pub fn is_empty(&self) -> bool {
184        self.index.read().is_empty()
185    }
186
187    /// Save the index and key mappings to disk.
188    pub fn persist(&self) -> Result<()> {
189        if let Some(ref path) = self.persist_path {
190            std::fs::create_dir_all(path)?;
191
192            let index_path = path.join("memory.usearch");
193            let mapping_path = path.join("key_map.json");
194
195            // Save index
196            self.index.read().save(&index_path)?;
197
198            // Save key mappings
199            let k2i = self.key_to_id.read();
200            let i2k = self.id_to_key.read();
201            let data = serde_json::to_string(&(k2i.clone(), &*i2k))?;
202            std::fs::write(&mapping_path, data)?;
203
204            tracing::debug!(path = %path.display(), entries = self.len(), "HNSW index persisted");
205        }
206        Ok(())
207    }
208}
209
210// ---------------------------------------------------------------------------
211// Semantic search on MemoryManager
212// ---------------------------------------------------------------------------