alopex_sql/executor/evaluator/
vector_ops.rs

1use std::str::FromStr;
2
3use thiserror::Error;
4
5/// ベクトル演算で発生するエラー。
6#[derive(Debug, Error, Clone, PartialEq, Eq)]
7pub enum VectorError {
8    /// 引数数が想定と異なる。
9    #[error("argument count mismatch: expected 3, got {actual}")]
10    ArgumentCountMismatch { actual: usize },
11
12    /// 第一引数がベクトル列ではない。
13    #[error("type mismatch: first argument must be VECTOR column")]
14    TypeMismatch,
15
16    /// ベクトルリテラルが不正。
17    #[error("invalid vector literal: {reason}")]
18    InvalidVectorLiteral { reason: String },
19
20    /// メトリクス指定が不正。
21    #[error("invalid metric '{metric}': {reason}")]
22    InvalidMetric { metric: String, reason: String },
23
24    /// ベクトル次元が一致しない。
25    #[error("dimension mismatch: column has {expected} dimensions, query has {actual}")]
26    DimensionMismatch { expected: usize, actual: usize },
27
28    /// コサイン類似度でゼロノルムベクトルが渡された。
29    #[error("zero-norm vector cannot be used for cosine similarity")]
30    ZeroNormVector,
31}
32
33/// ベクトル類似度/距離のメトリクス。
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum VectorMetric {
36    Cosine,
37    L2,
38    Inner,
39}
40
41impl VectorMetric {
42    /// 文字列をメトリクスに変換する(前後空白除去・小文字化)。
43    pub fn parse(s: &str) -> Result<Self, VectorError> {
44        let normalized = s.trim().to_lowercase();
45        match normalized.as_str() {
46            "cosine" => Ok(Self::Cosine),
47            "l2" => Ok(Self::L2),
48            "inner" => Ok(Self::Inner),
49            "" => Err(VectorError::InvalidMetric {
50                metric: s.to_string(),
51                reason: "empty metric string".into(),
52            }),
53            _ => Err(VectorError::InvalidMetric {
54                metric: s.to_string(),
55                reason: format!("expected 'cosine', 'l2', or 'inner', got '{}'", normalized),
56            }),
57        }
58    }
59}
60
61impl FromStr for VectorMetric {
62    type Err = VectorError;
63
64    fn from_str(s: &str) -> Result<Self, Self::Err> {
65        VectorMetric::parse(s)
66    }
67}
68
69/// ベクトル類似度/距離を計算する。
70///
71/// - 内部計算: f32(メモリ効率)
72/// - 返却値: f64(SQL DOUBLE として返す想定)
73pub fn vector_similarity(
74    column_value: &[f32],
75    query_vector: &[f32],
76    metric: VectorMetric,
77) -> Result<f64, VectorError> {
78    validate_dimensions(column_value, query_vector)?;
79
80    match metric {
81        VectorMetric::Cosine => compute_cosine_similarity(column_value, query_vector),
82        VectorMetric::L2 => compute_l2_distance(column_value, query_vector),
83        VectorMetric::Inner => compute_inner_product(column_value, query_vector),
84    }
85}
86
87/// vector_similarity のエイリアス。距離メトリクスでも同一の実装を利用する。
88pub fn vector_distance(
89    column_value: &[f32],
90    query_vector: &[f32],
91    metric: VectorMetric,
92) -> Result<f64, VectorError> {
93    vector_similarity(column_value, query_vector, metric)
94}
95
96fn validate_dimensions(a: &[f32], b: &[f32]) -> Result<(), VectorError> {
97    if a.len() != b.len() {
98        return Err(VectorError::DimensionMismatch {
99            expected: a.len(),
100            actual: b.len(),
101        });
102    }
103    Ok(())
104}
105
106fn compute_cosine_similarity(a: &[f32], b: &[f32]) -> Result<f64, VectorError> {
107    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
108    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
109    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
110
111    if norm_a == 0.0 || norm_b == 0.0 {
112        return Err(VectorError::ZeroNormVector);
113    }
114
115    Ok((dot / (norm_a * norm_b)) as f64)
116}
117
118fn compute_l2_distance(a: &[f32], b: &[f32]) -> Result<f64, VectorError> {
119    let sum_sq: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
120    Ok(sum_sq.sqrt() as f64)
121}
122
123fn compute_inner_product(a: &[f32], b: &[f32]) -> Result<f64, VectorError> {
124    let sum: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
125    Ok(sum as f64)
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn vector_metric_from_str_trims_and_lowercases() {
134        assert_eq!(
135            VectorMetric::parse(" COSINE ").unwrap(),
136            VectorMetric::Cosine
137        );
138        assert_eq!(VectorMetric::parse("l2").unwrap(), VectorMetric::L2);
139        assert_eq!(VectorMetric::parse("Inner").unwrap(), VectorMetric::Inner);
140    }
141
142    #[test]
143    fn vector_metric_from_str_empty_rejected() {
144        let err = VectorMetric::parse("").unwrap_err();
145        assert!(matches!(
146            err,
147            VectorError::InvalidMetric { reason, .. } if reason.contains("empty")
148        ));
149    }
150
151    #[test]
152    fn vector_metric_from_str_unknown_rejected() {
153        let err = VectorMetric::parse("minkowski").unwrap_err();
154        assert!(matches!(
155            err,
156            VectorError::InvalidMetric { reason, .. } if reason.contains("expected 'cosine', 'l2', or 'inner'")
157        ));
158    }
159
160    #[test]
161    fn vector_metric_from_str_trait_parse() {
162        let m: VectorMetric = "cosine".parse().unwrap();
163        assert_eq!(m, VectorMetric::Cosine);
164    }
165
166    #[test]
167    fn cosine_similarity_basic() {
168        let a = [1.0_f32, 0.0];
169        let b = [0.0_f32, 1.0];
170        let v = vector_similarity(&a, &b, VectorMetric::Cosine).unwrap();
171        assert!((v - 0.0).abs() < 1e-6);
172    }
173
174    #[test]
175    fn cosine_similarity_parallel() {
176        let a = [1.0_f32, 1.0];
177        let b = [2.0_f32, 2.0];
178        let v = vector_similarity(&a, &b, VectorMetric::Cosine).unwrap();
179        assert!((v - 1.0).abs() < 1e-6);
180    }
181
182    #[test]
183    fn cosine_similarity_zero_norm_error() {
184        let a = [0.0_f32, 0.0];
185        let b = [1.0_f32, 1.0];
186        let err = vector_similarity(&a, &b, VectorMetric::Cosine).unwrap_err();
187        assert!(matches!(err, VectorError::ZeroNormVector));
188    }
189
190    #[test]
191    fn l2_distance_basic() {
192        let a = [0.0_f32, 0.0];
193        let b = [3.0_f32, 4.0];
194        let v = vector_similarity(&a, &b, VectorMetric::L2).unwrap();
195        assert!((v - 5.0).abs() < 1e-6);
196    }
197
198    #[test]
199    fn inner_product_basic() {
200        let a = [1.0_f32, 2.0, 3.0];
201        let b = [4.0_f32, 5.0, 6.0];
202        let v = vector_similarity(&a, &b, VectorMetric::Inner).unwrap();
203        assert!((v - 32.0).abs() < 1e-6);
204    }
205
206    #[test]
207    fn dimension_mismatch_rejected() {
208        let a = [1.0_f32, 2.0];
209        let b = [1.0_f32, 2.0, 3.0];
210        let err = vector_similarity(&a, &b, VectorMetric::L2).unwrap_err();
211        assert!(matches!(
212            err,
213            VectorError::DimensionMismatch {
214                expected: 2,
215                actual: 3
216            }
217        ));
218    }
219
220    #[test]
221    fn vector_distance_alias() {
222        let a = [1.0_f32, 2.0];
223        let b = [3.0_f32, 4.0];
224        let sim = vector_similarity(&a, &b, VectorMetric::Inner).unwrap();
225        let dist = vector_distance(&a, &b, VectorMetric::Inner).unwrap();
226        assert!((sim - dist).abs() < 1e-6);
227    }
228}