use serde::{Deserialize, Serialize};
use crate::chunker::Chunk;
use crate::distance::Distance;
use crate::error::RagError;
use crate::metadata_filter::MetadataFilter;
#[inline]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
pub fn l2_normalize(v: &mut [f32]) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
#[inline]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.is_empty() || a.len() != b.len() {
return 0.0;
}
dot_product(a, b).clamp(-1.0, 1.0)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorEntry {
pub id: usize,
pub vector: Vec<f32>,
pub chunk: Chunk,
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub score: f32,
pub chunk: Chunk,
pub id: usize,
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct VectorStore {
entries: Vec<VectorEntry>,
dim: usize,
#[serde(default)]
distance: Distance,
}
impl VectorStore {
pub fn new(dim: usize) -> Self {
Self::new_with_distance(dim, Distance::default())
}
pub fn new_with_distance(dim: usize, distance: Distance) -> Self {
Self {
entries: Vec::new(),
dim,
distance,
}
}
pub fn insert(&mut self, mut vector: Vec<f32>, chunk: Chunk) -> Result<usize, RagError> {
if vector.len() != self.dim {
return Err(RagError::DimensionMismatch {
expected: self.dim,
got: vector.len(),
});
}
if vector.iter().any(|x| !x.is_finite()) {
return Err(RagError::NonFinite);
}
if matches!(
self.distance,
Distance::Cosine | Distance::DotProduct | Distance::Angular
) {
l2_normalize(&mut vector);
}
let id = self.entries.len();
self.entries.push(VectorEntry { id, vector, chunk });
Ok(id)
}
pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
self.search_with_threshold(query, top_k, f32::NEG_INFINITY)
}
pub fn search_with_threshold(
&self,
query: &[f32],
top_k: usize,
min_score: f32,
) -> Vec<SearchResult> {
self.scored(query, top_k, min_score, None)
}
pub fn search_filtered(
&self,
query: &[f32],
top_k: usize,
filter: &MetadataFilter,
) -> Result<Vec<SearchResult>, RagError> {
filter.validate()?;
Ok(self.scored(query, top_k, f32::NEG_INFINITY, Some(filter)))
}
fn scored(
&self,
query: &[f32],
top_k: usize,
min_score: f32,
filter: Option<&MetadataFilter>,
) -> Vec<SearchResult> {
if self.entries.is_empty() || top_k == 0 || query.len() != self.dim {
return Vec::new();
}
if query.iter().any(|x| !x.is_finite()) {
return Vec::new();
}
let prepared: Vec<f32> = if matches!(
self.distance,
Distance::Cosine | Distance::DotProduct | Distance::Angular
) {
let mut q = query.to_vec();
l2_normalize(&mut q);
q
} else {
query.to_vec()
};
let mut scored: Vec<(f32, usize)> = Vec::with_capacity(self.entries.len());
for entry in &self.entries {
if let Some(f) = filter {
if !f.matches(&entry.chunk.metadata) {
continue;
}
}
let raw = match self.distance.compute(&prepared, &entry.vector) {
Ok(v) => v,
Err(_) => continue,
};
let score = self.distance.to_score(raw);
if score >= min_score {
scored.push((score, entry.id));
}
}
scored.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
scored
.into_iter()
.map(|(score, id)| SearchResult {
score,
chunk: self.entries[id].chunk.clone(),
id,
})
.collect()
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn memory_usage_bytes(&self) -> usize {
self.entries.iter().fold(0usize, |acc, e| {
acc + e.vector.len() * std::mem::size_of::<f32>()
+ e.chunk.text.len()
+ std::mem::size_of::<VectorEntry>()
})
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn distance(&self) -> Distance {
self.distance
}
pub(crate) fn entries(&self) -> &[VectorEntry] {
&self.entries
}
pub(crate) fn set_entries(&mut self, entries: Vec<VectorEntry>) {
self.entries = entries;
}
}