use crate::network::LearnedManifold;
use crate::simd_ops::{cosine_similarity_simd, euclidean_distance_simd};
use exo_core::{ManifoldConfig, Pattern, Result, SearchResult};
use parking_lot::RwLock;
use std::sync::Arc;
pub struct GradientDescentRetriever {
_network: Arc<RwLock<LearnedManifold>>,
_config: ManifoldConfig,
}
impl GradientDescentRetriever {
pub fn new(network: Arc<RwLock<LearnedManifold>>, config: ManifoldConfig) -> Self {
Self {
_network: network,
_config: config,
}
}
pub fn retrieve(
&self,
query: &[f32],
k: usize,
patterns: &Arc<RwLock<Vec<Pattern>>>,
) -> Result<Vec<SearchResult>> {
let patterns = patterns.read();
let mut results = Vec::with_capacity(patterns.len());
for pattern in patterns.iter() {
let similarity = cosine_similarity_simd(query, &pattern.embedding);
let distance = euclidean_distance_simd(query, &pattern.embedding);
results.push(SearchResult {
pattern: pattern.clone(),
score: similarity,
distance,
});
}
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
results.truncate(k);
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::simd_ops::{cosine_similarity_simd, euclidean_distance_simd};
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0, 0.0];
assert!((cosine_similarity_simd(&a, &b) - 1.0).abs() < 1e-5);
let c = vec![1.0, 0.0, 0.0, 0.0];
let d = vec![0.0, 1.0, 0.0, 0.0];
assert!((cosine_similarity_simd(&c, &d) - 0.0).abs() < 1e-5);
}
#[test]
fn test_euclidean_distance() {
let a = vec![0.0, 0.0, 0.0, 0.0];
let b = vec![3.0, 4.0, 0.0, 0.0];
assert!((euclidean_distance_simd(&a, &b) - 5.0).abs() < 1e-5);
}
}