Skip to main content

pylon_plugin/builtin/
vector_search.rs

1//! In-memory vector search plugin.
2//!
3//! Stores embeddings (Vec<f32>) per (entity, row_id) and exposes nearest-neighbour
4//! search via cosine similarity. The plugin itself does *not* compute embeddings —
5//! callers pass pre-computed vectors via `index()` (typically from an
6//! OpenAI/Anthropic/local embedding model). For production scale move to a
7//! dedicated vector store (pgvector, Qdrant, Turso libsql vector); this is the
8//! "good enough for thousands of rows" implementation.
9//!
10//! Why not store vectors directly in SQLite? SQLite has no first-class vector
11//! support and naive blob storage means re-decoding every row per query. An
12//! in-memory index is far faster for small/medium datasets and survives via
13//! a snapshot-on-write to a JSON file (see `persist_path`).
14//!
15//! Search complexity is O(n * d) per query. With 10k rows and 1024-dim vectors
16//! that's ~10M float ops per query — well under 10ms on commodity hardware.
17
18use std::collections::HashMap;
19use std::path::PathBuf;
20use std::sync::Mutex;
21
22use serde::{Deserialize, Serialize};
23
24use crate::Plugin;
25use pylon_auth::AuthContext;
26use serde_json::Value;
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29struct VectorRow {
30    entity: String,
31    row_id: String,
32    vector: Vec<f32>,
33    /// L2 norm of the vector, cached so we don't recompute per query.
34    norm: f32,
35    /// Optional payload returned with hits (e.g. preview text).
36    metadata: Option<Value>,
37}
38
39#[derive(Debug, Clone)]
40pub struct VectorHit {
41    pub entity: String,
42    pub row_id: String,
43    pub score: f32,
44    pub metadata: Option<Value>,
45}
46
47pub struct VectorSearchPlugin {
48    rows: Mutex<HashMap<(String, String), VectorRow>>,
49    persist_path: Option<PathBuf>,
50}
51
52impl VectorSearchPlugin {
53    pub fn new() -> Self {
54        Self {
55            rows: Mutex::new(HashMap::new()),
56            persist_path: None,
57        }
58    }
59
60    /// Persist the index to disk on every write. Useful for restart safety.
61    /// Loads existing data from `path` if it exists.
62    pub fn with_persist_path(mut self, path: impl Into<PathBuf>) -> Self {
63        let path = path.into();
64        if let Ok(bytes) = std::fs::read(&path) {
65            if let Ok(rows) = serde_json::from_slice::<Vec<VectorRow>>(&bytes) {
66                let mut map = self.rows.lock().unwrap();
67                for r in rows {
68                    map.insert((r.entity.clone(), r.row_id.clone()), r);
69                }
70            }
71        }
72        self.persist_path = Some(path);
73        self
74    }
75
76    /// Upsert a vector for `(entity, row_id)`.
77    pub fn index(&self, entity: &str, row_id: &str, vector: Vec<f32>, metadata: Option<Value>) {
78        let norm = l2_norm(&vector);
79        let row = VectorRow {
80            entity: entity.into(),
81            row_id: row_id.into(),
82            vector,
83            norm,
84            metadata,
85        };
86        {
87            let mut map = self.rows.lock().unwrap();
88            map.insert((entity.into(), row_id.into()), row);
89        }
90        self.persist();
91    }
92
93    /// Find the k most similar vectors. Optionally restrict to one entity.
94    pub fn search(&self, query: &[f32], k: usize, entity_filter: Option<&str>) -> Vec<VectorHit> {
95        if query.is_empty() {
96            return vec![];
97        }
98        let q_norm = l2_norm(query);
99        if q_norm == 0.0 {
100            return vec![];
101        }
102
103        let map = self.rows.lock().unwrap();
104        let mut hits: Vec<VectorHit> = map
105            .values()
106            .filter(|r| {
107                entity_filter.map(|e| e == r.entity).unwrap_or(true)
108                    && r.vector.len() == query.len()
109                    && r.norm > 0.0
110            })
111            .map(|r| {
112                let dot: f32 = r.vector.iter().zip(query.iter()).map(|(a, b)| a * b).sum();
113                let score = dot / (r.norm * q_norm);
114                VectorHit {
115                    entity: r.entity.clone(),
116                    row_id: r.row_id.clone(),
117                    score,
118                    metadata: r.metadata.clone(),
119                }
120            })
121            .collect();
122
123        // Highest score first.
124        hits.sort_by(|a, b| {
125            b.score
126                .partial_cmp(&a.score)
127                .unwrap_or(std::cmp::Ordering::Equal)
128        });
129        hits.truncate(k);
130        hits
131    }
132
133    /// Number of indexed vectors.
134    pub fn len(&self) -> usize {
135        self.rows.lock().unwrap().len()
136    }
137
138    /// True when the index is empty.
139    pub fn is_empty(&self) -> bool {
140        self.rows.lock().unwrap().is_empty()
141    }
142
143    fn persist(&self) {
144        let Some(path) = self.persist_path.as_ref() else {
145            return;
146        };
147        let map = self.rows.lock().unwrap();
148        let rows: Vec<&VectorRow> = map.values().collect();
149        if let Ok(bytes) = serde_json::to_vec(&rows) {
150            let _ = std::fs::write(path, bytes);
151        }
152    }
153}
154
155impl Default for VectorSearchPlugin {
156    fn default() -> Self {
157        Self::new()
158    }
159}
160
161impl Plugin for VectorSearchPlugin {
162    fn name(&self) -> &str {
163        "vector_search"
164    }
165
166    fn after_delete(&self, entity: &str, id: &str, _auth: &AuthContext) {
167        let mut map = self.rows.lock().unwrap();
168        map.remove(&(entity.into(), id.into()));
169        drop(map);
170        self.persist();
171    }
172}
173
174fn l2_norm(v: &[f32]) -> f32 {
175    v.iter().map(|x| x * x).sum::<f32>().sqrt()
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn search_returns_most_similar() {
184        let p = VectorSearchPlugin::new();
185        p.index("Doc", "a", vec![1.0, 0.0, 0.0], None);
186        p.index("Doc", "b", vec![0.0, 1.0, 0.0], None);
187        p.index("Doc", "c", vec![0.9, 0.1, 0.0], None);
188
189        let hits = p.search(&[1.0, 0.0, 0.0], 2, None);
190        assert_eq!(hits.len(), 2);
191        assert_eq!(hits[0].row_id, "a");
192        assert_eq!(hits[1].row_id, "c");
193        assert!(hits[0].score > hits[1].score);
194    }
195
196    #[test]
197    fn entity_filter_restricts_results() {
198        let p = VectorSearchPlugin::new();
199        p.index("Doc", "1", vec![1.0, 0.0], None);
200        p.index("Note", "2", vec![1.0, 0.0], None);
201        let hits = p.search(&[1.0, 0.0], 5, Some("Note"));
202        assert_eq!(hits.len(), 1);
203        assert_eq!(hits[0].entity, "Note");
204    }
205
206    #[test]
207    fn upsert_replaces_previous_vector() {
208        let p = VectorSearchPlugin::new();
209        p.index("Doc", "x", vec![1.0, 0.0], None);
210        p.index("Doc", "x", vec![0.0, 1.0], None);
211        assert_eq!(p.len(), 1);
212        let hits = p.search(&[0.0, 1.0], 1, None);
213        assert!((hits[0].score - 1.0).abs() < 1e-5);
214    }
215
216    #[test]
217    fn dimension_mismatch_excluded() {
218        let p = VectorSearchPlugin::new();
219        p.index("Doc", "a", vec![1.0, 0.0], None);
220        p.index("Doc", "b", vec![1.0, 0.0, 0.0], None);
221        let hits = p.search(&[1.0, 0.0], 5, None);
222        assert_eq!(hits.len(), 1);
223        assert_eq!(hits[0].row_id, "a");
224    }
225
226    #[test]
227    fn delete_via_plugin_hook_removes_row() {
228        let p = VectorSearchPlugin::new();
229        p.index("Doc", "a", vec![1.0, 0.0], None);
230        assert_eq!(p.len(), 1);
231        p.after_delete("Doc", "a", &AuthContext::anonymous());
232        assert!(p.is_empty());
233    }
234
235    #[test]
236    fn persist_round_trip() {
237        let dir = std::env::temp_dir().join(format!("pylon_vec_{}", std::process::id()));
238        std::fs::create_dir_all(&dir).unwrap();
239        let path = dir.join("vec.json");
240
241        let p1 = VectorSearchPlugin::new().with_persist_path(&path);
242        p1.index(
243            "Doc",
244            "x",
245            vec![0.5, 0.5],
246            Some(serde_json::json!({"t": "hi"})),
247        );
248        drop(p1);
249
250        let p2 = VectorSearchPlugin::new().with_persist_path(&path);
251        assert_eq!(p2.len(), 1);
252        let hits = p2.search(&[0.5, 0.5], 1, None);
253        assert_eq!(hits[0].row_id, "x");
254        assert_eq!(hits[0].metadata.as_ref().unwrap()["t"], "hi");
255
256        let _ = std::fs::remove_dir_all(&dir);
257    }
258
259    #[test]
260    fn empty_query_returns_nothing() {
261        let p = VectorSearchPlugin::new();
262        p.index("Doc", "a", vec![1.0, 0.0], None);
263        assert!(p.search(&[], 5, None).is_empty());
264        assert!(p.search(&[0.0, 0.0], 5, None).is_empty());
265    }
266}