Skip to main content

alopex_sql/executor/evaluator/
vector_ops.rs

1use std::str::FromStr;
2
3use thiserror::Error;
4
5/// ベクトル演算で発生するエラー。
6#[derive(Debug, Error, Clone, PartialEq, Eq)]
7pub enum VectorError {
8    /// 引数数が想定と異なる。
9    #[error("argument count mismatch: expected 3, got {actual}")]
10    ArgumentCountMismatch { actual: usize },
11
12    /// 第一引数がベクトル列ではない。
13    #[error("type mismatch: first argument must be VECTOR column")]
14    TypeMismatch,
15
16    /// ベクトルリテラルが不正。
17    #[error("invalid vector literal: {reason}")]
18    InvalidVectorLiteral { reason: String },
19
20    /// メトリクス指定が不正。
21    #[error("invalid metric '{metric}': {reason}")]
22    InvalidMetric { metric: String, reason: String },
23
24    /// ベクトル次元が一致しない。
25    #[error("dimension mismatch: column has {expected} dimensions, query has {actual}")]
26    DimensionMismatch { expected: usize, actual: usize },
27
28    /// コサイン類似度でゼロノルムベクトルが渡された。
29    #[error("zero-norm vector cannot be used for cosine similarity")]
30    ZeroNormVector,
31}
32
33/// ベクトル類似度/距離のメトリクス。
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum VectorMetric {
36    Cosine,
37    L2,
38    Inner,
39}
40
41impl VectorMetric {
42    /// 文字列をメトリクスに変換する(前後空白除去・小文字化)。
43    pub fn parse(s: &str) -> Result<Self, VectorError> {
44        let normalized = s.trim().to_lowercase();
45        match normalized.as_str() {
46            "cosine" => Ok(Self::Cosine),
47            "l2" => Ok(Self::L2),
48            "inner" => Ok(Self::Inner),
49            "" => Err(VectorError::InvalidMetric {
50                metric: s.to_string(),
51                reason: "empty metric string".into(),
52            }),
53            _ => Err(VectorError::InvalidMetric {
54                metric: s.to_string(),
55                reason: format!("expected 'cosine', 'l2', or 'inner', got '{}'", normalized),
56            }),
57        }
58    }
59}
60
61impl FromStr for VectorMetric {
62    type Err = VectorError;
63
64    fn from_str(s: &str) -> Result<Self, Self::Err> {
65        VectorMetric::parse(s)
66    }
67}
68
69/// ベクトル類似度/距離を計算する。
70///
71/// - 内部計算: f32(メモリ効率)
72/// - 返却値: f64(SQL DOUBLE として返す想定)
73pub fn vector_similarity(
74    column_value: &[f32],
75    query_vector: &[f32],
76    metric: VectorMetric,
77) -> Result<f64, VectorError> {
78    validate_dimensions(column_value, query_vector)?;
79
80    match metric {
81        VectorMetric::Cosine => compute_cosine_similarity(column_value, query_vector),
82        VectorMetric::L2 => compute_l2_distance(column_value, query_vector),
83        VectorMetric::Inner => compute_inner_product(column_value, query_vector),
84    }
85}
86
87/// vector_similarity のエイリアス。距離メトリクスでも同一の実装を利用する。
88pub fn vector_distance(
89    column_value: &[f32],
90    query_vector: &[f32],
91    metric: VectorMetric,
92) -> Result<f64, VectorError> {
93    vector_similarity(column_value, query_vector, metric)
94}
95
96pub fn vector_dims(vector: &[f32]) -> usize {
97    vector.len()
98}
99
100pub fn vector_norm(vector: &[f32]) -> f64 {
101    let sum_sq: f32 = vector.iter().map(|v| v * v).sum();
102    (sum_sq.sqrt()) as f64
103}
104
105fn validate_dimensions(a: &[f32], b: &[f32]) -> Result<(), VectorError> {
106    if a.len() != b.len() {
107        return Err(VectorError::DimensionMismatch {
108            expected: a.len(),
109            actual: b.len(),
110        });
111    }
112    Ok(())
113}
114
115fn compute_cosine_similarity(a: &[f32], b: &[f32]) -> Result<f64, VectorError> {
116    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
117    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
118    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
119
120    if norm_a == 0.0 || norm_b == 0.0 {
121        return Err(VectorError::ZeroNormVector);
122    }
123
124    Ok((dot / (norm_a * norm_b)) as f64)
125}
126
127fn compute_l2_distance(a: &[f32], b: &[f32]) -> Result<f64, VectorError> {
128    let sum_sq: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
129    Ok(sum_sq.sqrt() as f64)
130}
131
132fn compute_inner_product(a: &[f32], b: &[f32]) -> Result<f64, VectorError> {
133    let sum: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
134    Ok(sum as f64)
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn vector_metric_from_str_trims_and_lowercases() {
143        assert_eq!(
144            VectorMetric::parse(" COSINE ").unwrap(),
145            VectorMetric::Cosine
146        );
147        assert_eq!(VectorMetric::parse("l2").unwrap(), VectorMetric::L2);
148        assert_eq!(VectorMetric::parse("Inner").unwrap(), VectorMetric::Inner);
149    }
150
151    #[test]
152    fn vector_metric_from_str_empty_rejected() {
153        let err = VectorMetric::parse("").unwrap_err();
154        assert!(matches!(
155            err,
156            VectorError::InvalidMetric { reason, .. } if reason.contains("empty")
157        ));
158    }
159
160    #[test]
161    fn vector_metric_from_str_unknown_rejected() {
162        let err = VectorMetric::parse("minkowski").unwrap_err();
163        assert!(matches!(
164            err,
165            VectorError::InvalidMetric { reason, .. } if reason.contains("expected 'cosine', 'l2', or 'inner'")
166        ));
167    }
168
169    #[test]
170    fn vector_metric_from_str_trait_parse() {
171        let m: VectorMetric = "cosine".parse().unwrap();
172        assert_eq!(m, VectorMetric::Cosine);
173    }
174
175    #[test]
176    fn cosine_similarity_basic() {
177        let a = [1.0_f32, 0.0];
178        let b = [0.0_f32, 1.0];
179        let v = vector_similarity(&a, &b, VectorMetric::Cosine).unwrap();
180        assert!((v - 0.0).abs() < 1e-6);
181    }
182
183    #[test]
184    fn cosine_similarity_parallel() {
185        let a = [1.0_f32, 1.0];
186        let b = [2.0_f32, 2.0];
187        let v = vector_similarity(&a, &b, VectorMetric::Cosine).unwrap();
188        assert!((v - 1.0).abs() < 1e-6);
189    }
190
191    #[test]
192    fn cosine_similarity_zero_norm_error() {
193        let a = [0.0_f32, 0.0];
194        let b = [1.0_f32, 1.0];
195        let err = vector_similarity(&a, &b, VectorMetric::Cosine).unwrap_err();
196        assert!(matches!(err, VectorError::ZeroNormVector));
197    }
198
199    #[test]
200    fn l2_distance_basic() {
201        let a = [0.0_f32, 0.0];
202        let b = [3.0_f32, 4.0];
203        let v = vector_similarity(&a, &b, VectorMetric::L2).unwrap();
204        assert!((v - 5.0).abs() < 1e-6);
205    }
206
207    #[test]
208    fn inner_product_basic() {
209        let a = [1.0_f32, 2.0, 3.0];
210        let b = [4.0_f32, 5.0, 6.0];
211        let v = vector_similarity(&a, &b, VectorMetric::Inner).unwrap();
212        assert!((v - 32.0).abs() < 1e-6);
213    }
214
215    #[test]
216    fn dimension_mismatch_rejected() {
217        let a = [1.0_f32, 2.0];
218        let b = [1.0_f32, 2.0, 3.0];
219        let err = vector_similarity(&a, &b, VectorMetric::L2).unwrap_err();
220        assert!(matches!(
221            err,
222            VectorError::DimensionMismatch {
223                expected: 2,
224                actual: 3
225            }
226        ));
227    }
228
229    #[test]
230    fn vector_distance_alias() {
231        let a = [1.0_f32, 2.0];
232        let b = [3.0_f32, 4.0];
233        let sim = vector_similarity(&a, &b, VectorMetric::Inner).unwrap();
234        let dist = vector_distance(&a, &b, VectorMetric::Inner).unwrap();
235        assert!((sim - dist).abs() < 1e-6);
236    }
237}