Skip to main content

ext_vector/
lib.rs

1//! ext-vector: HNSW vector index with SIMD distance kernels.
2//!
3//! Provides approximate nearest neighbor search through three procedures:
4//! - `vector.build(dim, metric?)` → initializes HNSW index (metric: "l2" or "cosine")
5//! - `vector.add(id, vector_csv)` → inserts a vector, returns `{status: STRING}`
6//! - `vector.search(query_csv, k)` → ANN results `{id: INT64, distance: DOUBLE}`
7//!
8//! Uses NEON intrinsics on aarch64 and scalar fallback (auto-vectorized to AVX2) elsewhere.
9//! Staging buffer accumulates vectors before bulk-inserting into the live HNSW index.
10
11pub mod distance;
12pub mod hnsw;
13pub mod staging;
14
15use std::collections::HashMap;
16use std::sync::Mutex;
17
18use kyu_extension::{Extension, ProcColumn, ProcParam, ProcRow, ProcedureSignature};
19use kyu_types::{LogicalType, TypedValue};
20use smol_str::SmolStr;
21
22use crate::distance::DistanceMetric;
23use crate::hnsw::{HnswConfig, HnswIndex};
24use crate::staging::StagingBuffer;
25
26/// Vector search state: optional HNSW index + staging buffer.
27struct VectorState {
28    index: Option<HnswIndex>,
29    staging: StagingBuffer,
30    metric: DistanceMetric,
31    next_id: usize,
32}
33
34impl Default for VectorState {
35    fn default() -> Self {
36        Self {
37            index: None,
38            staging: StagingBuffer::new(),
39            metric: DistanceMetric::L2,
40            next_id: 0,
41        }
42    }
43}
44
45/// Vector search extension.
46pub struct VectorExtension {
47    state: Mutex<VectorState>,
48}
49
50impl VectorExtension {
51    pub fn new() -> Self {
52        Self {
53            state: Mutex::new(VectorState::default()),
54        }
55    }
56}
57
58impl Default for VectorExtension {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl Extension for VectorExtension {
65    fn name(&self) -> &str {
66        "vector"
67    }
68
69    fn needs_graph(&self) -> bool {
70        false
71    }
72
73    fn procedures(&self) -> Vec<ProcedureSignature> {
74        vec![
75            ProcedureSignature {
76                name: "build".into(),
77                params: vec![
78                    ProcParam {
79                        name: "dim".into(),
80                        type_desc: "INT64".into(),
81                    },
82                    ProcParam {
83                        name: "metric".into(),
84                        type_desc: "STRING".into(),
85                    },
86                ],
87                columns: vec![ProcColumn {
88                    name: "status".into(),
89                    data_type: LogicalType::String,
90                }],
91            },
92            ProcedureSignature {
93                name: "add".into(),
94                params: vec![
95                    ProcParam {
96                        name: "id".into(),
97                        type_desc: "INT64".into(),
98                    },
99                    ProcParam {
100                        name: "vector_csv".into(),
101                        type_desc: "STRING".into(),
102                    },
103                ],
104                columns: vec![ProcColumn {
105                    name: "status".into(),
106                    data_type: LogicalType::String,
107                }],
108            },
109            ProcedureSignature {
110                name: "search".into(),
111                params: vec![
112                    ProcParam {
113                        name: "query_csv".into(),
114                        type_desc: "STRING".into(),
115                    },
116                    ProcParam {
117                        name: "k".into(),
118                        type_desc: "INT64".into(),
119                    },
120                ],
121                columns: vec![
122                    ProcColumn {
123                        name: "id".into(),
124                        data_type: LogicalType::Int64,
125                    },
126                    ProcColumn {
127                        name: "distance".into(),
128                        data_type: LogicalType::Double,
129                    },
130                ],
131            },
132        ]
133    }
134
135    fn execute(
136        &self,
137        procedure: &str,
138        args: &[String],
139        _adjacency: &HashMap<i64, Vec<(i64, f64)>>,
140    ) -> Result<Vec<ProcRow>, String> {
141        let mut state = self.state.lock().map_err(|e| format!("lock error: {e}"))?;
142
143        match procedure {
144            "build" => {
145                let dim: usize = args
146                    .first()
147                    .ok_or("vector.build requires dim argument")?
148                    .parse()
149                    .map_err(|_| "dim must be a positive integer")?;
150                if dim == 0 {
151                    return Err("dim must be > 0".into());
152                }
153
154                let metric = match args.get(1).map(|s| s.to_lowercase()).as_deref() {
155                    Some("cosine") => DistanceMetric::Cosine,
156                    _ => DistanceMetric::L2,
157                };
158
159                state.index = Some(HnswIndex::new(
160                    dim,
161                    HnswConfig {
162                        metric,
163                        ..HnswConfig::default()
164                    },
165                ));
166                state.staging = StagingBuffer::new();
167                state.metric = metric;
168                state.next_id = 0;
169
170                Ok(vec![vec![TypedValue::String(SmolStr::new(format!(
171                    "built dim={dim} metric={metric:?}"
172                )))]])
173            }
174
175            "add" => {
176                let VectorState {
177                    index,
178                    staging,
179                    next_id,
180                    ..
181                } = &mut *state;
182                let index = index.as_mut().ok_or("call vector.build first")?;
183
184                let ext_id: usize = args
185                    .first()
186                    .ok_or("vector.add requires id argument")?
187                    .parse()
188                    .map_err(|_| "id must be a non-negative integer")?;
189
190                let csv = args
191                    .get(1)
192                    .ok_or("vector.add requires vector_csv argument")?;
193                let vector: Vec<f32> = csv
194                    .split(',')
195                    .map(|s| {
196                        s.trim()
197                            .parse::<f32>()
198                            .map_err(|_| format!("invalid float in vector: '{}'", s.trim()))
199                    })
200                    .collect::<Result<_, _>>()?;
201
202                let needs_flush = staging.add(ext_id, vector);
203                if needs_flush {
204                    staging.flush(index);
205                }
206
207                *next_id = (*next_id).max(ext_id + 1);
208
209                Ok(vec![vec![TypedValue::String(SmolStr::new("ok"))]])
210            }
211
212            "search" => {
213                let VectorState { index, staging, .. } = &mut *state;
214                let index = index.as_mut().ok_or("call vector.build first")?;
215
216                let csv = args
217                    .first()
218                    .ok_or("vector.search requires query_csv argument")?;
219                let query: Vec<f32> = csv
220                    .split(',')
221                    .map(|s| {
222                        s.trim()
223                            .parse::<f32>()
224                            .map_err(|_| format!("invalid float in query: '{}'", s.trim()))
225                    })
226                    .collect::<Result<_, _>>()?;
227
228                let k: usize = args.get(1).and_then(|s| s.parse().ok()).unwrap_or(10);
229
230                // Flush pending before search for consistency.
231                if staging.pending_count() > 0 {
232                    staging.flush(index);
233                }
234
235                let results = index.search(&query, k, k.max(50));
236
237                Ok(results
238                    .into_iter()
239                    .map(|(id, dist)| {
240                        vec![
241                            TypedValue::Int64(id as i64),
242                            TypedValue::Double(dist as f64),
243                        ]
244                    })
245                    .collect())
246            }
247
248            _ => Err(format!("unknown procedure: {procedure}")),
249        }
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    fn empty_adj() -> HashMap<i64, Vec<(i64, f64)>> {
258        HashMap::new()
259    }
260
261    #[test]
262    fn extension_metadata() {
263        let ext = VectorExtension::new();
264        assert_eq!(ext.name(), "vector");
265        assert!(!ext.needs_graph());
266        assert_eq!(ext.procedures().len(), 3);
267    }
268
269    #[test]
270    fn build_add_search() {
271        let ext = VectorExtension::new();
272        let adj = empty_adj();
273
274        // Build index.
275        let result = ext
276            .execute("build", &["3".into(), "l2".into()], &adj)
277            .unwrap();
278        assert_eq!(result.len(), 1);
279
280        // Add vectors.
281        ext.execute("add", &["0".into(), "1.0,0.0,0.0".into()], &adj)
282            .unwrap();
283        ext.execute("add", &["1".into(), "0.0,1.0,0.0".into()], &adj)
284            .unwrap();
285        ext.execute("add", &["2".into(), "0.9,0.1,0.0".into()], &adj)
286            .unwrap();
287
288        // Search.
289        let results = ext
290            .execute("search", &["1.0,0.0,0.0".into(), "2".into()], &adj)
291            .unwrap();
292        assert!(!results.is_empty());
293        // Nearest should be vector 0 (identical).
294        assert_eq!(results[0][0], TypedValue::Int64(0));
295    }
296
297    #[test]
298    fn search_without_build() {
299        let ext = VectorExtension::new();
300        let adj = empty_adj();
301        let result = ext.execute("search", &["1.0,0.0".into(), "5".into()], &adj);
302        assert!(result.is_err());
303    }
304
305    #[test]
306    fn add_without_build() {
307        let ext = VectorExtension::new();
308        let adj = empty_adj();
309        let result = ext.execute("add", &["0".into(), "1.0,0.0".into()], &adj);
310        assert!(result.is_err());
311    }
312
313    #[test]
314    fn unknown_procedure() {
315        let ext = VectorExtension::new();
316        let adj = empty_adj();
317        assert!(ext.execute("nonexistent", &[], &adj).is_err());
318    }
319
320    #[test]
321    fn cosine_search() {
322        let ext = VectorExtension::new();
323        let adj = empty_adj();
324
325        ext.execute("build", &["3".into(), "cosine".into()], &adj)
326            .unwrap();
327        ext.execute("add", &["0".into(), "1.0,0.0,0.0".into()], &adj)
328            .unwrap();
329        ext.execute("add", &["1".into(), "0.0,1.0,0.0".into()], &adj)
330            .unwrap();
331        ext.execute("add", &["2".into(), "0.9,0.1,0.0".into()], &adj)
332            .unwrap();
333
334        let results = ext
335            .execute("search", &["1.0,0.0,0.0".into(), "3".into()], &adj)
336            .unwrap();
337        assert_eq!(results.len(), 3);
338        // Vector 0 should be closest (cosine distance ~0).
339        assert_eq!(results[0][0], TypedValue::Int64(0));
340        if let TypedValue::Double(d) = results[0][1] {
341            assert!(d < 0.01, "cosine distance to identical = {d}");
342        }
343    }
344}