1use serde::{Deserialize, Serialize};
4use std::str::FromStr;
5
6use crate::{Error, Result};
7#[cfg(feature = "async")]
8pub mod async_vector;
9pub mod columnar;
10pub mod flat;
11pub mod hnsw;
12pub mod simd;
13
14#[cfg(feature = "tokio")]
16pub use async_vector::AsyncVectorStoreAdapter;
17#[cfg(feature = "async")]
18pub use async_vector::{AsyncHnswIndex, AsyncVectorStore};
19pub use columnar::{
20 key_layout as vector_key_layout, AppendResult, SearchStats, VectorSearchParams,
21 VectorSearchResult, VectorSegment, VectorStoreConfig, VectorStoreManager,
22};
23pub use hnsw::{HnswConfig, HnswIndex, HnswSearchResult, HnswStats};
24pub use simd::{select_kernel, DistanceKernel, ScalarKernel};
25
26#[cfg(all(test, not(target_arch = "wasm32")))]
27mod disk;
28
29#[cfg(all(test, not(target_arch = "wasm32")))]
30mod integration;
31
32#[derive(Clone, Debug, Default, PartialEq, Eq)]
36pub struct DeleteResult {
37 pub vectors_deleted: u64,
39 pub segments_modified: Vec<u64>,
41}
42
43#[derive(Clone, Debug, PartialEq, Eq)]
47pub struct CompactionResult {
48 pub old_segment_id: u64,
50 pub new_segment_id: Option<u64>,
52 pub vectors_removed: u64,
54 pub space_reclaimed: u64,
56}
57
58#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
60pub enum Metric {
61 Cosine,
63 L2,
65 InnerProduct,
67}
68
69impl Metric {
70 pub fn as_str(&self) -> &'static str {
72 match self {
73 Metric::Cosine => "cosine",
74 Metric::L2 => "l2",
75 Metric::InnerProduct => "inner",
76 }
77 }
78}
79
80impl FromStr for Metric {
81 type Err = Error;
82
83 fn from_str(s: &str) -> Result<Self> {
84 match s.to_ascii_lowercase().as_str() {
85 "cosine" => Ok(Metric::Cosine),
86 "l2" => Ok(Metric::L2),
87 "inner" | "inner_product" | "innerproduct" => Ok(Metric::InnerProduct),
88 other => Err(Error::UnsupportedMetric {
89 metric: other.to_string(),
90 }),
91 }
92 }
93}
94
95#[derive(Clone, Copy, Debug, PartialEq, Eq)]
97pub struct VectorType {
98 dim: usize,
99 metric: Metric,
100}
101
102impl VectorType {
103 pub fn new(dim: usize, metric: Metric) -> Self {
105 Self { dim, metric }
106 }
107
108 pub fn dim(&self) -> usize {
110 self.dim
111 }
112
113 pub fn metric(&self) -> Metric {
115 self.metric
116 }
117
118 pub fn validate(&self, vector: &[f32]) -> Result<()> {
120 validate_dimensions(self.dim, vector.len())
121 }
122
123 pub fn score(&self, query: &[f32], item: &[f32]) -> Result<f32> {
125 self.validate(query)?;
126 self.validate(item)?;
127 score(self.metric, query, item)
128 }
129}
130
131pub fn validate_dimensions(expected: usize, actual: usize) -> Result<()> {
133 if expected != actual {
134 return Err(Error::DimensionMismatch { expected, actual });
135 }
136 Ok(())
137}
138
139pub fn score(metric: Metric, query: &[f32], item: &[f32]) -> Result<f32> {
145 validate_dimensions(query.len(), item.len())?;
146
147 match metric {
148 Metric::Cosine => {
149 let dot = query
150 .iter()
151 .zip(item.iter())
152 .map(|(a, b)| a * b)
153 .sum::<f32>();
154 let q_norm = query.iter().map(|v| v * v).sum::<f32>().sqrt();
155 let i_norm = item.iter().map(|v| v * v).sum::<f32>().sqrt();
156
157 if q_norm == 0.0 || i_norm == 0.0 {
158 return Ok(0.0);
159 }
160
161 Ok(dot / (q_norm * i_norm))
162 }
163 Metric::L2 => {
164 let dist = query
165 .iter()
166 .zip(item.iter())
167 .map(|(a, b)| {
168 let d = a - b;
169 d * d
170 })
171 .sum::<f32>()
172 .sqrt();
173 Ok(-dist)
174 }
175 Metric::InnerProduct => Ok(query
176 .iter()
177 .zip(item.iter())
178 .map(|(a, b)| a * b)
179 .sum::<f32>()),
180 }
181}
182
183#[cfg(all(test, not(target_arch = "wasm32")))]
184mod tests {
185 use super::*;
186
187 #[test]
188 fn rejects_dimension_mismatch() {
189 let vt = VectorType::new(3, Metric::Cosine);
190 let err = vt.validate(&[1.0, 2.0]).unwrap_err();
191 assert!(matches!(
192 err,
193 Error::DimensionMismatch {
194 expected: 3,
195 actual: 2
196 }
197 ));
198
199 let err = score(Metric::L2, &[1.0, 2.0], &[1.0]).unwrap_err();
200 assert!(matches!(
201 err,
202 Error::DimensionMismatch {
203 expected: 2,
204 actual: 1
205 }
206 ));
207 }
208
209 #[test]
210 fn computes_cosine() {
211 let vt = VectorType::new(3, Metric::Cosine);
212 let s = vt.score(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]).unwrap();
213 assert_eq!(s, 0.0);
214
215 let s = vt.score(&[1.0, 1.0, 0.0], &[1.0, 1.0, 0.0]).unwrap();
216 assert!((s - 1.0).abs() < 1e-6);
217 }
218
219 #[test]
220 fn computes_l2_as_negative_distance() {
221 let s = score(Metric::L2, &[0.0, 0.0], &[3.0, 4.0]).unwrap();
222 assert!((s + 5.0).abs() < 1e-6);
223 }
224
225 #[test]
226 fn computes_inner_product() {
227 let s = score(Metric::InnerProduct, &[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]).unwrap();
228 assert_eq!(s, 32.0);
229 }
230
231 #[test]
232 fn parses_metric_from_str() {
233 assert_eq!(Metric::from_str("cosine").unwrap(), Metric::Cosine);
234 assert_eq!(Metric::from_str("L2").unwrap(), Metric::L2);
235 assert_eq!(
236 Metric::from_str("inner_product").unwrap(),
237 Metric::InnerProduct
238 );
239
240 let err = Metric::from_str("chebyshev").unwrap_err();
241 assert!(matches!(err, Error::UnsupportedMetric { .. }));
242 }
243}