alopex_sql/executor/evaluator/
vector_ops.rs1use std::str::FromStr;
2
3use thiserror::Error;
4
5#[derive(Debug, Error, Clone, PartialEq, Eq)]
7pub enum VectorError {
8 #[error("argument count mismatch: expected 3, got {actual}")]
10 ArgumentCountMismatch { actual: usize },
11
12 #[error("type mismatch: first argument must be VECTOR column")]
14 TypeMismatch,
15
16 #[error("invalid vector literal: {reason}")]
18 InvalidVectorLiteral { reason: String },
19
20 #[error("invalid metric '{metric}': {reason}")]
22 InvalidMetric { metric: String, reason: String },
23
24 #[error("dimension mismatch: column has {expected} dimensions, query has {actual}")]
26 DimensionMismatch { expected: usize, actual: usize },
27
28 #[error("zero-norm vector cannot be used for cosine similarity")]
30 ZeroNormVector,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum VectorMetric {
36 Cosine,
37 L2,
38 Inner,
39}
40
41impl VectorMetric {
42 pub fn parse(s: &str) -> Result<Self, VectorError> {
44 let normalized = s.trim().to_lowercase();
45 match normalized.as_str() {
46 "cosine" => Ok(Self::Cosine),
47 "l2" => Ok(Self::L2),
48 "inner" => Ok(Self::Inner),
49 "" => Err(VectorError::InvalidMetric {
50 metric: s.to_string(),
51 reason: "empty metric string".into(),
52 }),
53 _ => Err(VectorError::InvalidMetric {
54 metric: s.to_string(),
55 reason: format!("expected 'cosine', 'l2', or 'inner', got '{}'", normalized),
56 }),
57 }
58 }
59}
60
61impl FromStr for VectorMetric {
62 type Err = VectorError;
63
64 fn from_str(s: &str) -> Result<Self, Self::Err> {
65 VectorMetric::parse(s)
66 }
67}
68
69pub fn vector_similarity(
74 column_value: &[f32],
75 query_vector: &[f32],
76 metric: VectorMetric,
77) -> Result<f64, VectorError> {
78 validate_dimensions(column_value, query_vector)?;
79
80 match metric {
81 VectorMetric::Cosine => compute_cosine_similarity(column_value, query_vector),
82 VectorMetric::L2 => compute_l2_distance(column_value, query_vector),
83 VectorMetric::Inner => compute_inner_product(column_value, query_vector),
84 }
85}
86
87pub fn vector_distance(
89 column_value: &[f32],
90 query_vector: &[f32],
91 metric: VectorMetric,
92) -> Result<f64, VectorError> {
93 vector_similarity(column_value, query_vector, metric)
94}
95
96pub fn vector_dims(vector: &[f32]) -> usize {
97 vector.len()
98}
99
100pub fn vector_norm(vector: &[f32]) -> f64 {
101 let sum_sq: f32 = vector.iter().map(|v| v * v).sum();
102 (sum_sq.sqrt()) as f64
103}
104
105fn validate_dimensions(a: &[f32], b: &[f32]) -> Result<(), VectorError> {
106 if a.len() != b.len() {
107 return Err(VectorError::DimensionMismatch {
108 expected: a.len(),
109 actual: b.len(),
110 });
111 }
112 Ok(())
113}
114
115fn compute_cosine_similarity(a: &[f32], b: &[f32]) -> Result<f64, VectorError> {
116 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
117 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
118 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
119
120 if norm_a == 0.0 || norm_b == 0.0 {
121 return Err(VectorError::ZeroNormVector);
122 }
123
124 Ok((dot / (norm_a * norm_b)) as f64)
125}
126
127fn compute_l2_distance(a: &[f32], b: &[f32]) -> Result<f64, VectorError> {
128 let sum_sq: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
129 Ok(sum_sq.sqrt() as f64)
130}
131
132fn compute_inner_product(a: &[f32], b: &[f32]) -> Result<f64, VectorError> {
133 let sum: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
134 Ok(sum as f64)
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn vector_metric_from_str_trims_and_lowercases() {
143 assert_eq!(
144 VectorMetric::parse(" COSINE ").unwrap(),
145 VectorMetric::Cosine
146 );
147 assert_eq!(VectorMetric::parse("l2").unwrap(), VectorMetric::L2);
148 assert_eq!(VectorMetric::parse("Inner").unwrap(), VectorMetric::Inner);
149 }
150
151 #[test]
152 fn vector_metric_from_str_empty_rejected() {
153 let err = VectorMetric::parse("").unwrap_err();
154 assert!(matches!(
155 err,
156 VectorError::InvalidMetric { reason, .. } if reason.contains("empty")
157 ));
158 }
159
160 #[test]
161 fn vector_metric_from_str_unknown_rejected() {
162 let err = VectorMetric::parse("minkowski").unwrap_err();
163 assert!(matches!(
164 err,
165 VectorError::InvalidMetric { reason, .. } if reason.contains("expected 'cosine', 'l2', or 'inner'")
166 ));
167 }
168
169 #[test]
170 fn vector_metric_from_str_trait_parse() {
171 let m: VectorMetric = "cosine".parse().unwrap();
172 assert_eq!(m, VectorMetric::Cosine);
173 }
174
175 #[test]
176 fn cosine_similarity_basic() {
177 let a = [1.0_f32, 0.0];
178 let b = [0.0_f32, 1.0];
179 let v = vector_similarity(&a, &b, VectorMetric::Cosine).unwrap();
180 assert!((v - 0.0).abs() < 1e-6);
181 }
182
183 #[test]
184 fn cosine_similarity_parallel() {
185 let a = [1.0_f32, 1.0];
186 let b = [2.0_f32, 2.0];
187 let v = vector_similarity(&a, &b, VectorMetric::Cosine).unwrap();
188 assert!((v - 1.0).abs() < 1e-6);
189 }
190
191 #[test]
192 fn cosine_similarity_zero_norm_error() {
193 let a = [0.0_f32, 0.0];
194 let b = [1.0_f32, 1.0];
195 let err = vector_similarity(&a, &b, VectorMetric::Cosine).unwrap_err();
196 assert!(matches!(err, VectorError::ZeroNormVector));
197 }
198
199 #[test]
200 fn l2_distance_basic() {
201 let a = [0.0_f32, 0.0];
202 let b = [3.0_f32, 4.0];
203 let v = vector_similarity(&a, &b, VectorMetric::L2).unwrap();
204 assert!((v - 5.0).abs() < 1e-6);
205 }
206
207 #[test]
208 fn inner_product_basic() {
209 let a = [1.0_f32, 2.0, 3.0];
210 let b = [4.0_f32, 5.0, 6.0];
211 let v = vector_similarity(&a, &b, VectorMetric::Inner).unwrap();
212 assert!((v - 32.0).abs() < 1e-6);
213 }
214
215 #[test]
216 fn dimension_mismatch_rejected() {
217 let a = [1.0_f32, 2.0];
218 let b = [1.0_f32, 2.0, 3.0];
219 let err = vector_similarity(&a, &b, VectorMetric::L2).unwrap_err();
220 assert!(matches!(
221 err,
222 VectorError::DimensionMismatch {
223 expected: 2,
224 actual: 3
225 }
226 ));
227 }
228
229 #[test]
230 fn vector_distance_alias() {
231 let a = [1.0_f32, 2.0];
232 let b = [3.0_f32, 4.0];
233 let sim = vector_similarity(&a, &b, VectorMetric::Inner).unwrap();
234 let dist = vector_distance(&a, &b, VectorMetric::Inner).unwrap();
235 assert!((sim - dist).abs() < 1e-6);
236 }
237}