use scirs2_core::ndarray::{Array1, Array2, Axis};
use crate::error::TextError;
pub type ProjResult<T> = Result<T, TextError>;
#[derive(Debug, Clone)]
pub struct ProjectionConfig {
pub d_in: usize,
pub d_hidden: usize,
pub d_out: usize,
pub dropout_rate: f32,
pub learning_rate: f32,
}
impl Default for ProjectionConfig {
fn default() -> Self {
ProjectionConfig {
d_in: 768,
d_hidden: 768,
d_out: 768,
dropout_rate: 0.1,
learning_rate: 1e-4,
}
}
}
pub struct DifferentiableProjection {
config: ProjectionConfig,
w1: Array2<f32>,
b1: Array2<f32>,
w2: Array2<f32>,
b2: Array2<f32>,
rng_state: u64,
steps: u64,
}
impl DifferentiableProjection {
pub fn new(config: ProjectionConfig) -> Self {
let d_in = config.d_in;
let d_h = config.d_hidden;
let d_out = config.d_out;
let w1 = glorot_uniform(d_in, d_h, 42);
let b1 = Array2::zeros((1, d_h));
let w2 = glorot_uniform(d_h, d_out, 137);
let b2 = Array2::zeros((1, d_out));
DifferentiableProjection {
config,
w1,
b1,
w2,
b2,
rng_state: 0xDEAD_BEEF_1234_5678,
steps: 0,
}
}
pub fn update_step(&mut self, embeddings: &Array2<f32>, temperature: f32) -> ProjResult<f32> {
let n = embeddings.nrows();
if n == 0 {
return Ok(0.0);
}
if embeddings.ncols() != self.config.d_in {
return Err(TextError::InvalidInput(format!(
"Expected d_in={}, got {}",
self.config.d_in,
embeddings.ncols()
)));
}
let (ha, cache_a) = self.forward_train(embeddings);
let (hb, cache_b) = self.forward_train(embeddings);
let (loss, d_logits) = infonce_loss_and_grad(&ha, &hb, temperature);
let inv_tau = 1.0_f32 / temperature;
let d_ha = d_logits.dot(&hb) * inv_tau; let d_hb = d_logits.t().dot(&ha) * inv_tau;
let (dw1_a, db1_a, dw2_a, db2_a) = self.backward(&d_ha, &cache_a, embeddings);
let (dw1_b, db1_b, dw2_b, db2_b) = self.backward(&d_hb, &cache_b, embeddings);
let lr = self.config.learning_rate;
let inv_two = 0.5_f32;
self.w1 = &self.w1 - &((&dw1_a + &dw1_b) * (lr * inv_two));
self.b1 = &self.b1 - &((&db1_a + &db1_b) * (lr * inv_two));
self.w2 = &self.w2 - &((&dw2_a + &dw2_b) * (lr * inv_two));
self.b2 = &self.b2 - &((&db2_a + &db2_b) * (lr * inv_two));
self.steps += 1;
Ok(loss)
}
pub fn forward_inference(&self, embeddings: &Array2<f32>) -> ProjResult<Array2<f32>> {
if embeddings.ncols() != self.config.d_in {
return Err(TextError::InvalidInput(format!(
"Expected d_in={}, got {}",
self.config.d_in,
embeddings.ncols()
)));
}
let z1 = embeddings.dot(&self.w1) + &self.b1;
let h1 = z1.mapv(f32::tanh);
let z2 = h1.dot(&self.w2) + &self.b2;
let output = z2.mapv(f32::tanh);
Ok(output)
}
pub fn steps(&self) -> u64 {
self.steps
}
pub fn config(&self) -> &ProjectionConfig {
&self.config
}
fn forward_train(&mut self, x: &Array2<f32>) -> (Array2<f32>, ForwardCache) {
let rate = self.config.dropout_rate;
let scale = if rate < 1.0 { 1.0 / (1.0 - rate) } else { 1.0 };
let z1 = x.dot(&self.w1) + &self.b1;
let mask1 = self.bernoulli_mask(z1.nrows(), z1.ncols(), rate, scale);
let h1 = z1.mapv(f32::tanh) * &mask1;
let z2 = h1.dot(&self.w2) + &self.b2;
let mask2 = self.bernoulli_mask(z2.nrows(), z2.ncols(), rate, scale);
let output = z2.mapv(f32::tanh) * &mask2;
let cache = ForwardCache {
h1,
z1,
mask1,
z2,
mask2,
};
(output, cache)
}
fn backward(
&self,
d_out: &Array2<f32>,
cache: &ForwardCache,
x: &Array2<f32>,
) -> (Array2<f32>, Array2<f32>, Array2<f32>, Array2<f32>) {
let d_h2 = d_out * &cache.mask2;
let tanh_z2 = cache.z2.mapv(f32::tanh);
let d_z2 = &d_h2 * &(1.0 - &tanh_z2.mapv(|v| v * v));
let dw2 = cache.h1.t().dot(&d_z2);
let db2 = d_z2.sum_axis(Axis(0)).insert_axis(Axis(0));
let d_h1_raw = d_z2.dot(&self.w2.t());
let d_h1 = d_h1_raw * &cache.mask1;
let tanh_z1 = cache.z1.mapv(f32::tanh);
let d_z1 = &d_h1 * &(1.0 - &tanh_z1.mapv(|v| v * v));
let dw1 = x.t().dot(&d_z1);
let db1 = d_z1.sum_axis(Axis(0)).insert_axis(Axis(0));
(dw1, db1, dw2, db2)
}
fn bernoulli_mask(&mut self, rows: usize, cols: usize, rate: f32, scale: f32) -> Array2<f32> {
Array2::from_shape_fn((rows, cols), |_| {
self.rng_state = self
.rng_state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let u = (self.rng_state >> 33) as f32 / u32::MAX as f32;
if u < rate {
0.0
} else {
scale
}
})
}
}
struct ForwardCache {
h1: Array2<f32>,
z1: Array2<f32>,
mask1: Array2<f32>,
z2: Array2<f32>,
mask2: Array2<f32>,
}
fn infonce_loss_and_grad(
ha: &Array2<f32>,
hb: &Array2<f32>,
temperature: f32,
) -> (f32, Array2<f32>) {
let n = ha.nrows();
let inv_n = 1.0_f32 / n as f32;
let inv_tau = 1.0_f32 / temperature;
let sim = ha.dot(&hb.t()); let logits = sim.mapv(|v| v * inv_tau);
let mut softmax = Array2::<f32>::zeros((n, n));
let mut total_loss = 0.0_f32;
for i in 0..n {
let row = logits.row(i);
let max_val = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = row.iter().map(|&v| (v - max_val).exp()).collect();
let sum_exp: f32 = exps.iter().sum();
let log_sum_exp = sum_exp.ln() + max_val;
let pos_logit = logits[[i, i]];
total_loss += -(pos_logit - log_sum_exp);
for j in 0..n {
softmax[[i, j]] = exps[j] / sum_exp.max(1e-30);
}
}
let loss = total_loss * inv_n;
let mut d_logits = softmax;
for i in 0..n {
d_logits[[i, i]] -= 1.0;
}
d_logits.mapv_inplace(|v| v * inv_n);
(loss, d_logits)
}
fn glorot_uniform(fan_in: usize, fan_out: usize, seed: u64) -> Array2<f32> {
let limit = (6.0_f32 / (fan_in + fan_out) as f32).sqrt();
let mut state = seed;
Array2::from_shape_fn((fan_in, fan_out), |_| {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let u = (state >> 12) as f64 / (1u64 << 52) as f64;
(u as f32 * 2.0 - 1.0) * limit
})
}
impl std::fmt::Debug for DifferentiableProjection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DifferentiableProjection")
.field("d_in", &self.config.d_in)
.field("d_hidden", &self.config.d_hidden)
.field("d_out", &self.config.d_out)
.field("steps", &self.steps)
.finish()
}
}