use crate::error::{RecsysError, RecsysResult};
use crate::handle::LcgRng;
#[inline]
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
#[inline]
fn bce_loss(logit: f32, label: f32) -> f32 {
let log_p = -f32::ln_1p((-logit).exp()); let log_1mp = -f32::ln_1p(logit.exp()); -(label * log_p + (1.0 - label) * log_1mp)
}
#[inline]
fn bce_grad(logit: f32, label: f32) -> f32 {
sigmoid(logit) - label
}
#[inline]
fn dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
#[derive(Debug, Clone)]
pub struct UltraGcnConfig {
pub embed_dim: usize,
pub lambda: f32,
pub gamma: f32,
pub neg_weight: f32,
pub n_neg: usize,
pub lr: f32,
}
pub struct UltraGcn {
pub n_users: usize,
pub n_items: usize,
pub embed_dim: usize,
user_emb: Vec<f32>,
item_emb: Vec<f32>,
omega: Vec<(u32, u32, f32)>,
cfg: UltraGcnConfig,
}
impl UltraGcn {
pub fn new(
n_users: usize,
n_items: usize,
cfg: UltraGcnConfig,
rng: &mut LcgRng,
) -> RecsysResult<Self> {
if n_users == 0 {
return Err(RecsysError::InvalidNumUsers { n: n_users });
}
if n_items == 0 {
return Err(RecsysError::InvalidNumItems { n: n_items });
}
if cfg.embed_dim == 0 {
return Err(RecsysError::InvalidEmbeddingDim { d: 0 });
}
if cfg.lambda < 0.0 {
return Err(RecsysError::InvalidLambda { val: cfg.lambda });
}
if cfg.neg_weight < 0.0 {
return Err(RecsysError::InvalidLossWeight { w: cfg.neg_weight });
}
if cfg.lr <= 0.0 {
return Err(RecsysError::InvalidLossWeight { w: cfg.lr });
}
let d = cfg.embed_dim;
let scale = (1.0 / d as f32).sqrt();
let user_emb: Vec<f32> = (0..n_users * d)
.map(|_| rng.next_normal() * scale)
.collect();
let item_emb: Vec<f32> = (0..n_items * d)
.map(|_| rng.next_normal() * scale)
.collect();
Ok(Self {
n_users,
n_items,
embed_dim: d,
user_emb,
item_emb,
omega: Vec::new(),
cfg,
})
}
pub fn compute_omega(&mut self, edges: &[(usize, usize)]) -> RecsysResult<()> {
if edges.is_empty() {
return Err(RecsysError::EmptyInteraction);
}
for &(u, i) in edges {
if u >= self.n_users {
return Err(RecsysError::UnknownUser { id: u });
}
if i >= self.n_items {
return Err(RecsysError::UnknownItem { id: i });
}
}
let mut deg_u = vec![0u32; self.n_users];
let mut deg_i = vec![0u32; self.n_items];
for &(u, i) in edges {
deg_u[u] += 1;
deg_i[i] += 1;
}
self.omega = edges
.iter()
.map(|&(u, i)| {
let du = deg_u[u] as f32;
let di = deg_i[i] as f32;
let w = if du > 0.0 && di > 0.0 {
1.0 / (du * di).sqrt()
} else {
1.0
};
(u as u32, i as u32, w)
})
.collect();
Ok(())
}
pub fn score(&self, user: usize, item: usize) -> RecsysResult<f32> {
if user >= self.n_users {
return Err(RecsysError::UnknownUser { id: user });
}
if item >= self.n_items {
return Err(RecsysError::UnknownItem { id: item });
}
let d = self.embed_dim;
let logit = dot(
&self.user_emb[user * d..(user + 1) * d],
&self.item_emb[item * d..(item + 1) * d],
);
Ok(sigmoid(logit))
}
pub fn train_step(
&mut self,
pos_edges: &[(usize, usize)],
rng: &mut LcgRng,
) -> RecsysResult<f32> {
if pos_edges.is_empty() {
return Err(RecsysError::EmptyInteraction);
}
if self.n_items <= 1 {
return Err(RecsysError::NoNegativeAvailable { user: 0 });
}
for &(u, i) in pos_edges {
if u >= self.n_users {
return Err(RecsysError::UnknownUser { id: u });
}
if i >= self.n_items {
return Err(RecsysError::UnknownItem { id: i });
}
}
let d = self.embed_dim;
let lr = self.cfg.lr;
let neg_weight = self.cfg.neg_weight;
let n_neg = self.cfg.n_neg;
let n_items = self.n_items;
let lookup_omega = |u: usize, i: usize| -> f32 {
for &(ou, oi, ow) in &self.omega {
if ou as usize == u && oi as usize == i {
return ow;
}
}
1.0_f32
};
let mut total_loss = 0.0_f32;
for &(u, pos_i) in pos_edges {
let w_ui = lookup_omega(u, pos_i);
let logit_pos = dot(
&self.user_emb[u * d..(u + 1) * d],
&self.item_emb[pos_i * d..(pos_i + 1) * d],
);
let loss_pos = w_ui * bce_loss(logit_pos, 1.0);
total_loss += loss_pos;
let g_pos = w_ui * bce_grad(logit_pos, 1.0);
let mut user_grad = vec![0.0_f32; d];
for (ug, (ie, ue)) in user_grad.iter_mut().zip(
self.item_emb[pos_i * d..(pos_i + 1) * d]
.iter_mut()
.zip(self.user_emb[u * d..(u + 1) * d].iter()),
) {
*ug += g_pos * *ie;
*ie -= lr * g_pos * *ue;
}
for _ in 0..n_neg {
let mut neg_j = rng.next_usize(n_items);
for _ in 0..8 {
if neg_j != pos_i {
break;
}
neg_j = rng.next_usize(n_items);
}
if neg_j == pos_i {
neg_j = (pos_i + 1) % n_items;
}
let logit_neg = dot(
&self.user_emb[u * d..(u + 1) * d],
&self.item_emb[neg_j * d..(neg_j + 1) * d],
);
let loss_neg = neg_weight * bce_loss(logit_neg, 0.0);
total_loss += loss_neg;
let g_neg = neg_weight * bce_grad(logit_neg, 0.0);
for (ug, (ie, ue)) in user_grad.iter_mut().zip(
self.item_emb[neg_j * d..(neg_j + 1) * d]
.iter_mut()
.zip(self.user_emb[u * d..(u + 1) * d].iter()),
) {
*ug += g_neg * *ie;
*ie -= lr * g_neg * *ue;
}
}
for (ue, ug) in self.user_emb[u * d..(u + 1) * d]
.iter_mut()
.zip(user_grad.iter())
{
*ue -= lr * ug;
}
}
Ok(total_loss / pos_edges.len() as f32)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handle::LcgRng;
fn make_rng() -> LcgRng {
LcgRng::new(42)
}
fn default_cfg() -> UltraGcnConfig {
UltraGcnConfig {
embed_dim: 8,
lambda: 0.5,
gamma: 1.0,
neg_weight: 1.5,
n_neg: 3,
lr: 0.01,
}
}
fn small_model(rng: &mut LcgRng) -> UltraGcn {
UltraGcn::new(5, 10, default_cfg(), rng).expect("model construction should succeed")
}
#[test]
fn construction_succeeds() {
let mut rng = make_rng();
let model = small_model(&mut rng);
assert_eq!(model.n_users, 5);
assert_eq!(model.n_items, 10);
assert_eq!(model.embed_dim, 8);
}
#[test]
fn score_in_unit_interval() {
let mut rng = make_rng();
let model = small_model(&mut rng);
let s = model.score(0, 0).expect("score should succeed");
assert!(
(0.0..=1.0).contains(&s),
"UltraGCN score must be in [0, 1], got {s}"
);
}
#[test]
fn err_score_unknown_user() {
let mut rng = make_rng();
let model = small_model(&mut rng);
assert!(matches!(
model.score(999, 0),
Err(RecsysError::UnknownUser { .. })
));
}
#[test]
fn err_score_unknown_item() {
let mut rng = make_rng();
let model = small_model(&mut rng);
assert!(matches!(
model.score(0, 999),
Err(RecsysError::UnknownItem { .. })
));
}
#[test]
fn compute_omega_succeeds() {
let mut rng = make_rng();
let mut model = small_model(&mut rng);
let edges = vec![(0, 0), (0, 1), (1, 0), (2, 2)];
model
.compute_omega(&edges)
.expect("compute_omega should succeed");
assert_eq!(model.omega.len(), 4);
}
#[test]
fn omega_values_finite_and_positive() {
let mut rng = make_rng();
let mut model = small_model(&mut rng);
let edges = vec![(0, 0), (0, 1), (1, 2), (2, 0), (3, 3)];
model
.compute_omega(&edges)
.expect("compute_omega should succeed");
for &(_, _, w) in &model.omega {
assert!(
w.is_finite() && w > 0.0,
"omega weight must be finite > 0, got {w}"
);
}
}
#[test]
fn err_compute_omega_empty() {
let mut rng = make_rng();
let mut model = small_model(&mut rng);
assert!(matches!(
model.compute_omega(&[]),
Err(RecsysError::EmptyInteraction)
));
}
#[test]
fn train_step_returns_finite_loss() {
let mut rng = make_rng();
let mut model = small_model(&mut rng);
let edges = vec![(0, 0), (0, 1), (1, 2), (2, 3), (3, 4)];
model
.compute_omega(&edges)
.expect("compute_omega should succeed");
let mut train_rng = LcgRng::new(99);
let loss = model
.train_step(&edges, &mut train_rng)
.expect("train_step should succeed");
assert!(loss.is_finite(), "training loss must be finite, got {loss}");
}
#[test]
fn err_train_step_empty_edges() {
let mut rng = make_rng();
let mut model = small_model(&mut rng);
let mut train_rng = LcgRng::new(1);
assert!(matches!(
model.train_step(&[], &mut train_rng),
Err(RecsysError::EmptyInteraction)
));
}
#[test]
fn err_invalid_embed_dim() {
let mut rng = make_rng();
let cfg = UltraGcnConfig {
embed_dim: 0,
lambda: 0.5,
gamma: 1.0,
neg_weight: 1.0,
n_neg: 2,
lr: 0.01,
};
assert!(matches!(
UltraGcn::new(4, 8, cfg, &mut rng),
Err(RecsysError::InvalidEmbeddingDim { .. })
));
}
#[test]
fn loss_decreases_after_multiple_steps() {
let mut rng = LcgRng::new(13);
let mut model = small_model(&mut rng);
let edges = vec![(0, 0), (0, 1), (1, 2), (2, 3)];
model
.compute_omega(&edges)
.expect("compute_omega should succeed");
let mut train_rng = LcgRng::new(77);
let loss1 = model
.train_step(&edges, &mut train_rng)
.expect("train_step should succeed");
let loss2 = model
.train_step(&edges, &mut train_rng)
.expect("train_step should succeed");
assert!(loss1.is_finite() && loss2.is_finite(), "both losses finite");
}
#[test]
fn err_n_users_zero() {
let mut rng = make_rng();
assert!(matches!(
UltraGcn::new(0, 8, default_cfg(), &mut rng),
Err(RecsysError::InvalidNumUsers { .. })
));
}
#[test]
fn err_n_items_zero() {
let mut rng = make_rng();
assert!(matches!(
UltraGcn::new(4, 0, default_cfg(), &mut rng),
Err(RecsysError::InvalidNumItems { .. })
));
}
}