use crate::error::{RecsysError, RecsysResult};
use crate::handle::LcgRng;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CrossKind {
Vector,
MatrixV2,
}
#[derive(Debug, Clone)]
pub struct DcnConfig {
pub input_dim: usize,
pub n_cross_layers: usize,
pub deep_dims: Vec<usize>,
pub kind: CrossKind,
}
pub struct Dcn {
config: DcnConfig,
cross_w: Vec<Vec<f32>>,
cross_b: Vec<Vec<f32>>,
deep_layers: Vec<(Vec<f32>, Vec<f32>)>,
head_w: Vec<f32>,
head_b: f32,
deep_out_dim: usize,
}
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
impl Dcn {
pub fn new(config: DcnConfig, rng: &mut LcgRng) -> RecsysResult<Self> {
let d = config.input_dim;
if d == 0 {
return Err(RecsysError::InvalidConfig {
msg: "input_dim must be > 0".to_string(),
});
}
for &h in &config.deep_dims {
if h == 0 {
return Err(RecsysError::InvalidEmbeddingDim { d: 0 });
}
}
let scale = (1.0 / d as f32).sqrt();
let mut cross_w = Vec::with_capacity(config.n_cross_layers);
let mut cross_b = Vec::with_capacity(config.n_cross_layers);
for _ in 0..config.n_cross_layers {
let w_len = match config.kind {
CrossKind::Vector => d,
CrossKind::MatrixV2 => d * d,
};
cross_w.push((0..w_len).map(|_| rng.next_normal() * scale).collect());
cross_b.push(vec![0.0_f32; d]);
}
let mut deep_layers = Vec::new();
let mut in_dim = d;
for &out_dim in &config.deep_dims {
let sc = (2.0 / in_dim as f32).sqrt();
let w: Vec<f32> = (0..out_dim * in_dim)
.map(|_| rng.next_normal() * sc)
.collect();
let b = vec![0.0_f32; out_dim];
deep_layers.push((w, b));
in_dim = out_dim;
}
let deep_out_dim = if config.deep_dims.is_empty() {
0
} else {
in_dim
};
let head_in = d + deep_out_dim;
let sc = (1.0 / head_in as f32).sqrt();
let head_w: Vec<f32> = (0..head_in).map(|_| rng.next_normal() * sc).collect();
Ok(Self {
config,
cross_w,
cross_b,
deep_layers,
head_w,
head_b: 0.0,
deep_out_dim,
})
}
pub fn cross_forward(&self, x: &[f32]) -> RecsysResult<Vec<f32>> {
let d = self.config.input_dim;
if x.len() != d {
return Err(RecsysError::DimensionMismatch {
expected: d,
got: x.len(),
});
}
let x0 = x;
let mut x_l = x.to_vec();
for (w, b) in self.cross_w.iter().zip(self.cross_b.iter()) {
let mut next = vec![0.0_f32; d];
match self.config.kind {
CrossKind::Vector => {
let s: f32 = x_l.iter().zip(w.iter()).map(|(&a, &c)| a * c).sum();
for i in 0..d {
next[i] = x0[i] * s + b[i] + x_l[i];
}
}
CrossKind::MatrixV2 => {
for i in 0..d {
let mut acc = b[i];
let row = &w[i * d..(i + 1) * d];
for (j, &rj) in row.iter().enumerate() {
acc += rj * x_l[j];
}
next[i] = x0[i] * acc + x_l[i];
}
}
}
x_l = next;
}
Ok(x_l)
}
pub fn deep_forward(&self, x: &[f32]) -> RecsysResult<Vec<f32>> {
let d = self.config.input_dim;
if x.len() != d {
return Err(RecsysError::DimensionMismatch {
expected: d,
got: x.len(),
});
}
if self.deep_layers.is_empty() {
return Ok(Vec::new());
}
let mut cur = x.to_vec();
let mut cur_dim = d;
for (w, b) in &self.deep_layers {
let out_dim = b.len();
let mut out = vec![0.0_f32; out_dim];
for o in 0..out_dim {
let mut acc = b[o];
let row = &w[o * cur_dim..(o + 1) * cur_dim];
for (j, &rj) in row.iter().enumerate() {
acc += rj * cur[j];
}
out[o] = acc.max(0.0);
}
cur = out;
cur_dim = out_dim;
}
Ok(cur)
}
pub fn forward(&self, x: &[f32]) -> RecsysResult<f32> {
let cross_out = self.cross_forward(x)?;
let deep_out = self.deep_forward(x)?;
debug_assert_eq!(deep_out.len(), self.deep_out_dim);
let mut logit = self.head_b;
for (i, &v) in cross_out.iter().enumerate() {
logit += self.head_w[i] * v;
}
let off = cross_out.len();
for (i, &v) in deep_out.iter().enumerate() {
logit += self.head_w[off + i] * v;
}
Ok(sigmoid(logit))
}
pub fn n_cross_layers(&self) -> usize {
self.config.n_cross_layers
}
pub fn deep_out_dim(&self) -> usize {
self.deep_out_dim
}
pub fn kind(&self) -> CrossKind {
self.config.kind
}
}
#[cfg(test)]
mod tests {
use super::*;
fn cfg(kind: CrossKind, n_cross: usize, deep: Vec<usize>) -> DcnConfig {
DcnConfig {
input_dim: 6,
n_cross_layers: n_cross,
deep_dims: deep,
kind,
}
}
#[test]
fn vector_forward_in_unit_interval() {
let mut rng = LcgRng::new(1);
let model =
Dcn::new(cfg(CrossKind::Vector, 2, vec![8, 4]), &mut rng).expect("model must build");
let x: Vec<f32> = (0..6).map(|_| rng.next_f32()).collect();
let p = model.forward(&x).expect("forward must succeed");
assert!((0.0..=1.0).contains(&p), "prob {p} not in [0,1]");
}
#[test]
fn matrix_v2_forward_in_unit_interval() {
let mut rng = LcgRng::new(2);
let model =
Dcn::new(cfg(CrossKind::MatrixV2, 3, vec![8, 4]), &mut rng).expect("model must build");
let x: Vec<f32> = (0..6).map(|_| rng.next_f32()).collect();
let p = model.forward(&x).expect("forward must succeed");
assert!((0.0..=1.0).contains(&p), "prob {p} not in [0,1]");
}
#[test]
fn cross_output_has_input_dim() {
let mut rng = LcgRng::new(3);
let model =
Dcn::new(cfg(CrossKind::Vector, 4, vec![]), &mut rng).expect("model must build");
let x = vec![0.5_f32; 6];
let out = model.cross_forward(&x).expect("cross must succeed");
assert_eq!(out.len(), 6);
}
#[test]
fn zero_cross_layers_is_identity() {
let mut rng = LcgRng::new(4);
let model =
Dcn::new(cfg(CrossKind::Vector, 0, vec![]), &mut rng).expect("model must build");
let x = vec![0.1_f32, -0.2, 0.3, 0.4, -0.5, 0.6];
let out = model.cross_forward(&x).expect("cross must succeed");
for (a, b) in out.iter().zip(x.iter()) {
assert!((a - b).abs() < 1e-6, "expected identity, got {a} vs {b}");
}
}
#[test]
fn deep_output_dim_matches_last_hidden() {
let mut rng = LcgRng::new(5);
let model = Dcn::new(cfg(CrossKind::MatrixV2, 1, vec![10, 7, 3]), &mut rng)
.expect("model must build");
assert_eq!(model.deep_out_dim(), 3);
let out = model
.deep_forward(&[0.2_f32; 6])
.expect("deep must succeed");
assert_eq!(out.len(), 3);
}
#[test]
fn empty_deep_returns_empty() {
let mut rng = LcgRng::new(6);
let model =
Dcn::new(cfg(CrossKind::Vector, 2, vec![]), &mut rng).expect("model must build");
assert_eq!(model.deep_out_dim(), 0);
let out = model
.deep_forward(&[0.3_f32; 6])
.expect("deep must succeed");
assert!(out.is_empty());
}
#[test]
fn deep_relu_nonnegative() {
let mut rng = LcgRng::new(7);
let model =
Dcn::new(cfg(CrossKind::Vector, 1, vec![12, 8]), &mut rng).expect("model must build");
let x: Vec<f32> = (0..6)
.map(|i| if i % 2 == 0 { -1.0 } else { 1.0 })
.collect();
let out = model.deep_forward(&x).expect("deep must succeed");
assert!(out.iter().all(|&v| v >= 0.0), "ReLU output must be >= 0");
}
#[test]
fn forward_finite_extreme_inputs() {
let mut rng = LcgRng::new(8);
let model =
Dcn::new(cfg(CrossKind::MatrixV2, 2, vec![8]), &mut rng).expect("model must build");
let x = vec![1e3_f32; 6];
let p = model.forward(&x).expect("forward must succeed");
assert!(p.is_finite(), "prob must be finite for extreme input");
assert!((0.0..=1.0).contains(&p));
}
#[test]
fn dimension_mismatch_cross_errors() {
let mut rng = LcgRng::new(9);
let model =
Dcn::new(cfg(CrossKind::Vector, 1, vec![4]), &mut rng).expect("model must build");
let err = model.cross_forward(&[1.0, 2.0, 3.0]);
assert!(matches!(err, Err(RecsysError::DimensionMismatch { .. })));
}
#[test]
fn dimension_mismatch_deep_errors() {
let mut rng = LcgRng::new(10);
let model =
Dcn::new(cfg(CrossKind::MatrixV2, 1, vec![4]), &mut rng).expect("model must build");
let err = model.deep_forward(&[1.0, 2.0]);
assert!(matches!(err, Err(RecsysError::DimensionMismatch { .. })));
}
#[test]
fn zero_input_dim_rejected() {
let mut rng = LcgRng::new(11);
let err = Dcn::new(
DcnConfig {
input_dim: 0,
n_cross_layers: 1,
deep_dims: vec![4],
kind: CrossKind::Vector,
},
&mut rng,
);
assert!(matches!(err, Err(RecsysError::InvalidConfig { .. })));
}
#[test]
fn zero_deep_hidden_rejected() {
let mut rng = LcgRng::new(12);
let err = Dcn::new(
DcnConfig {
input_dim: 6,
n_cross_layers: 1,
deep_dims: vec![0],
kind: CrossKind::Vector,
},
&mut rng,
);
assert!(matches!(err, Err(RecsysError::InvalidEmbeddingDim { .. })));
}
#[test]
fn n_cross_layers_reported() {
let mut rng = LcgRng::new(13);
let model =
Dcn::new(cfg(CrossKind::MatrixV2, 5, vec![4]), &mut rng).expect("model must build");
assert_eq!(model.n_cross_layers(), 5);
assert_eq!(model.kind(), CrossKind::MatrixV2);
}
#[test]
fn vector_single_layer_matches_formula() {
let mut rng = LcgRng::new(14);
let mut model =
Dcn::new(cfg(CrossKind::Vector, 1, vec![]), &mut rng).expect("model must build");
model.cross_w[0] = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0];
model.cross_b[0] = vec![0.5; 6];
let x0 = vec![2.0_f32, 1.0, 1.0, 1.0, 1.0, 1.0];
let out = model.cross_forward(&x0).expect("cross must succeed");
assert!((out[0] - (3.0 * 2.0 + 0.5)).abs() < 1e-5, "got {}", out[0]);
assert!((out[1] - (3.0 * 1.0 + 0.5)).abs() < 1e-5, "got {}", out[1]);
}
#[test]
fn matrix_identity_weight_doubles_residual() {
let mut rng = LcgRng::new(15);
let mut model =
Dcn::new(cfg(CrossKind::MatrixV2, 1, vec![]), &mut rng).expect("model must build");
let d = 6;
let mut w = vec![0.0_f32; d * d];
for i in 0..d {
w[i * d + i] = 1.0;
}
model.cross_w[0] = w;
model.cross_b[0] = vec![0.0; d];
let x0 = vec![0.5_f32, 1.0, 2.0, 0.0, -1.0, 3.0];
let out = model.cross_forward(&x0).expect("cross must succeed");
for i in 0..d {
let expected = x0[i] * x0[i] + x0[i];
assert!(
(out[i] - expected).abs() < 1e-5,
"i={i}: got {} want {expected}",
out[i]
);
}
}
}