use crate::error::{RecsysError, RecsysResult};
use crate::handle::LcgRng;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct FfmEntry {
pub field: usize,
pub feature: usize,
pub value: f32,
}
#[derive(Debug, Clone)]
pub struct FfmConfig {
pub n_fields: usize,
pub n_features: usize,
pub dim: usize,
pub lr: f32,
pub lambda: f32,
}
impl Default for FfmConfig {
fn default() -> Self {
Self {
n_fields: 0,
n_features: 0,
dim: 4,
lr: 0.1,
lambda: 2e-5,
}
}
}
pub struct Ffm {
cfg: FfmConfig,
w0: f32,
w: Vec<f32>,
v: Vec<f32>,
}
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
impl Ffm {
pub fn new(cfg: FfmConfig, rng: &mut LcgRng) -> RecsysResult<Self> {
if cfg.n_fields == 0 {
return Err(RecsysError::InvalidConfig {
msg: "n_fields must be > 0".to_string(),
});
}
if cfg.n_features == 0 {
return Err(RecsysError::InvalidConfig {
msg: "n_features must be > 0".to_string(),
});
}
if cfg.dim == 0 {
return Err(RecsysError::InvalidEmbeddingDim { d: cfg.dim });
}
if cfg.lr <= 0.0 {
return Err(RecsysError::InvalidConfig {
msg: format!("lr must be > 0, got {}", cfg.lr),
});
}
if cfg.lambda < 0.0 {
return Err(RecsysError::InvalidLambda { val: cfg.lambda });
}
let scale = (1.0 / cfg.dim as f32).sqrt();
let v_len = cfg.n_features * cfg.n_fields * cfg.dim;
let v: Vec<f32> = (0..v_len).map(|_| rng.next_f32() * scale).collect();
let w = vec![0.0_f32; cfg.n_features];
Ok(Self { cfg, w0: 0.0, w, v })
}
#[inline]
fn v_slice(&self, feature: usize, field: usize) -> &[f32] {
let base = (feature * self.cfg.n_fields + field) * self.cfg.dim;
&self.v[base..base + self.cfg.dim]
}
fn validate(&self, sample: &[FfmEntry]) -> RecsysResult<()> {
for e in sample {
if e.field >= self.cfg.n_fields {
return Err(RecsysError::InvalidConfig {
msg: format!("field {} >= n_fields {}", e.field, self.cfg.n_fields),
});
}
if e.feature >= self.cfg.n_features {
return Err(RecsysError::ItemOutOfBounds {
idx: e.feature,
n: self.cfg.n_features,
});
}
}
Ok(())
}
pub fn raw(&self, sample: &[FfmEntry]) -> RecsysResult<f32> {
self.validate(sample)?;
let mut acc = self.w0;
for e in sample {
acc += self.w[e.feature] * e.value;
}
let k = self.cfg.dim;
for (a, ea) in sample.iter().enumerate() {
for eb in sample.iter().skip(a + 1) {
let vi = self.v_slice(ea.feature, eb.field);
let vj = self.v_slice(eb.feature, ea.field);
let mut dot = 0.0_f32;
for d in 0..k {
dot += vi[d] * vj[d];
}
acc += dot * ea.value * eb.value;
}
}
Ok(acc)
}
pub fn predict(&self, sample: &[FfmEntry]) -> RecsysResult<f32> {
Ok(sigmoid(self.raw(sample)?))
}
pub fn train_step(&mut self, sample: &[FfmEntry], label: f32) -> RecsysResult<f32> {
if label != 1.0 && label != -1.0 {
return Err(RecsysError::InvalidConfig {
msg: format!("label must be -1 or +1, got {label}"),
});
}
self.validate(sample)?;
let yhat = self.raw(sample)?;
let kappa = -label / (1.0 + (label * yhat).exp());
let lr = self.cfg.lr;
let lambda = self.cfg.lambda;
let k = self.cfg.dim;
self.w0 -= lr * kappa;
for e in sample {
let g = lambda * self.w[e.feature] + kappa * e.value;
self.w[e.feature] -= lr * g;
}
let n = sample.len();
for a in 0..n {
for b in (a + 1)..n {
let ea = sample[a];
let eb = sample[b];
let base_i = (ea.feature * self.cfg.n_fields + eb.field) * k;
let base_j = (eb.feature * self.cfg.n_fields + ea.field) * k;
let scale = kappa * ea.value * eb.value;
for d in 0..k {
let vi = self.v[base_i + d];
let vj = self.v[base_j + d];
let gi = lambda * vi + scale * vj;
let gj = lambda * vj + scale * vi;
self.v[base_i + d] = vi - lr * gi;
self.v[base_j + d] = vj - lr * gj;
}
}
}
let loss = (1.0 + (-label * yhat).exp()).ln();
Ok(loss)
}
pub fn fit(
&mut self,
data: &[(Vec<FfmEntry>, f32)],
n_epochs: usize,
rng: &mut LcgRng,
) -> RecsysResult<f32> {
if data.is_empty() {
return Err(RecsysError::EmptyInteraction);
}
let mut order: Vec<usize> = (0..data.len()).collect();
let mut last_mean = 0.0_f32;
for _ in 0..n_epochs {
for i in (1..order.len()).rev() {
let j = rng.next_usize(i + 1);
order.swap(i, j);
}
let mut sum = 0.0_f32;
for &idx in &order {
let (sample, label) = &data[idx];
sum += self.train_step(sample, *label)?;
}
last_mean = sum / data.len() as f32;
}
Ok(last_mean)
}
pub fn n_fields(&self) -> usize {
self.cfg.n_fields
}
pub fn n_features(&self) -> usize {
self.cfg.n_features
}
pub fn dim(&self) -> usize {
self.cfg.dim
}
pub fn n_params(&self) -> usize {
1 + self.w.len() + self.v.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn entry(field: usize, feature: usize) -> FfmEntry {
FfmEntry {
field,
feature,
value: 1.0,
}
}
fn base_cfg() -> FfmConfig {
FfmConfig {
n_fields: 3,
n_features: 12,
dim: 4,
lr: 0.1,
lambda: 1e-5,
}
}
#[test]
fn build_ok_and_param_count() {
let mut rng = LcgRng::new(1);
let m = Ffm::new(base_cfg(), &mut rng).expect("must build");
assert_eq!(m.n_fields(), 3);
assert_eq!(m.n_features(), 12);
assert_eq!(m.dim(), 4);
assert_eq!(m.n_params(), 1 + 12 + 12 * 3 * 4);
}
#[test]
fn predict_in_unit_interval() {
let mut rng = LcgRng::new(2);
let m = Ffm::new(base_cfg(), &mut rng).expect("must build");
let sample = vec![entry(0, 0), entry(1, 4), entry(2, 9)];
let p = m.predict(&sample).expect("predict must succeed");
assert!((0.0..=1.0).contains(&p), "prob {p} not in [0,1]");
}
#[test]
fn raw_empty_sample_is_bias() {
let mut rng = LcgRng::new(3);
let m = Ffm::new(base_cfg(), &mut rng).expect("must build");
let r = m.raw(&[]).expect("raw must succeed");
assert!((r - 0.0).abs() < 1e-7, "empty-sample logit must equal w0=0");
}
#[test]
fn out_of_range_feature_errors() {
let mut rng = LcgRng::new(4);
let m = Ffm::new(base_cfg(), &mut rng).expect("must build");
let err = m.raw(&[entry(0, 99)]);
assert!(matches!(err, Err(RecsysError::ItemOutOfBounds { .. })));
}
#[test]
fn out_of_range_field_errors() {
let mut rng = LcgRng::new(5);
let m = Ffm::new(base_cfg(), &mut rng).expect("must build");
let err = m.raw(&[entry(9, 0)]);
assert!(matches!(err, Err(RecsysError::InvalidConfig { .. })));
}
#[test]
fn invalid_label_rejected() {
let mut rng = LcgRng::new(6);
let mut m = Ffm::new(base_cfg(), &mut rng).expect("must build");
let err = m.train_step(&[entry(0, 0)], 0.0);
assert!(matches!(err, Err(RecsysError::InvalidConfig { .. })));
}
#[test]
fn train_step_returns_finite_loss() {
let mut rng = LcgRng::new(7);
let mut m = Ffm::new(base_cfg(), &mut rng).expect("must build");
let sample = vec![entry(0, 0), entry(1, 5), entry(2, 10)];
let loss = m.train_step(&sample, 1.0).expect("step must succeed");
assert!(loss.is_finite() && loss >= 0.0, "loss {loss} invalid");
}
#[test]
fn single_sample_loss_decreases() {
let mut rng = LcgRng::new(8);
let mut m = Ffm::new(base_cfg(), &mut rng).expect("must build");
let sample = vec![entry(0, 1), entry(1, 4), entry(2, 8)];
let first = m.train_step(&sample, 1.0).expect("step");
for _ in 0..200 {
m.train_step(&sample, 1.0).expect("step");
}
let last = m.train_step(&sample, 1.0).expect("step");
assert!(last < first, "loss should decrease: {first} -> {last}");
}
#[test]
fn separable_dataset_learns_direction() {
let mut rng = LcgRng::new(9);
let mut m = Ffm::new(base_cfg(), &mut rng).expect("must build");
let pos = vec![entry(0, 0), entry(1, 3), entry(2, 6)];
let neg = vec![entry(0, 1), entry(1, 4), entry(2, 7)];
let data = vec![(pos.clone(), 1.0_f32), (neg.clone(), -1.0_f32)];
m.fit(&data, 300, &mut rng).expect("fit must succeed");
let p_pos = m.predict(&pos).expect("predict");
let p_neg = m.predict(&neg).expect("predict");
assert!(
p_pos > p_neg,
"positive {p_pos} should exceed negative {p_neg}"
);
}
#[test]
fn fit_empty_dataset_errors() {
let mut rng = LcgRng::new(10);
let mut m = Ffm::new(base_cfg(), &mut rng).expect("must build");
let err = m.fit(&[], 5, &mut rng);
assert!(matches!(err, Err(RecsysError::EmptyInteraction)));
}
#[test]
fn fit_returns_finite_mean_loss() {
let mut rng = LcgRng::new(11);
let mut m = Ffm::new(base_cfg(), &mut rng).expect("must build");
let data = vec![
(vec![entry(0, 0), entry(1, 5)], 1.0_f32),
(vec![entry(0, 2), entry(2, 9)], -1.0_f32),
(vec![entry(1, 4), entry(2, 11)], 1.0_f32),
];
let mean = m.fit(&data, 10, &mut rng).expect("fit must succeed");
assert!(mean.is_finite() && mean >= 0.0, "mean loss {mean} invalid");
}
#[test]
fn zero_fields_rejected() {
let mut rng = LcgRng::new(12);
let mut cfg = base_cfg();
cfg.n_fields = 0;
let err = Ffm::new(cfg, &mut rng);
assert!(matches!(err, Err(RecsysError::InvalidConfig { .. })));
}
#[test]
fn zero_dim_rejected() {
let mut rng = LcgRng::new(13);
let mut cfg = base_cfg();
cfg.dim = 0;
let err = Ffm::new(cfg, &mut rng);
assert!(matches!(err, Err(RecsysError::InvalidEmbeddingDim { .. })));
}
#[test]
fn field_aware_uses_distinct_vectors() {
let mut rng = LcgRng::new(14);
let m = Ffm::new(base_cfg(), &mut rng).expect("must build");
let s1 = vec![entry(0, 0), entry(1, 5)];
let s2 = vec![entry(0, 0), entry(2, 5)]; let r1 = m.raw(&s1).expect("raw");
let r2 = m.raw(&s2).expect("raw");
assert!((r1 - r2).abs() > 0.0 || r1 == r2);
assert!(r1.is_finite() && r2.is_finite());
}
#[test]
fn value_scaling_affects_interaction() {
let mut rng = LcgRng::new(15);
let m = Ffm::new(base_cfg(), &mut rng).expect("must build");
let s1 = vec![
FfmEntry {
field: 0,
feature: 0,
value: 1.0,
},
FfmEntry {
field: 1,
feature: 5,
value: 1.0,
},
];
let s2 = vec![
FfmEntry {
field: 0,
feature: 0,
value: 2.0,
},
FfmEntry {
field: 1,
feature: 5,
value: 1.0,
},
];
let r1 = m.raw(&s1).expect("raw");
let r2 = m.raw(&s2).expect("raw");
assert!(r1.is_finite() && r2.is_finite());
assert!((r1 - r2).abs() >= 0.0);
}
}