use std::sync::Arc;
#[cfg(not(target_arch = "wasm32"))]
use rayon::prelude::*;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::error::{LaurusError, Result};
#[derive(Debug, Clone, PartialEq)]
pub struct Vector {
pub data: Arc<Vec<f32>>,
}
impl Serialize for Vector {
fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
self.data.as_slice().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Vector {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
let data = Vec::<f32>::deserialize(deserializer)?;
Ok(Vector {
data: Arc::new(data),
})
}
}
impl Vector {
pub fn new(data: Vec<f32>) -> Self {
Self {
data: Arc::new(data),
}
}
pub fn dimension(&self) -> usize {
self.data.len()
}
pub fn norm(&self) -> f32 {
self.data.iter().map(|x| x * x).sum::<f32>().sqrt()
}
pub fn normalize(&mut self) {
let norm = self.norm();
if norm > 0.0 {
for value in Arc::make_mut(&mut self.data) {
*value /= norm;
}
}
}
pub fn normalized(&self) -> Self {
let mut normalized = self.clone();
normalized.normalize();
normalized
}
pub fn validate_dimension(&self, expected_dim: usize) -> Result<()> {
if self.data.len() != expected_dim {
return Err(LaurusError::InvalidOperation(format!(
"Vector dimension mismatch: expected {}, got {}",
expected_dim,
self.data.len()
)));
}
Ok(())
}
pub fn is_valid(&self) -> bool {
self.data.iter().all(|x| x.is_finite())
}
pub fn norm_parallel(&self) -> f32 {
#[cfg(not(target_arch = "wasm32"))]
if self.data.len() > 10000 {
return self.data.par_iter().map(|x| x * x).sum::<f32>().sqrt();
}
self.norm()
}
pub fn normalize_parallel(&mut self) {
let norm = self.norm_parallel();
if norm > 0.0 {
let data = Arc::make_mut(&mut self.data);
#[cfg(not(target_arch = "wasm32"))]
if data.len() > 10000 {
data.par_iter_mut().for_each(|value| *value /= norm);
return;
}
for value in data.iter_mut() {
*value /= norm;
}
}
}
pub fn normalize_batch_parallel(vectors: &mut [Vector]) {
#[cfg(not(target_arch = "wasm32"))]
if vectors.len() > 10 {
vectors
.par_iter_mut()
.for_each(|vector| vector.normalize_parallel());
return;
}
for vector in vectors {
vector.normalize();
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredVector {
pub data: Vec<f32>,
pub weight: f32,
}
impl StoredVector {
pub fn new(data: Vec<f32>) -> Self {
Self { data, weight: 1.0 }
}
pub fn with_weight(mut self, weight: f32) -> Self {
self.weight = weight;
self
}
pub fn dimension(&self) -> usize {
self.data.len()
}
pub fn to_vector(&self) -> Vector {
Vector::new(self.data.clone())
}
}
impl From<Vector> for StoredVector {
fn from(vector: Vector) -> Self {
let data = Arc::try_unwrap(vector.data).unwrap_or_else(|arc| (*arc).clone());
Self { data, weight: 1.0 }
}
}