1pub mod distance;
12pub mod hnsw;
13pub mod staging;
14
15use std::collections::HashMap;
16use std::sync::Mutex;
17
18use kyu_extension::{Extension, ProcColumn, ProcParam, ProcRow, ProcedureSignature};
19use kyu_types::{LogicalType, TypedValue};
20use smol_str::SmolStr;
21
22use crate::distance::DistanceMetric;
23use crate::hnsw::{HnswConfig, HnswIndex};
24use crate::staging::StagingBuffer;
25
26struct VectorState {
28 index: Option<HnswIndex>,
29 staging: StagingBuffer,
30 metric: DistanceMetric,
31 next_id: usize,
32}
33
34impl Default for VectorState {
35 fn default() -> Self {
36 Self {
37 index: None,
38 staging: StagingBuffer::new(),
39 metric: DistanceMetric::L2,
40 next_id: 0,
41 }
42 }
43}
44
45pub struct VectorExtension {
47 state: Mutex<VectorState>,
48}
49
50impl VectorExtension {
51 pub fn new() -> Self {
52 Self {
53 state: Mutex::new(VectorState::default()),
54 }
55 }
56}
57
58impl Default for VectorExtension {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl Extension for VectorExtension {
65 fn name(&self) -> &str {
66 "vector"
67 }
68
69 fn needs_graph(&self) -> bool {
70 false
71 }
72
73 fn procedures(&self) -> Vec<ProcedureSignature> {
74 vec![
75 ProcedureSignature {
76 name: "build".into(),
77 params: vec![
78 ProcParam {
79 name: "dim".into(),
80 type_desc: "INT64".into(),
81 },
82 ProcParam {
83 name: "metric".into(),
84 type_desc: "STRING".into(),
85 },
86 ],
87 columns: vec![ProcColumn {
88 name: "status".into(),
89 data_type: LogicalType::String,
90 }],
91 },
92 ProcedureSignature {
93 name: "add".into(),
94 params: vec![
95 ProcParam {
96 name: "id".into(),
97 type_desc: "INT64".into(),
98 },
99 ProcParam {
100 name: "vector_csv".into(),
101 type_desc: "STRING".into(),
102 },
103 ],
104 columns: vec![ProcColumn {
105 name: "status".into(),
106 data_type: LogicalType::String,
107 }],
108 },
109 ProcedureSignature {
110 name: "search".into(),
111 params: vec![
112 ProcParam {
113 name: "query_csv".into(),
114 type_desc: "STRING".into(),
115 },
116 ProcParam {
117 name: "k".into(),
118 type_desc: "INT64".into(),
119 },
120 ],
121 columns: vec![
122 ProcColumn {
123 name: "id".into(),
124 data_type: LogicalType::Int64,
125 },
126 ProcColumn {
127 name: "distance".into(),
128 data_type: LogicalType::Double,
129 },
130 ],
131 },
132 ]
133 }
134
135 fn execute(
136 &self,
137 procedure: &str,
138 args: &[String],
139 _adjacency: &HashMap<i64, Vec<(i64, f64)>>,
140 ) -> Result<Vec<ProcRow>, String> {
141 let mut state = self.state.lock().map_err(|e| format!("lock error: {e}"))?;
142
143 match procedure {
144 "build" => {
145 let dim: usize = args
146 .first()
147 .ok_or("vector.build requires dim argument")?
148 .parse()
149 .map_err(|_| "dim must be a positive integer")?;
150 if dim == 0 {
151 return Err("dim must be > 0".into());
152 }
153
154 let metric = match args.get(1).map(|s| s.to_lowercase()).as_deref() {
155 Some("cosine") => DistanceMetric::Cosine,
156 _ => DistanceMetric::L2,
157 };
158
159 state.index = Some(HnswIndex::new(
160 dim,
161 HnswConfig {
162 metric,
163 ..HnswConfig::default()
164 },
165 ));
166 state.staging = StagingBuffer::new();
167 state.metric = metric;
168 state.next_id = 0;
169
170 Ok(vec![vec![TypedValue::String(SmolStr::new(format!(
171 "built dim={dim} metric={metric:?}"
172 )))]])
173 }
174
175 "add" => {
176 let VectorState {
177 index,
178 staging,
179 next_id,
180 ..
181 } = &mut *state;
182 let index = index.as_mut().ok_or("call vector.build first")?;
183
184 let ext_id: usize = args
185 .first()
186 .ok_or("vector.add requires id argument")?
187 .parse()
188 .map_err(|_| "id must be a non-negative integer")?;
189
190 let csv = args
191 .get(1)
192 .ok_or("vector.add requires vector_csv argument")?;
193 let vector: Vec<f32> = csv
194 .split(',')
195 .map(|s| {
196 s.trim()
197 .parse::<f32>()
198 .map_err(|_| format!("invalid float in vector: '{}'", s.trim()))
199 })
200 .collect::<Result<_, _>>()?;
201
202 let needs_flush = staging.add(ext_id, vector);
203 if needs_flush {
204 staging.flush(index);
205 }
206
207 *next_id = (*next_id).max(ext_id + 1);
208
209 Ok(vec![vec![TypedValue::String(SmolStr::new("ok"))]])
210 }
211
212 "search" => {
213 let VectorState { index, staging, .. } = &mut *state;
214 let index = index.as_mut().ok_or("call vector.build first")?;
215
216 let csv = args
217 .first()
218 .ok_or("vector.search requires query_csv argument")?;
219 let query: Vec<f32> = csv
220 .split(',')
221 .map(|s| {
222 s.trim()
223 .parse::<f32>()
224 .map_err(|_| format!("invalid float in query: '{}'", s.trim()))
225 })
226 .collect::<Result<_, _>>()?;
227
228 let k: usize = args.get(1).and_then(|s| s.parse().ok()).unwrap_or(10);
229
230 if staging.pending_count() > 0 {
232 staging.flush(index);
233 }
234
235 let results = index.search(&query, k, k.max(50));
236
237 Ok(results
238 .into_iter()
239 .map(|(id, dist)| {
240 vec![
241 TypedValue::Int64(id as i64),
242 TypedValue::Double(dist as f64),
243 ]
244 })
245 .collect())
246 }
247
248 _ => Err(format!("unknown procedure: {procedure}")),
249 }
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 fn empty_adj() -> HashMap<i64, Vec<(i64, f64)>> {
258 HashMap::new()
259 }
260
261 #[test]
262 fn extension_metadata() {
263 let ext = VectorExtension::new();
264 assert_eq!(ext.name(), "vector");
265 assert!(!ext.needs_graph());
266 assert_eq!(ext.procedures().len(), 3);
267 }
268
269 #[test]
270 fn build_add_search() {
271 let ext = VectorExtension::new();
272 let adj = empty_adj();
273
274 let result = ext
276 .execute("build", &["3".into(), "l2".into()], &adj)
277 .unwrap();
278 assert_eq!(result.len(), 1);
279
280 ext.execute("add", &["0".into(), "1.0,0.0,0.0".into()], &adj)
282 .unwrap();
283 ext.execute("add", &["1".into(), "0.0,1.0,0.0".into()], &adj)
284 .unwrap();
285 ext.execute("add", &["2".into(), "0.9,0.1,0.0".into()], &adj)
286 .unwrap();
287
288 let results = ext
290 .execute("search", &["1.0,0.0,0.0".into(), "2".into()], &adj)
291 .unwrap();
292 assert!(!results.is_empty());
293 assert_eq!(results[0][0], TypedValue::Int64(0));
295 }
296
297 #[test]
298 fn search_without_build() {
299 let ext = VectorExtension::new();
300 let adj = empty_adj();
301 let result = ext.execute("search", &["1.0,0.0".into(), "5".into()], &adj);
302 assert!(result.is_err());
303 }
304
305 #[test]
306 fn add_without_build() {
307 let ext = VectorExtension::new();
308 let adj = empty_adj();
309 let result = ext.execute("add", &["0".into(), "1.0,0.0".into()], &adj);
310 assert!(result.is_err());
311 }
312
313 #[test]
314 fn unknown_procedure() {
315 let ext = VectorExtension::new();
316 let adj = empty_adj();
317 assert!(ext.execute("nonexistent", &[], &adj).is_err());
318 }
319
320 #[test]
321 fn cosine_search() {
322 let ext = VectorExtension::new();
323 let adj = empty_adj();
324
325 ext.execute("build", &["3".into(), "cosine".into()], &adj)
326 .unwrap();
327 ext.execute("add", &["0".into(), "1.0,0.0,0.0".into()], &adj)
328 .unwrap();
329 ext.execute("add", &["1".into(), "0.0,1.0,0.0".into()], &adj)
330 .unwrap();
331 ext.execute("add", &["2".into(), "0.9,0.1,0.0".into()], &adj)
332 .unwrap();
333
334 let results = ext
335 .execute("search", &["1.0,0.0,0.0".into(), "3".into()], &adj)
336 .unwrap();
337 assert_eq!(results.len(), 3);
338 assert_eq!(results[0][0], TypedValue::Int64(0));
340 if let TypedValue::Double(d) = results[0][1] {
341 assert!(d < 0.01, "cosine distance to identical = {d}");
342 }
343 }
344}