1use std::collections::HashMap;
19use std::path::PathBuf;
20use std::sync::Mutex;
21
22use serde::{Deserialize, Serialize};
23
24use crate::Plugin;
25use pylon_auth::AuthContext;
26use serde_json::Value;
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29struct VectorRow {
30 entity: String,
31 row_id: String,
32 vector: Vec<f32>,
33 norm: f32,
35 metadata: Option<Value>,
37}
38
39#[derive(Debug, Clone)]
40pub struct VectorHit {
41 pub entity: String,
42 pub row_id: String,
43 pub score: f32,
44 pub metadata: Option<Value>,
45}
46
47pub struct VectorSearchPlugin {
48 rows: Mutex<HashMap<(String, String), VectorRow>>,
49 persist_path: Option<PathBuf>,
50}
51
52impl VectorSearchPlugin {
53 pub fn new() -> Self {
54 Self {
55 rows: Mutex::new(HashMap::new()),
56 persist_path: None,
57 }
58 }
59
60 pub fn with_persist_path(mut self, path: impl Into<PathBuf>) -> Self {
63 let path = path.into();
64 if let Ok(bytes) = std::fs::read(&path) {
65 if let Ok(rows) = serde_json::from_slice::<Vec<VectorRow>>(&bytes) {
66 let mut map = self.rows.lock().unwrap();
67 for r in rows {
68 map.insert((r.entity.clone(), r.row_id.clone()), r);
69 }
70 }
71 }
72 self.persist_path = Some(path);
73 self
74 }
75
76 pub fn index(&self, entity: &str, row_id: &str, vector: Vec<f32>, metadata: Option<Value>) {
78 let norm = l2_norm(&vector);
79 let row = VectorRow {
80 entity: entity.into(),
81 row_id: row_id.into(),
82 vector,
83 norm,
84 metadata,
85 };
86 {
87 let mut map = self.rows.lock().unwrap();
88 map.insert((entity.into(), row_id.into()), row);
89 }
90 self.persist();
91 }
92
93 pub fn search(&self, query: &[f32], k: usize, entity_filter: Option<&str>) -> Vec<VectorHit> {
95 if query.is_empty() {
96 return vec![];
97 }
98 let q_norm = l2_norm(query);
99 if q_norm == 0.0 {
100 return vec![];
101 }
102
103 let map = self.rows.lock().unwrap();
104 let mut hits: Vec<VectorHit> = map
105 .values()
106 .filter(|r| {
107 entity_filter.map(|e| e == r.entity).unwrap_or(true)
108 && r.vector.len() == query.len()
109 && r.norm > 0.0
110 })
111 .map(|r| {
112 let dot: f32 = r.vector.iter().zip(query.iter()).map(|(a, b)| a * b).sum();
113 let score = dot / (r.norm * q_norm);
114 VectorHit {
115 entity: r.entity.clone(),
116 row_id: r.row_id.clone(),
117 score,
118 metadata: r.metadata.clone(),
119 }
120 })
121 .collect();
122
123 hits.sort_by(|a, b| {
125 b.score
126 .partial_cmp(&a.score)
127 .unwrap_or(std::cmp::Ordering::Equal)
128 });
129 hits.truncate(k);
130 hits
131 }
132
133 pub fn len(&self) -> usize {
135 self.rows.lock().unwrap().len()
136 }
137
138 pub fn is_empty(&self) -> bool {
140 self.rows.lock().unwrap().is_empty()
141 }
142
143 fn persist(&self) {
144 let Some(path) = self.persist_path.as_ref() else {
145 return;
146 };
147 let map = self.rows.lock().unwrap();
148 let rows: Vec<&VectorRow> = map.values().collect();
149 if let Ok(bytes) = serde_json::to_vec(&rows) {
150 let _ = std::fs::write(path, bytes);
151 }
152 }
153}
154
155impl Default for VectorSearchPlugin {
156 fn default() -> Self {
157 Self::new()
158 }
159}
160
161impl Plugin for VectorSearchPlugin {
162 fn name(&self) -> &str {
163 "vector_search"
164 }
165
166 fn after_delete(&self, entity: &str, id: &str, _auth: &AuthContext) {
167 let mut map = self.rows.lock().unwrap();
168 map.remove(&(entity.into(), id.into()));
169 drop(map);
170 self.persist();
171 }
172}
173
174fn l2_norm(v: &[f32]) -> f32 {
175 v.iter().map(|x| x * x).sum::<f32>().sqrt()
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn search_returns_most_similar() {
184 let p = VectorSearchPlugin::new();
185 p.index("Doc", "a", vec![1.0, 0.0, 0.0], None);
186 p.index("Doc", "b", vec![0.0, 1.0, 0.0], None);
187 p.index("Doc", "c", vec![0.9, 0.1, 0.0], None);
188
189 let hits = p.search(&[1.0, 0.0, 0.0], 2, None);
190 assert_eq!(hits.len(), 2);
191 assert_eq!(hits[0].row_id, "a");
192 assert_eq!(hits[1].row_id, "c");
193 assert!(hits[0].score > hits[1].score);
194 }
195
196 #[test]
197 fn entity_filter_restricts_results() {
198 let p = VectorSearchPlugin::new();
199 p.index("Doc", "1", vec![1.0, 0.0], None);
200 p.index("Note", "2", vec![1.0, 0.0], None);
201 let hits = p.search(&[1.0, 0.0], 5, Some("Note"));
202 assert_eq!(hits.len(), 1);
203 assert_eq!(hits[0].entity, "Note");
204 }
205
206 #[test]
207 fn upsert_replaces_previous_vector() {
208 let p = VectorSearchPlugin::new();
209 p.index("Doc", "x", vec![1.0, 0.0], None);
210 p.index("Doc", "x", vec![0.0, 1.0], None);
211 assert_eq!(p.len(), 1);
212 let hits = p.search(&[0.0, 1.0], 1, None);
213 assert!((hits[0].score - 1.0).abs() < 1e-5);
214 }
215
216 #[test]
217 fn dimension_mismatch_excluded() {
218 let p = VectorSearchPlugin::new();
219 p.index("Doc", "a", vec![1.0, 0.0], None);
220 p.index("Doc", "b", vec![1.0, 0.0, 0.0], None);
221 let hits = p.search(&[1.0, 0.0], 5, None);
222 assert_eq!(hits.len(), 1);
223 assert_eq!(hits[0].row_id, "a");
224 }
225
226 #[test]
227 fn delete_via_plugin_hook_removes_row() {
228 let p = VectorSearchPlugin::new();
229 p.index("Doc", "a", vec![1.0, 0.0], None);
230 assert_eq!(p.len(), 1);
231 p.after_delete("Doc", "a", &AuthContext::anonymous());
232 assert!(p.is_empty());
233 }
234
235 #[test]
236 fn persist_round_trip() {
237 let dir = std::env::temp_dir().join(format!("pylon_vec_{}", std::process::id()));
238 std::fs::create_dir_all(&dir).unwrap();
239 let path = dir.join("vec.json");
240
241 let p1 = VectorSearchPlugin::new().with_persist_path(&path);
242 p1.index(
243 "Doc",
244 "x",
245 vec![0.5, 0.5],
246 Some(serde_json::json!({"t": "hi"})),
247 );
248 drop(p1);
249
250 let p2 = VectorSearchPlugin::new().with_persist_path(&path);
251 assert_eq!(p2.len(), 1);
252 let hits = p2.search(&[0.5, 0.5], 1, None);
253 assert_eq!(hits[0].row_id, "x");
254 assert_eq!(hits[0].metadata.as_ref().unwrap()["t"], "hi");
255
256 let _ = std::fs::remove_dir_all(&dir);
257 }
258
259 #[test]
260 fn empty_query_returns_nothing() {
261 let p = VectorSearchPlugin::new();
262 p.index("Doc", "a", vec![1.0, 0.0], None);
263 assert!(p.search(&[], 5, None).is_empty());
264 assert!(p.search(&[0.0, 0.0], 5, None).is_empty());
265 }
266}