1pub mod distance;
12pub mod hnsw;
13pub mod staging;
14
15use std::collections::HashMap;
16use std::sync::Mutex;
17
18use kyu_extension::{Extension, ProcColumn, ProcParam, ProcRow, ProcedureSignature};
19use kyu_types::{LogicalType, TypedValue};
20use smol_str::SmolStr;
21
22use crate::distance::DistanceMetric;
23use crate::hnsw::{HnswConfig, HnswIndex};
24use crate::staging::StagingBuffer;
25
26struct VectorState {
28 index: Option<HnswIndex>,
29 staging: StagingBuffer,
30 metric: DistanceMetric,
31 next_id: usize,
32}
33
34impl Default for VectorState {
35 fn default() -> Self {
36 Self {
37 index: None,
38 staging: StagingBuffer::new(),
39 metric: DistanceMetric::L2,
40 next_id: 0,
41 }
42 }
43}
44
45pub struct VectorExtension {
47 state: Mutex<VectorState>,
48}
49
50impl VectorExtension {
51 pub fn new() -> Self {
52 Self {
53 state: Mutex::new(VectorState::default()),
54 }
55 }
56}
57
58impl Default for VectorExtension {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl Extension for VectorExtension {
65 fn name(&self) -> &str {
66 "vector"
67 }
68
69 fn needs_graph(&self) -> bool {
70 false
71 }
72
73 fn procedures(&self) -> Vec<ProcedureSignature> {
74 vec![
75 ProcedureSignature {
76 name: "build".into(),
77 params: vec![
78 ProcParam {
79 name: "dim".into(),
80 type_desc: "INT64".into(),
81 },
82 ProcParam {
83 name: "metric".into(),
84 type_desc: "STRING".into(),
85 },
86 ],
87 columns: vec![ProcColumn {
88 name: "status".into(),
89 data_type: LogicalType::String,
90 }],
91 },
92 ProcedureSignature {
93 name: "add".into(),
94 params: vec![
95 ProcParam {
96 name: "id".into(),
97 type_desc: "INT64".into(),
98 },
99 ProcParam {
100 name: "vector_csv".into(),
101 type_desc: "STRING".into(),
102 },
103 ],
104 columns: vec![ProcColumn {
105 name: "status".into(),
106 data_type: LogicalType::String,
107 }],
108 },
109 ProcedureSignature {
110 name: "search".into(),
111 params: vec![
112 ProcParam {
113 name: "query_csv".into(),
114 type_desc: "STRING".into(),
115 },
116 ProcParam {
117 name: "k".into(),
118 type_desc: "INT64".into(),
119 },
120 ],
121 columns: vec![
122 ProcColumn {
123 name: "id".into(),
124 data_type: LogicalType::Int64,
125 },
126 ProcColumn {
127 name: "distance".into(),
128 data_type: LogicalType::Double,
129 },
130 ],
131 },
132 ]
133 }
134
135 fn execute(
136 &self,
137 procedure: &str,
138 args: &[String],
139 _adjacency: &HashMap<i64, Vec<(i64, f64)>>,
140 ) -> Result<Vec<ProcRow>, String> {
141 let mut state = self
142 .state
143 .lock()
144 .map_err(|e| format!("lock error: {e}"))?;
145
146 match procedure {
147 "build" => {
148 let dim: usize = args
149 .first()
150 .ok_or("vector.build requires dim argument")?
151 .parse()
152 .map_err(|_| "dim must be a positive integer")?;
153 if dim == 0 {
154 return Err("dim must be > 0".into());
155 }
156
157 let metric = match args.get(1).map(|s| s.to_lowercase()).as_deref() {
158 Some("cosine") => DistanceMetric::Cosine,
159 _ => DistanceMetric::L2,
160 };
161
162 state.index = Some(HnswIndex::new(dim, HnswConfig {
163 metric,
164 ..HnswConfig::default()
165 }));
166 state.staging = StagingBuffer::new();
167 state.metric = metric;
168 state.next_id = 0;
169
170 Ok(vec![vec![TypedValue::String(SmolStr::new(format!(
171 "built dim={dim} metric={metric:?}"
172 )))]])
173 }
174
175 "add" => {
176 let VectorState { index, staging, next_id, .. } = &mut *state;
177 let index = index.as_mut().ok_or("call vector.build first")?;
178
179 let ext_id: usize = args
180 .first()
181 .ok_or("vector.add requires id argument")?
182 .parse()
183 .map_err(|_| "id must be a non-negative integer")?;
184
185 let csv = args.get(1).ok_or("vector.add requires vector_csv argument")?;
186 let vector: Vec<f32> = csv
187 .split(',')
188 .map(|s| {
189 s.trim()
190 .parse::<f32>()
191 .map_err(|_| format!("invalid float in vector: '{}'", s.trim()))
192 })
193 .collect::<Result<_, _>>()?;
194
195 let needs_flush = staging.add(ext_id, vector);
196 if needs_flush {
197 staging.flush(index);
198 }
199
200 *next_id = (*next_id).max(ext_id + 1);
201
202 Ok(vec![vec![TypedValue::String(SmolStr::new("ok"))]])
203 }
204
205 "search" => {
206 let VectorState { index, staging, .. } = &mut *state;
207 let index = index.as_mut().ok_or("call vector.build first")?;
208
209 let csv = args.first().ok_or("vector.search requires query_csv argument")?;
210 let query: Vec<f32> = csv
211 .split(',')
212 .map(|s| {
213 s.trim()
214 .parse::<f32>()
215 .map_err(|_| format!("invalid float in query: '{}'", s.trim()))
216 })
217 .collect::<Result<_, _>>()?;
218
219 let k: usize = args
220 .get(1)
221 .and_then(|s| s.parse().ok())
222 .unwrap_or(10);
223
224 if staging.pending_count() > 0 {
226 staging.flush(index);
227 }
228
229 let results = index.search(&query, k, k.max(50));
230
231 Ok(results
232 .into_iter()
233 .map(|(id, dist)| {
234 vec![
235 TypedValue::Int64(id as i64),
236 TypedValue::Double(dist as f64),
237 ]
238 })
239 .collect())
240 }
241
242 _ => Err(format!("unknown procedure: {procedure}")),
243 }
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 fn empty_adj() -> HashMap<i64, Vec<(i64, f64)>> {
252 HashMap::new()
253 }
254
255 #[test]
256 fn extension_metadata() {
257 let ext = VectorExtension::new();
258 assert_eq!(ext.name(), "vector");
259 assert!(!ext.needs_graph());
260 assert_eq!(ext.procedures().len(), 3);
261 }
262
263 #[test]
264 fn build_add_search() {
265 let ext = VectorExtension::new();
266 let adj = empty_adj();
267
268 let result = ext.execute("build", &["3".into(), "l2".into()], &adj).unwrap();
270 assert_eq!(result.len(), 1);
271
272 ext.execute("add", &["0".into(), "1.0,0.0,0.0".into()], &adj).unwrap();
274 ext.execute("add", &["1".into(), "0.0,1.0,0.0".into()], &adj).unwrap();
275 ext.execute("add", &["2".into(), "0.9,0.1,0.0".into()], &adj).unwrap();
276
277 let results = ext.execute("search", &["1.0,0.0,0.0".into(), "2".into()], &adj).unwrap();
279 assert!(!results.is_empty());
280 assert_eq!(results[0][0], TypedValue::Int64(0));
282 }
283
284 #[test]
285 fn search_without_build() {
286 let ext = VectorExtension::new();
287 let adj = empty_adj();
288 let result = ext.execute("search", &["1.0,0.0".into(), "5".into()], &adj);
289 assert!(result.is_err());
290 }
291
292 #[test]
293 fn add_without_build() {
294 let ext = VectorExtension::new();
295 let adj = empty_adj();
296 let result = ext.execute("add", &["0".into(), "1.0,0.0".into()], &adj);
297 assert!(result.is_err());
298 }
299
300 #[test]
301 fn unknown_procedure() {
302 let ext = VectorExtension::new();
303 let adj = empty_adj();
304 assert!(ext.execute("nonexistent", &[], &adj).is_err());
305 }
306
307 #[test]
308 fn cosine_search() {
309 let ext = VectorExtension::new();
310 let adj = empty_adj();
311
312 ext.execute("build", &["3".into(), "cosine".into()], &adj).unwrap();
313 ext.execute("add", &["0".into(), "1.0,0.0,0.0".into()], &adj).unwrap();
314 ext.execute("add", &["1".into(), "0.0,1.0,0.0".into()], &adj).unwrap();
315 ext.execute("add", &["2".into(), "0.9,0.1,0.0".into()], &adj).unwrap();
316
317 let results = ext.execute("search", &["1.0,0.0,0.0".into(), "3".into()], &adj).unwrap();
318 assert_eq!(results.len(), 3);
319 assert_eq!(results[0][0], TypedValue::Int64(0));
321 if let TypedValue::Double(d) = results[0][1] {
322 assert!(d < 0.01, "cosine distance to identical = {d}");
323 }
324 }
325}