alopex_sql/executor/evaluator/
vector_ops.rs1use std::str::FromStr;
2
3use thiserror::Error;
4
5#[derive(Debug, Error, Clone, PartialEq, Eq)]
7pub enum VectorError {
8 #[error("argument count mismatch: expected 3, got {actual}")]
10 ArgumentCountMismatch { actual: usize },
11
12 #[error("type mismatch: first argument must be VECTOR column")]
14 TypeMismatch,
15
16 #[error("invalid vector literal: {reason}")]
18 InvalidVectorLiteral { reason: String },
19
20 #[error("invalid metric '{metric}': {reason}")]
22 InvalidMetric { metric: String, reason: String },
23
24 #[error("dimension mismatch: column has {expected} dimensions, query has {actual}")]
26 DimensionMismatch { expected: usize, actual: usize },
27
28 #[error("zero-norm vector cannot be used for cosine similarity")]
30 ZeroNormVector,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum VectorMetric {
36 Cosine,
37 L2,
38 Inner,
39}
40
41impl VectorMetric {
42 pub fn parse(s: &str) -> Result<Self, VectorError> {
44 let normalized = s.trim().to_lowercase();
45 match normalized.as_str() {
46 "cosine" => Ok(Self::Cosine),
47 "l2" => Ok(Self::L2),
48 "inner" => Ok(Self::Inner),
49 "" => Err(VectorError::InvalidMetric {
50 metric: s.to_string(),
51 reason: "empty metric string".into(),
52 }),
53 _ => Err(VectorError::InvalidMetric {
54 metric: s.to_string(),
55 reason: format!("expected 'cosine', 'l2', or 'inner', got '{}'", normalized),
56 }),
57 }
58 }
59}
60
61impl FromStr for VectorMetric {
62 type Err = VectorError;
63
64 fn from_str(s: &str) -> Result<Self, Self::Err> {
65 VectorMetric::parse(s)
66 }
67}
68
69pub fn vector_similarity(
74 column_value: &[f32],
75 query_vector: &[f32],
76 metric: VectorMetric,
77) -> Result<f64, VectorError> {
78 validate_dimensions(column_value, query_vector)?;
79
80 match metric {
81 VectorMetric::Cosine => compute_cosine_similarity(column_value, query_vector),
82 VectorMetric::L2 => compute_l2_distance(column_value, query_vector),
83 VectorMetric::Inner => compute_inner_product(column_value, query_vector),
84 }
85}
86
87pub fn vector_distance(
89 column_value: &[f32],
90 query_vector: &[f32],
91 metric: VectorMetric,
92) -> Result<f64, VectorError> {
93 vector_similarity(column_value, query_vector, metric)
94}
95
96fn validate_dimensions(a: &[f32], b: &[f32]) -> Result<(), VectorError> {
97 if a.len() != b.len() {
98 return Err(VectorError::DimensionMismatch {
99 expected: a.len(),
100 actual: b.len(),
101 });
102 }
103 Ok(())
104}
105
106fn compute_cosine_similarity(a: &[f32], b: &[f32]) -> Result<f64, VectorError> {
107 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
108 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
109 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
110
111 if norm_a == 0.0 || norm_b == 0.0 {
112 return Err(VectorError::ZeroNormVector);
113 }
114
115 Ok((dot / (norm_a * norm_b)) as f64)
116}
117
118fn compute_l2_distance(a: &[f32], b: &[f32]) -> Result<f64, VectorError> {
119 let sum_sq: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
120 Ok(sum_sq.sqrt() as f64)
121}
122
123fn compute_inner_product(a: &[f32], b: &[f32]) -> Result<f64, VectorError> {
124 let sum: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
125 Ok(sum as f64)
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[test]
133 fn vector_metric_from_str_trims_and_lowercases() {
134 assert_eq!(
135 VectorMetric::parse(" COSINE ").unwrap(),
136 VectorMetric::Cosine
137 );
138 assert_eq!(VectorMetric::parse("l2").unwrap(), VectorMetric::L2);
139 assert_eq!(VectorMetric::parse("Inner").unwrap(), VectorMetric::Inner);
140 }
141
142 #[test]
143 fn vector_metric_from_str_empty_rejected() {
144 let err = VectorMetric::parse("").unwrap_err();
145 assert!(matches!(
146 err,
147 VectorError::InvalidMetric { reason, .. } if reason.contains("empty")
148 ));
149 }
150
151 #[test]
152 fn vector_metric_from_str_unknown_rejected() {
153 let err = VectorMetric::parse("minkowski").unwrap_err();
154 assert!(matches!(
155 err,
156 VectorError::InvalidMetric { reason, .. } if reason.contains("expected 'cosine', 'l2', or 'inner'")
157 ));
158 }
159
160 #[test]
161 fn vector_metric_from_str_trait_parse() {
162 let m: VectorMetric = "cosine".parse().unwrap();
163 assert_eq!(m, VectorMetric::Cosine);
164 }
165
166 #[test]
167 fn cosine_similarity_basic() {
168 let a = [1.0_f32, 0.0];
169 let b = [0.0_f32, 1.0];
170 let v = vector_similarity(&a, &b, VectorMetric::Cosine).unwrap();
171 assert!((v - 0.0).abs() < 1e-6);
172 }
173
174 #[test]
175 fn cosine_similarity_parallel() {
176 let a = [1.0_f32, 1.0];
177 let b = [2.0_f32, 2.0];
178 let v = vector_similarity(&a, &b, VectorMetric::Cosine).unwrap();
179 assert!((v - 1.0).abs() < 1e-6);
180 }
181
182 #[test]
183 fn cosine_similarity_zero_norm_error() {
184 let a = [0.0_f32, 0.0];
185 let b = [1.0_f32, 1.0];
186 let err = vector_similarity(&a, &b, VectorMetric::Cosine).unwrap_err();
187 assert!(matches!(err, VectorError::ZeroNormVector));
188 }
189
190 #[test]
191 fn l2_distance_basic() {
192 let a = [0.0_f32, 0.0];
193 let b = [3.0_f32, 4.0];
194 let v = vector_similarity(&a, &b, VectorMetric::L2).unwrap();
195 assert!((v - 5.0).abs() < 1e-6);
196 }
197
198 #[test]
199 fn inner_product_basic() {
200 let a = [1.0_f32, 2.0, 3.0];
201 let b = [4.0_f32, 5.0, 6.0];
202 let v = vector_similarity(&a, &b, VectorMetric::Inner).unwrap();
203 assert!((v - 32.0).abs() < 1e-6);
204 }
205
206 #[test]
207 fn dimension_mismatch_rejected() {
208 let a = [1.0_f32, 2.0];
209 let b = [1.0_f32, 2.0, 3.0];
210 let err = vector_similarity(&a, &b, VectorMetric::L2).unwrap_err();
211 assert!(matches!(
212 err,
213 VectorError::DimensionMismatch {
214 expected: 2,
215 actual: 3
216 }
217 ));
218 }
219
220 #[test]
221 fn vector_distance_alias() {
222 let a = [1.0_f32, 2.0];
223 let b = [3.0_f32, 4.0];
224 let sim = vector_similarity(&a, &b, VectorMetric::Inner).unwrap();
225 let dist = vector_distance(&a, &b, VectorMetric::Inner).unwrap();
226 assert!((sim - dist).abs() < 1e-6);
227 }
228}