oxios_memory/memory/
hnsw_memory_index.rs1use std::collections::HashMap;
4use std::path::PathBuf;
5use std::sync::atomic::{AtomicU64, Ordering};
6
7use anyhow::Result;
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10
11use super::l2_normalize_f32;
12use super::HnswIndex;
13use super::MemoryEntry;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct SemanticHit {
18 pub entry: MemoryEntry,
20 pub distance: f32,
22 pub similarity: f32,
24}
25
26pub struct HnswMemoryIndex {
31 index: RwLock<HnswIndex>,
33 key_to_id: RwLock<HashMap<u64, String>>,
35 id_to_key: RwLock<HashMap<String, u64>>,
37 next_key: AtomicU64,
39 persist_path: Option<PathBuf>,
41}
42
43impl std::fmt::Debug for HnswMemoryIndex {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 f.debug_struct("HnswMemoryIndex")
46 .field("size", &self.len())
47 .field("dimensions", &self.index.read().dimensions())
48 .finish()
49 }
50}
51
52impl HnswMemoryIndex {
53 pub fn new(dimensions: usize, capacity: usize, persist_path: Option<PathBuf>) -> Result<Self> {
60 let index = HnswIndex::new(dimensions, capacity)?;
61 Ok(Self {
62 index: RwLock::new(index),
63 key_to_id: RwLock::new(HashMap::new()),
64 id_to_key: RwLock::new(HashMap::new()),
65 next_key: AtomicU64::new(1), persist_path,
67 })
68 }
69
70 pub fn restore_or_new(
72 dimensions: usize,
73 capacity: usize,
74 persist_path: Option<PathBuf>,
75 ) -> Result<Self> {
76 if let Some(ref path) = persist_path {
77 let index_path = path.join("memory.usearch");
78 let mapping_path = path.join("key_map.json");
79
80 if index_path.exists() && mapping_path.exists() {
81 tracing::info!(path = %index_path.display(), "Restoring HNSW index from disk");
82
83 if let Ok(index) = HnswIndex::load(&index_path) {
84 if let Ok(data) = std::fs::read_to_string(&mapping_path) {
85 if let Ok((k2i, i2k)) = serde_json::from_str::<(
86 HashMap<u64, String>,
87 HashMap<String, u64>,
88 )>(&data)
89 {
90 let max_key = k2i.keys().max().copied().unwrap_or(0);
91 return Ok(Self {
92 index: RwLock::new(index),
93 key_to_id: RwLock::new(k2i),
94 id_to_key: RwLock::new(i2k),
95 next_key: AtomicU64::new(max_key + 1),
96 persist_path,
97 });
98 }
99 }
100 }
101
102 tracing::warn!("Failed to restore HNSW index, creating new one");
103 }
104 }
105
106 Self::new(dimensions, capacity, persist_path)
107 }
108
109 fn get_or_create_key(&self, id: &str) -> u64 {
111 {
113 let i2k = self.id_to_key.read();
114 if let Some(&key) = i2k.get(id) {
115 return key;
116 }
117 }
118
119 let mut i2k = self.id_to_key.write();
121 let mut k2i = self.key_to_id.write();
122
123 if let Some(&key) = i2k.get(id) {
125 return key;
126 }
127
128 let key = self.next_key.fetch_add(1, Ordering::Relaxed);
129 i2k.insert(id.to_string(), key);
130 k2i.insert(key, id.to_string());
131 key
132 }
133
134 pub fn add_entry(&self, id: &str, vector: &[f32]) -> Result<()> {
136 let key = self.get_or_create_key(id);
137 let mut normalized = vector.to_vec();
138 l2_normalize_f32(&mut normalized);
139 self.index.write().add(key, &normalized)?;
140 Ok(())
141 }
142
143 pub fn remove_entry(&self, id: &str) -> Result<()> {
145 let key = {
146 let i2k = self.id_to_key.read();
147 i2k.get(id).copied()
148 };
149 if let Some(key) = key {
150 self.index.write().remove(key)?;
151 let mut k2i = self.key_to_id.write();
152 let mut i2k = self.id_to_key.write();
153 k2i.remove(&key);
154 i2k.remove(id);
155 }
156 Ok(())
157 }
158
159 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
163 let mut normalized = query.to_vec();
164 l2_normalize_f32(&mut normalized);
165
166 let raw = self.index.read().search(&normalized, k)?;
167 let k2i = self.key_to_id.read();
168
169 let results = raw
170 .into_iter()
171 .filter_map(|(key, dist)| k2i.get(&key).map(|id| (id.clone(), dist)))
172 .collect();
173
174 Ok(results)
175 }
176
177 pub fn len(&self) -> usize {
179 self.index.read().len()
180 }
181
182 pub fn is_empty(&self) -> bool {
184 self.index.read().is_empty()
185 }
186
187 pub fn persist(&self) -> Result<()> {
189 if let Some(ref path) = self.persist_path {
190 std::fs::create_dir_all(path)?;
191
192 let index_path = path.join("memory.usearch");
193 let mapping_path = path.join("key_map.json");
194
195 self.index.read().save(&index_path)?;
197
198 let k2i = self.key_to_id.read();
200 let i2k = self.id_to_key.read();
201 let data = serde_json::to_string(&(k2i.clone(), &*i2k))?;
202 std::fs::write(&mapping_path, data)?;
203
204 tracing::debug!(path = %path.display(), entries = self.len(), "HNSW index persisted");
205 }
206 Ok(())
207 }
208}
209
210