use crate::simd;
use crate::RetrieveError;
pub struct DimensionalityReducer {
original_dim: usize,
intermediate_dim: usize,
projection_matrix: Option<Vec<Vec<f32>>>,
}
impl DimensionalityReducer {
pub fn new(original_dim: usize, intermediate_dim: usize) -> Result<Self, RetrieveError> {
if original_dim == 0 || intermediate_dim == 0 {
return Err(RetrieveError::Other(
"Dimensions must be greater than 0".to_string(),
));
}
if intermediate_dim >= original_dim {
return Err(RetrieveError::Other(
"Intermediate dimension must be less than original".to_string(),
));
}
Ok(Self {
original_dim,
intermediate_dim,
projection_matrix: None,
})
}
pub fn fit(&mut self, vectors: &[f32], num_vectors: usize) -> Result<(), RetrieveError> {
if vectors.len() < num_vectors * self.original_dim {
return Err(RetrieveError::Other("Insufficient vectors".to_string()));
}
self.projection_matrix = Some(self.compute_projection(vectors, num_vectors)?);
Ok(())
}
pub fn transform(
&self,
vectors: &[f32],
num_vectors: usize,
) -> Result<Vec<f32>, RetrieveError> {
let matrix = self
.projection_matrix
.as_ref()
.ok_or_else(|| RetrieveError::Other("Reducer not fitted".to_string()))?;
let mut reduced = Vec::with_capacity(num_vectors * self.intermediate_dim);
for i in 0..num_vectors {
let vec = self.get_vector(vectors, i);
let mut reduced_vec = vec![0.0f32; self.intermediate_dim];
for (j, row) in matrix.iter().enumerate() {
reduced_vec[j] = simd::dot(vec, row);
}
reduced.extend_from_slice(&reduced_vec);
}
Ok(reduced)
}
fn compute_projection(
&self,
vectors: &[f32],
num_vectors: usize,
) -> Result<Vec<Vec<f32>>, RetrieveError> {
let mut mean = vec![0.0f32; self.original_dim];
for i in 0..num_vectors {
let vec = self.get_vector(vectors, i);
for (j, &val) in vec.iter().enumerate() {
mean[j] += val;
}
}
for val in mean.iter_mut() {
*val /= num_vectors as f32;
}
let mut projection = Vec::new();
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..self.intermediate_dim {
let mut row = Vec::with_capacity(self.original_dim);
for _ in 0..self.original_dim {
row.push(rng.random::<f32>() * 2.0 - 1.0);
}
let norm: f32 = row.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in row.iter_mut() {
*val /= norm;
}
}
projection.push(row);
}
Ok(projection)
}
fn get_vector<'a>(&self, vectors: &'a [f32], idx: usize) -> &'a [f32] {
let start = idx * self.original_dim;
let end = start + self.original_dim;
&vectors[start..end]
}
}