1use serde::{Deserialize, Serialize};
4use std::str::FromStr;
5
6use crate::{Error, Result};
7pub mod columnar;
8pub mod flat;
9pub mod hnsw;
10pub mod simd;
11
12pub use columnar::{
14 key_layout as vector_key_layout, AppendResult, SearchStats, VectorSearchParams,
15 VectorSearchResult, VectorSegment, VectorStoreConfig, VectorStoreManager,
16};
17pub use hnsw::{HnswConfig, HnswIndex, HnswSearchResult, HnswStats};
18pub use simd::{select_kernel, DistanceKernel, ScalarKernel};
19
20#[cfg(all(test, not(target_arch = "wasm32")))]
21mod disk;
22
23#[cfg(all(test, not(target_arch = "wasm32")))]
24mod integration;
25
26#[derive(Clone, Debug, Default, PartialEq, Eq)]
30pub struct DeleteResult {
31 pub vectors_deleted: u64,
33 pub segments_modified: Vec<u64>,
35}
36
37#[derive(Clone, Debug, PartialEq, Eq)]
41pub struct CompactionResult {
42 pub old_segment_id: u64,
44 pub new_segment_id: Option<u64>,
46 pub vectors_removed: u64,
48 pub space_reclaimed: u64,
50}
51
52#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
54pub enum Metric {
55 Cosine,
57 L2,
59 InnerProduct,
61}
62
63impl Metric {
64 pub fn as_str(&self) -> &'static str {
66 match self {
67 Metric::Cosine => "cosine",
68 Metric::L2 => "l2",
69 Metric::InnerProduct => "inner",
70 }
71 }
72}
73
74impl FromStr for Metric {
75 type Err = Error;
76
77 fn from_str(s: &str) -> Result<Self> {
78 match s.to_ascii_lowercase().as_str() {
79 "cosine" => Ok(Metric::Cosine),
80 "l2" => Ok(Metric::L2),
81 "inner" | "inner_product" | "innerproduct" => Ok(Metric::InnerProduct),
82 other => Err(Error::UnsupportedMetric {
83 metric: other.to_string(),
84 }),
85 }
86 }
87}
88
89#[derive(Clone, Copy, Debug, PartialEq, Eq)]
91pub struct VectorType {
92 dim: usize,
93 metric: Metric,
94}
95
96impl VectorType {
97 pub fn new(dim: usize, metric: Metric) -> Self {
99 Self { dim, metric }
100 }
101
102 pub fn dim(&self) -> usize {
104 self.dim
105 }
106
107 pub fn metric(&self) -> Metric {
109 self.metric
110 }
111
112 pub fn validate(&self, vector: &[f32]) -> Result<()> {
114 validate_dimensions(self.dim, vector.len())
115 }
116
117 pub fn score(&self, query: &[f32], item: &[f32]) -> Result<f32> {
119 self.validate(query)?;
120 self.validate(item)?;
121 score(self.metric, query, item)
122 }
123}
124
125pub fn validate_dimensions(expected: usize, actual: usize) -> Result<()> {
127 if expected != actual {
128 return Err(Error::DimensionMismatch { expected, actual });
129 }
130 Ok(())
131}
132
133pub fn score(metric: Metric, query: &[f32], item: &[f32]) -> Result<f32> {
139 validate_dimensions(query.len(), item.len())?;
140
141 match metric {
142 Metric::Cosine => {
143 let dot = query
144 .iter()
145 .zip(item.iter())
146 .map(|(a, b)| a * b)
147 .sum::<f32>();
148 let q_norm = query.iter().map(|v| v * v).sum::<f32>().sqrt();
149 let i_norm = item.iter().map(|v| v * v).sum::<f32>().sqrt();
150
151 if q_norm == 0.0 || i_norm == 0.0 {
152 return Ok(0.0);
153 }
154
155 Ok(dot / (q_norm * i_norm))
156 }
157 Metric::L2 => {
158 let dist = query
159 .iter()
160 .zip(item.iter())
161 .map(|(a, b)| {
162 let d = a - b;
163 d * d
164 })
165 .sum::<f32>()
166 .sqrt();
167 Ok(-dist)
168 }
169 Metric::InnerProduct => Ok(query
170 .iter()
171 .zip(item.iter())
172 .map(|(a, b)| a * b)
173 .sum::<f32>()),
174 }
175}
176
177#[cfg(all(test, not(target_arch = "wasm32")))]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn rejects_dimension_mismatch() {
183 let vt = VectorType::new(3, Metric::Cosine);
184 let err = vt.validate(&[1.0, 2.0]).unwrap_err();
185 assert!(matches!(
186 err,
187 Error::DimensionMismatch {
188 expected: 3,
189 actual: 2
190 }
191 ));
192
193 let err = score(Metric::L2, &[1.0, 2.0], &[1.0]).unwrap_err();
194 assert!(matches!(
195 err,
196 Error::DimensionMismatch {
197 expected: 2,
198 actual: 1
199 }
200 ));
201 }
202
203 #[test]
204 fn computes_cosine() {
205 let vt = VectorType::new(3, Metric::Cosine);
206 let s = vt.score(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]).unwrap();
207 assert_eq!(s, 0.0);
208
209 let s = vt.score(&[1.0, 1.0, 0.0], &[1.0, 1.0, 0.0]).unwrap();
210 assert!((s - 1.0).abs() < 1e-6);
211 }
212
213 #[test]
214 fn computes_l2_as_negative_distance() {
215 let s = score(Metric::L2, &[0.0, 0.0], &[3.0, 4.0]).unwrap();
216 assert!((s + 5.0).abs() < 1e-6);
217 }
218
219 #[test]
220 fn computes_inner_product() {
221 let s = score(Metric::InnerProduct, &[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]).unwrap();
222 assert_eq!(s, 32.0);
223 }
224
225 #[test]
226 fn parses_metric_from_str() {
227 assert_eq!(Metric::from_str("cosine").unwrap(), Metric::Cosine);
228 assert_eq!(Metric::from_str("L2").unwrap(), Metric::L2);
229 assert_eq!(
230 Metric::from_str("inner_product").unwrap(),
231 Metric::InnerProduct
232 );
233
234 let err = Metric::from_str("chebyshev").unwrap_err();
235 assert!(matches!(err, Error::UnsupportedMetric { .. }));
236 }
237}