use crate::ai::vector_store::{
InMemoryVectorStore, SimilarityMetric, VectorStore, VectorStoreConfig,
};
use anyhow::{anyhow, Result};
use scirs2_core::ndarray_ext::ArrayView1;
use std::sync::Arc;
pub fn compute_similarity(a: &[f32], b: &[f32], metric: SimilarityMetric) -> Result<f32> {
if a.len() != b.len() {
return Err(anyhow!("Vector dimension mismatch"));
}
let a_arr = ArrayView1::from(a);
let b_arr = ArrayView1::from(b);
match metric {
SimilarityMetric::Cosine => {
let dot_product = a_arr.dot(&b_arr);
let norm_a = a_arr.dot(&a_arr).sqrt();
let norm_b = b_arr.dot(&b_arr).sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
Ok(0.0)
} else {
Ok(dot_product / (norm_a * norm_b))
}
}
SimilarityMetric::Euclidean => {
let diff = &a_arr - &b_arr;
let distance = diff.dot(&diff).sqrt();
Ok(1.0 / (1.0 + distance))
}
SimilarityMetric::Manhattan => {
let diff = &a_arr - &b_arr;
let distance = diff.mapv(f32::abs).sum();
Ok(1.0 / (1.0 + distance))
}
SimilarityMetric::DotProduct => Ok(a_arr.dot(&b_arr)),
SimilarityMetric::Jaccard => {
let a_binary = a_arr.mapv(|x| if x > 0.0 { 1u32 } else { 0 });
let b_binary = b_arr.mapv(|x| if x > 0.0 { 1u32 } else { 0 });
let intersection: u32 = (&a_binary * &b_binary).sum();
let union: u32 = a_binary
.iter()
.zip(b_binary.iter())
.map(|(x, y)| if *x > 0 || *y > 0 { 1 } else { 0 })
.sum();
if union == 0 {
Ok(0.0)
} else {
Ok(intersection as f32 / union as f32)
}
}
SimilarityMetric::Hamming => {
let a_binary = a_arr.mapv(|x| if x > 0.0 { 1u32 } else { 0 });
let b_binary = b_arr.mapv(|x| if x > 0.0 { 1u32 } else { 0 });
let differences: u32 = a_binary
.iter()
.zip(b_binary.iter())
.map(|(x, y)| if x != y { 1 } else { 0 })
.sum();
Ok(1.0 - (differences as f32 / a.len() as f32))
}
}
}
pub fn compute_similarities_batch(
query: &[f32],
candidates: &[&[f32]],
metric: SimilarityMetric,
) -> Result<Vec<f32>> {
if candidates.is_empty() {
return Ok(Vec::new());
}
for candidate in candidates {
if candidate.len() != query.len() {
return Err(anyhow!("Vector dimension mismatch in batch"));
}
}
let query_arr = ArrayView1::from(query);
let query_norm = match metric {
SimilarityMetric::Cosine => {
let norm = query_arr.dot(&query_arr).sqrt();
if norm == 0.0 {
return Ok(vec![0.0; candidates.len()]);
}
norm
}
_ => 1.0,
};
if candidates.len() > 100 {
use rayon::prelude::*;
let results: Vec<f32> = candidates
.par_iter()
.map(|candidate| {
let c_arr = ArrayView1::from(*candidate);
match metric {
SimilarityMetric::Cosine => {
let dot = query_arr.dot(&c_arr);
let c_norm = c_arr.dot(&c_arr).sqrt();
if c_norm == 0.0 {
0.0
} else {
dot / (query_norm * c_norm)
}
}
SimilarityMetric::Euclidean => {
let diff = &query_arr - &c_arr;
let dist = diff.dot(&diff).sqrt();
1.0 / (1.0 + dist)
}
SimilarityMetric::Manhattan => {
let diff = &query_arr - &c_arr;
let dist = diff.mapv(f32::abs).sum();
1.0 / (1.0 + dist)
}
SimilarityMetric::DotProduct => query_arr.dot(&c_arr),
_ => compute_similarity(query, candidate, metric).unwrap_or(0.0),
}
})
.collect();
Ok(results)
} else {
candidates
.iter()
.map(|candidate| compute_similarity(query, candidate, metric))
.collect()
}
}
pub fn create_vector_store(config: &VectorStoreConfig) -> Result<Arc<dyn VectorStore>> {
Ok(Arc::new(InMemoryVectorStore::new(config.clone())))
}