Skip to main content

ext_vector/
lib.rs

1//! ext-vector: HNSW vector index with SIMD distance kernels.
2//!
3//! Provides approximate nearest neighbor search through three procedures:
4//! - `vector.build(dim, metric?)` → initializes HNSW index (metric: "l2" or "cosine")
5//! - `vector.add(id, vector_csv)` → inserts a vector, returns `{status: STRING}`
6//! - `vector.search(query_csv, k)` → ANN results `{id: INT64, distance: DOUBLE}`
7//!
8//! Uses NEON intrinsics on aarch64 and scalar fallback (auto-vectorized to AVX2) elsewhere.
9//! Staging buffer accumulates vectors before bulk-inserting into the live HNSW index.
10
11pub mod distance;
12pub mod hnsw;
13pub mod staging;
14
15use std::collections::HashMap;
16use std::sync::Mutex;
17
18use kyu_extension::{Extension, ProcColumn, ProcParam, ProcRow, ProcedureSignature};
19use kyu_types::{LogicalType, TypedValue};
20use smol_str::SmolStr;
21
22use crate::distance::DistanceMetric;
23use crate::hnsw::{HnswConfig, HnswIndex};
24use crate::staging::StagingBuffer;
25
26/// Vector search state: optional HNSW index + staging buffer.
27struct VectorState {
28    index: Option<HnswIndex>,
29    staging: StagingBuffer,
30    metric: DistanceMetric,
31    next_id: usize,
32}
33
34impl Default for VectorState {
35    fn default() -> Self {
36        Self {
37            index: None,
38            staging: StagingBuffer::new(),
39            metric: DistanceMetric::L2,
40            next_id: 0,
41        }
42    }
43}
44
45/// Vector search extension.
46pub struct VectorExtension {
47    state: Mutex<VectorState>,
48}
49
50impl VectorExtension {
51    pub fn new() -> Self {
52        Self {
53            state: Mutex::new(VectorState::default()),
54        }
55    }
56}
57
58impl Default for VectorExtension {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl Extension for VectorExtension {
65    fn name(&self) -> &str {
66        "vector"
67    }
68
69    fn needs_graph(&self) -> bool {
70        false
71    }
72
73    fn procedures(&self) -> Vec<ProcedureSignature> {
74        vec![
75            ProcedureSignature {
76                name: "build".into(),
77                params: vec![
78                    ProcParam {
79                        name: "dim".into(),
80                        type_desc: "INT64".into(),
81                    },
82                    ProcParam {
83                        name: "metric".into(),
84                        type_desc: "STRING".into(),
85                    },
86                ],
87                columns: vec![ProcColumn {
88                    name: "status".into(),
89                    data_type: LogicalType::String,
90                }],
91            },
92            ProcedureSignature {
93                name: "add".into(),
94                params: vec![
95                    ProcParam {
96                        name: "id".into(),
97                        type_desc: "INT64".into(),
98                    },
99                    ProcParam {
100                        name: "vector_csv".into(),
101                        type_desc: "STRING".into(),
102                    },
103                ],
104                columns: vec![ProcColumn {
105                    name: "status".into(),
106                    data_type: LogicalType::String,
107                }],
108            },
109            ProcedureSignature {
110                name: "search".into(),
111                params: vec![
112                    ProcParam {
113                        name: "query_csv".into(),
114                        type_desc: "STRING".into(),
115                    },
116                    ProcParam {
117                        name: "k".into(),
118                        type_desc: "INT64".into(),
119                    },
120                ],
121                columns: vec![
122                    ProcColumn {
123                        name: "id".into(),
124                        data_type: LogicalType::Int64,
125                    },
126                    ProcColumn {
127                        name: "distance".into(),
128                        data_type: LogicalType::Double,
129                    },
130                ],
131            },
132        ]
133    }
134
135    fn execute(
136        &self,
137        procedure: &str,
138        args: &[String],
139        _adjacency: &HashMap<i64, Vec<(i64, f64)>>,
140    ) -> Result<Vec<ProcRow>, String> {
141        let mut state = self
142            .state
143            .lock()
144            .map_err(|e| format!("lock error: {e}"))?;
145
146        match procedure {
147            "build" => {
148                let dim: usize = args
149                    .first()
150                    .ok_or("vector.build requires dim argument")?
151                    .parse()
152                    .map_err(|_| "dim must be a positive integer")?;
153                if dim == 0 {
154                    return Err("dim must be > 0".into());
155                }
156
157                let metric = match args.get(1).map(|s| s.to_lowercase()).as_deref() {
158                    Some("cosine") => DistanceMetric::Cosine,
159                    _ => DistanceMetric::L2,
160                };
161
162                state.index = Some(HnswIndex::new(dim, HnswConfig {
163                    metric,
164                    ..HnswConfig::default()
165                }));
166                state.staging = StagingBuffer::new();
167                state.metric = metric;
168                state.next_id = 0;
169
170                Ok(vec![vec![TypedValue::String(SmolStr::new(format!(
171                    "built dim={dim} metric={metric:?}"
172                )))]])
173            }
174
175            "add" => {
176                let VectorState { index, staging, next_id, .. } = &mut *state;
177                let index = index.as_mut().ok_or("call vector.build first")?;
178
179                let ext_id: usize = args
180                    .first()
181                    .ok_or("vector.add requires id argument")?
182                    .parse()
183                    .map_err(|_| "id must be a non-negative integer")?;
184
185                let csv = args.get(1).ok_or("vector.add requires vector_csv argument")?;
186                let vector: Vec<f32> = csv
187                    .split(',')
188                    .map(|s| {
189                        s.trim()
190                            .parse::<f32>()
191                            .map_err(|_| format!("invalid float in vector: '{}'", s.trim()))
192                    })
193                    .collect::<Result<_, _>>()?;
194
195                let needs_flush = staging.add(ext_id, vector);
196                if needs_flush {
197                    staging.flush(index);
198                }
199
200                *next_id = (*next_id).max(ext_id + 1);
201
202                Ok(vec![vec![TypedValue::String(SmolStr::new("ok"))]])
203            }
204
205            "search" => {
206                let VectorState { index, staging, .. } = &mut *state;
207                let index = index.as_mut().ok_or("call vector.build first")?;
208
209                let csv = args.first().ok_or("vector.search requires query_csv argument")?;
210                let query: Vec<f32> = csv
211                    .split(',')
212                    .map(|s| {
213                        s.trim()
214                            .parse::<f32>()
215                            .map_err(|_| format!("invalid float in query: '{}'", s.trim()))
216                    })
217                    .collect::<Result<_, _>>()?;
218
219                let k: usize = args
220                    .get(1)
221                    .and_then(|s| s.parse().ok())
222                    .unwrap_or(10);
223
224                // Flush pending before search for consistency.
225                if staging.pending_count() > 0 {
226                    staging.flush(index);
227                }
228
229                let results = index.search(&query, k, k.max(50));
230
231                Ok(results
232                    .into_iter()
233                    .map(|(id, dist)| {
234                        vec![
235                            TypedValue::Int64(id as i64),
236                            TypedValue::Double(dist as f64),
237                        ]
238                    })
239                    .collect())
240            }
241
242            _ => Err(format!("unknown procedure: {procedure}")),
243        }
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    fn empty_adj() -> HashMap<i64, Vec<(i64, f64)>> {
252        HashMap::new()
253    }
254
255    #[test]
256    fn extension_metadata() {
257        let ext = VectorExtension::new();
258        assert_eq!(ext.name(), "vector");
259        assert!(!ext.needs_graph());
260        assert_eq!(ext.procedures().len(), 3);
261    }
262
263    #[test]
264    fn build_add_search() {
265        let ext = VectorExtension::new();
266        let adj = empty_adj();
267
268        // Build index.
269        let result = ext.execute("build", &["3".into(), "l2".into()], &adj).unwrap();
270        assert_eq!(result.len(), 1);
271
272        // Add vectors.
273        ext.execute("add", &["0".into(), "1.0,0.0,0.0".into()], &adj).unwrap();
274        ext.execute("add", &["1".into(), "0.0,1.0,0.0".into()], &adj).unwrap();
275        ext.execute("add", &["2".into(), "0.9,0.1,0.0".into()], &adj).unwrap();
276
277        // Search.
278        let results = ext.execute("search", &["1.0,0.0,0.0".into(), "2".into()], &adj).unwrap();
279        assert!(!results.is_empty());
280        // Nearest should be vector 0 (identical).
281        assert_eq!(results[0][0], TypedValue::Int64(0));
282    }
283
284    #[test]
285    fn search_without_build() {
286        let ext = VectorExtension::new();
287        let adj = empty_adj();
288        let result = ext.execute("search", &["1.0,0.0".into(), "5".into()], &adj);
289        assert!(result.is_err());
290    }
291
292    #[test]
293    fn add_without_build() {
294        let ext = VectorExtension::new();
295        let adj = empty_adj();
296        let result = ext.execute("add", &["0".into(), "1.0,0.0".into()], &adj);
297        assert!(result.is_err());
298    }
299
300    #[test]
301    fn unknown_procedure() {
302        let ext = VectorExtension::new();
303        let adj = empty_adj();
304        assert!(ext.execute("nonexistent", &[], &adj).is_err());
305    }
306
307    #[test]
308    fn cosine_search() {
309        let ext = VectorExtension::new();
310        let adj = empty_adj();
311
312        ext.execute("build", &["3".into(), "cosine".into()], &adj).unwrap();
313        ext.execute("add", &["0".into(), "1.0,0.0,0.0".into()], &adj).unwrap();
314        ext.execute("add", &["1".into(), "0.0,1.0,0.0".into()], &adj).unwrap();
315        ext.execute("add", &["2".into(), "0.9,0.1,0.0".into()], &adj).unwrap();
316
317        let results = ext.execute("search", &["1.0,0.0,0.0".into(), "3".into()], &adj).unwrap();
318        assert_eq!(results.len(), 3);
319        // Vector 0 should be closest (cosine distance ~0).
320        assert_eq!(results[0][0], TypedValue::Int64(0));
321        if let TypedValue::Double(d) = results[0][1] {
322            assert!(d < 0.01, "cosine distance to identical = {d}");
323        }
324    }
325}