use crate::{
error::{VisionError, VisionResult},
handle::LcgRng,
};
pub struct ProjectionWeights {
pub weight: Vec<f32>,
pub bias: Vec<f32>,
}
impl ProjectionWeights {
pub fn default_init(embed_dim: usize, proj_dim: usize, rng: &mut LcgRng) -> VisionResult<Self> {
if embed_dim == 0 {
return Err(VisionError::InvalidEmbedDim(embed_dim));
}
if proj_dim == 0 {
return Err(VisionError::InvalidProjDim(proj_dim));
}
let scale = 1.0 / (embed_dim as f32).sqrt();
let mut weight = vec![0.0f32; proj_dim * embed_dim];
rng.fill_normal(&mut weight);
for v in &mut weight {
*v *= scale;
}
let mut bias = vec![0.0f32; proj_dim];
rng.fill_normal(&mut bias);
for v in &mut bias {
*v *= scale;
}
Ok(Self { weight, bias })
}
}
pub struct ProjectionHead {
pub embed_dim: usize,
pub proj_dim: usize,
pub weights: ProjectionWeights,
}
impl ProjectionHead {
pub fn new(embed_dim: usize, proj_dim: usize, rng: &mut LcgRng) -> VisionResult<Self> {
let weights = ProjectionWeights::default_init(embed_dim, proj_dim, rng)?;
Ok(Self {
embed_dim,
proj_dim,
weights,
})
}
pub fn project(&self, x: &[f32]) -> VisionResult<Vec<f32>> {
if x.len() != self.embed_dim {
return Err(VisionError::DimensionMismatch {
expected: self.embed_dim,
got: x.len(),
});
}
let mut z = vec![0.0f32; self.proj_dim];
for (p, zp) in z.iter_mut().enumerate() {
let row_off = p * self.embed_dim;
let acc: f32 = self.weights.weight[row_off..row_off + self.embed_dim]
.iter()
.zip(x.iter())
.map(|(&w, &xi)| w * xi)
.sum::<f32>()
+ self.weights.bias[p];
*zp = acc;
}
let norm: f32 = z.iter().map(|&v| v * v).sum::<f32>().sqrt();
let inv_norm = 1.0 / norm.max(1e-12);
for v in &mut z {
*v *= inv_norm;
}
Ok(z)
}
pub fn project_batch(&self, x: &[f32], batch: usize) -> VisionResult<Vec<f32>> {
let expected = batch * self.embed_dim;
if x.len() != expected {
return Err(VisionError::DimensionMismatch {
expected,
got: x.len(),
});
}
let mut out = vec![0.0f32; batch * self.proj_dim];
for b in 0..batch {
let x_slice = &x[b * self.embed_dim..(b + 1) * self.embed_dim];
let z = self.project(x_slice)?;
let out_off = b * self.proj_dim;
out[out_off..out_off + self.proj_dim].copy_from_slice(&z);
}
Ok(out)
}
pub fn cosine_sim(a: &[f32], b: &[f32]) -> VisionResult<f32> {
if a.len() != b.len() {
return Err(VisionError::DimensionMismatch {
expected: a.len(),
got: b.len(),
});
}
let dot: f32 = a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum();
let norm_a: f32 = a.iter().map(|&v| v * v).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|&v| v * v).sum::<f32>().sqrt();
let denom = norm_a * norm_b + 1e-12;
Ok(dot / denom)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handle::LcgRng;
fn make_head(embed_dim: usize, proj_dim: usize, seed: u64) -> ProjectionHead {
let mut rng = LcgRng::new(seed);
ProjectionHead::new(embed_dim, proj_dim, &mut rng).expect("valid head")
}
fn random_vec(len: usize, seed: u64) -> Vec<f32> {
let mut rng = LcgRng::new(seed);
let mut v = vec![0.0f32; len];
rng.fill_normal(&mut v);
v
}
#[test]
fn weights_correct_sizes() {
let mut rng = LcgRng::new(1);
let w = ProjectionWeights::default_init(64, 128, &mut rng).expect("ok");
assert_eq!(w.weight.len(), 128 * 64, "weight size mismatch");
assert_eq!(w.bias.len(), 128, "bias size mismatch");
}
#[test]
fn weights_finite_values() {
let mut rng = LcgRng::new(2);
let w = ProjectionWeights::default_init(64, 128, &mut rng).expect("ok");
assert!(w.weight.iter().all(|v| v.is_finite()), "non-finite weights");
assert!(w.bias.iter().all(|v| v.is_finite()), "non-finite bias");
}
#[test]
fn weights_error_zero_embed_dim() {
let mut rng = LcgRng::new(3);
let r = ProjectionWeights::default_init(0, 64, &mut rng);
assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
}
#[test]
fn weights_error_zero_proj_dim() {
let mut rng = LcgRng::new(4);
let r = ProjectionWeights::default_init(64, 0, &mut rng);
assert!(matches!(r, Err(VisionError::InvalidProjDim(0))));
}
#[test]
fn project_output_l2_norm_approx_one() {
let head = make_head(64, 128, 10);
let x = random_vec(64, 11);
let z = head.project(&x).expect("project ok");
let norm: f32 = z.iter().map(|&v| v * v).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-5,
"L2 norm of projected embedding should be ≈1.0, got {norm}"
);
}
#[test]
fn project_output_size() {
let head = make_head(32, 64, 12);
let x = random_vec(32, 13);
let z = head.project(&x).expect("project ok");
assert_eq!(z.len(), 64, "output size should be proj_dim");
}
#[test]
fn project_output_finite() {
let head = make_head(128, 64, 14);
let x = random_vec(128, 15);
let z = head.project(&x).expect("project ok");
assert!(z.iter().all(|v| v.is_finite()), "output must be finite");
}
#[test]
fn project_error_wrong_input_size() {
let head = make_head(64, 128, 16);
let x = random_vec(32, 17); let r = head.project(&x);
assert!(
matches!(
r,
Err(VisionError::DimensionMismatch {
expected: 64,
got: 32
})
),
"expected DimensionMismatch(64, 32), got {:?}",
r
);
}
#[test]
fn project_zero_input_normalises() {
let head = make_head(16, 32, 18);
let x = vec![0.0f32; 16];
let z = head.project(&x).expect("ok");
let norm: f32 = z.iter().map(|&v| v * v).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-5,
"zero-input projection should still yield unit norm, got {norm}"
);
}
#[test]
fn project_batch_output_size() {
let head = make_head(32, 64, 20);
let x = random_vec(4 * 32, 21); let z = head.project_batch(&x, 4).expect("batch project ok");
assert_eq!(z.len(), 4 * 64, "batch output size mismatch");
}
#[test]
fn project_batch_each_row_unit_norm() {
let head = make_head(32, 64, 22);
let x = random_vec(8 * 32, 23); let z = head.project_batch(&x, 8).expect("ok");
for i in 0..8 {
let row = &z[i * 64..(i + 1) * 64];
let norm: f32 = row.iter().map(|&v| v * v).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-5,
"batch row {i} norm = {norm}, expected 1.0"
);
}
}
#[test]
fn project_batch_matches_individual() {
let head = make_head(16, 32, 24);
let x_all = random_vec(3 * 16, 25);
let z_batch = head.project_batch(&x_all, 3).expect("batch ok");
for i in 0..3 {
let xi = &x_all[i * 16..(i + 1) * 16];
let z_single = head.project(xi).expect("single ok");
let z_batch_row = &z_batch[i * 32..(i + 1) * 32];
for (j, (&a, &b)) in z_single.iter().zip(z_batch_row.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-6,
"batch vs single at [{i},{j}]: {a} ≠ {b}"
);
}
}
}
#[test]
fn project_batch_error_wrong_total_length() {
let head = make_head(32, 64, 26);
let x = random_vec(3 * 32 + 5, 27); let r = head.project_batch(&x, 3);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn cosine_sim_unit_vector_with_self_is_one() {
let head = make_head(32, 32, 30);
let x = random_vec(32, 31);
let z = head.project(&x).expect("ok");
let sim = ProjectionHead::cosine_sim(&z, &z).expect("cosine ok");
assert!(
(sim - 1.0).abs() < 1e-5,
"cosine(v, v) should be ≈1.0 for unit-norm v, got {sim}"
);
}
#[test]
fn cosine_sim_orthogonal_vectors() {
let a = vec![1.0f32, 0.0, 0.0, 0.0];
let b = vec![0.0f32, 1.0, 0.0, 0.0];
let sim = ProjectionHead::cosine_sim(&a, &b).expect("ok");
assert!(
sim.abs() < 1e-6,
"cosine similarity of orthogonal vectors should be ≈0, got {sim}"
);
}
#[test]
fn cosine_sim_opposite_vectors() {
let a = vec![1.0f32, 0.0, 0.0];
let b = vec![-1.0f32, 0.0, 0.0];
let sim = ProjectionHead::cosine_sim(&a, &b).expect("ok");
assert!(
(sim + 1.0).abs() < 1e-5,
"cosine similarity of opposite vectors should be ≈-1, got {sim}"
);
}
#[test]
fn cosine_sim_range() {
let mut rng = LcgRng::new(40);
for _ in 0..50 {
let mut a = vec![0.0f32; 64];
let mut b = vec![0.0f32; 64];
rng.fill_normal(&mut a);
rng.fill_normal(&mut b);
let sim = ProjectionHead::cosine_sim(&a, &b).expect("ok");
assert!(
(-1.0 - 1e-5..=1.0 + 1e-5).contains(&sim),
"cosine sim out of [-1,1]: {sim}"
);
}
}
#[test]
fn cosine_sim_error_length_mismatch() {
let a = vec![1.0f32; 4];
let b = vec![1.0f32; 8];
let r = ProjectionHead::cosine_sim(&a, &b);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
}