use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
#[derive(Debug, Clone)]
pub struct LinearCondField {
pub(crate) w: Array2<f32>,
}
impl LinearCondField {
pub fn new_zeros(d: usize) -> Self {
Self {
w: Array2::zeros((d, 2 * d + 2)),
}
}
pub fn w(&self) -> &Array2<f32> {
&self.w
}
pub(crate) fn d(&self) -> usize {
self.w.nrows()
}
pub fn eval(
&self,
x: &ArrayView1<f32>,
t: f32,
y: &ArrayView1<f32>,
) -> crate::Result<Array1<f32>> {
let d = self.d();
if x.len() != d || y.len() != d {
return Err(crate::Error::Shape("x and y must have length d"));
}
let mut out = Array1::<f32>::zeros(d);
for i in 0..d {
let mut s = 0.0f32;
for k in 0..d {
s += self.w[[i, k]] * x[k];
}
for k in 0..d {
s += self.w[[i, d + k]] * y[k];
}
s += self.w[[i, 2 * d]] * t;
s += self.w[[i, 2 * d + 1]] * 1.0;
out[i] = s;
}
Ok(out)
}
pub fn sgd_step(
&mut self,
x: &ArrayView1<f32>,
t: f32,
y: &ArrayView1<f32>,
u: &ArrayView1<f32>,
lr: f32,
) -> crate::Result<()> {
let d = self.d();
if x.len() != d || y.len() != d || u.len() != d {
return Err(crate::Error::Shape("x, y, and u must have length d"));
}
let pred = self.eval(x, t, y)?;
let mut r = Array1::<f32>::zeros(d);
for i in 0..d {
r[i] = pred[i] - u[i];
}
for i in 0..d {
let ri = r[i];
for k in 0..d {
self.w[[i, k]] -= lr * ri * x[k];
}
for k in 0..d {
self.w[[i, d + k]] -= lr * ri * y[k];
}
self.w[[i, 2 * d]] -= lr * ri * t;
self.w[[i, 2 * d + 1]] -= lr * ri * 1.0;
}
Ok(())
}
pub fn mse_batch(
&self,
xs: &ArrayView2<f32>,
ts: &[f32],
ys: &ArrayView2<f32>,
us: &ArrayView2<f32>,
) -> crate::Result<f32> {
let n = xs.nrows();
let d = xs.ncols();
if ys.nrows() != n || ys.ncols() != d || us.nrows() != n || us.ncols() != d || ts.len() != n
{
return Err(crate::Error::Shape("batch dimensions must agree"));
}
let mut s: f64 = 0.0;
for i in 0..n {
let pred = self.eval(&xs.row(i), ts[i], &ys.row(i))?;
for k in 0..d {
let r = (pred[k] - us[[i, k]]) as f64;
s += r * r;
}
}
Ok((s / (n as f64 * d as f64)) as f32)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zero_field_evals_to_zero() {
let f = LinearCondField::new_zeros(3);
let x = Array1::from_vec(vec![1.0f32, 2.0, 3.0]);
let y = Array1::from_vec(vec![4.0f32, 5.0, 6.0]);
let out = f.eval(&x.view(), 0.5, &y.view()).unwrap();
for i in 0..3 {
assert_eq!(
out[i], 0.0,
"zero field must produce zero output at dim {i}"
);
}
}
#[test]
fn eval_matches_manual_1d() {
let mut f = LinearCondField::new_zeros(1);
f.w[[0, 0]] = 2.0; f.w[[0, 1]] = 3.0; f.w[[0, 2]] = 4.0; f.w[[0, 3]] = 5.0;
let x = Array1::from_vec(vec![1.0f32]);
let y = Array1::from_vec(vec![2.0f32]);
let t = 0.5f32;
let out = f.eval(&x.view(), t, &y.view()).unwrap();
assert!((out[0] - 15.0).abs() < 1e-6, "got {}", out[0]);
}
#[test]
fn sgd_step_reduces_loss_on_constant_target() {
let mut f = LinearCondField::new_zeros(2);
let x = Array1::from_vec(vec![1.0f32, 0.0]);
let y = Array1::from_vec(vec![0.0f32, 1.0]);
let u = Array1::from_vec(vec![3.0f32, -2.0]); let t = 0.5f32;
let lr = 0.01f32;
let pred_before = f.eval(&x.view(), t, &y.view()).unwrap();
let loss_before: f32 = pred_before
.iter()
.zip(u.iter())
.map(|(p, t)| (p - t).powi(2))
.sum();
for _ in 0..100 {
f.sgd_step(&x.view(), t, &y.view(), &u.view(), lr).unwrap();
}
let pred_after = f.eval(&x.view(), t, &y.view()).unwrap();
let loss_after: f32 = pred_after
.iter()
.zip(u.iter())
.map(|(p, t)| (p - t).powi(2))
.sum();
assert!(
loss_after < loss_before * 0.01,
"SGD should reduce loss substantially: before={loss_before} after={loss_after}"
);
}
#[test]
fn mse_batch_zero_field_equals_target_norm() {
let f = LinearCondField::new_zeros(2);
let xs = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).unwrap();
let ys = Array2::from_shape_vec((2, 2), vec![0.0, 1.0, 1.0, 0.0]).unwrap();
let us = Array2::from_shape_vec((2, 2), vec![3.0, 4.0, 1.0, 2.0]).unwrap();
let ts = vec![0.3f32, 0.7];
let mse = f
.mse_batch(&xs.view(), &ts, &ys.view(), &us.view())
.unwrap();
assert!((mse - 7.5).abs() < 1e-5, "expected 7.5, got {mse}");
}
}