use crate::error::{RecsysError, RecsysResult};
use crate::handle::LcgRng;
fn layer_norm(x: &[f32], g: &[f32], b: &[f32]) -> Vec<f32> {
let mean = x.iter().sum::<f32>() / x.len() as f32;
let var = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
let inv_std = 1.0 / (var + 1e-5).sqrt();
x.iter()
.zip(g.iter().zip(b.iter()))
.map(|(&xi, (&gi, &bi))| (xi - mean) * inv_std * gi + bi)
.collect()
}
fn softmax_inplace(v: &mut [f32]) {
let max = v.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0_f32;
for x in v.iter_mut() {
*x = (*x - max).exp();
sum += *x;
}
let inv = 1.0 / (sum + 1e-10);
for x in v.iter_mut() {
*x *= inv;
}
}
pub struct BertLayer {
pub wq: Vec<f32>,
pub wk: Vec<f32>,
pub wv: Vec<f32>,
pub wo: Vec<f32>,
pub w1: Vec<f32>,
pub b1: Vec<f32>,
pub w2: Vec<f32>,
pub b2: Vec<f32>,
pub ln1_g: Vec<f32>,
pub ln1_b: Vec<f32>,
pub ln2_g: Vec<f32>,
pub ln2_b: Vec<f32>,
}
impl BertLayer {
pub fn new(emb_dim: usize, rng: &mut LcgRng) -> Self {
let sc = (1.0 / emb_dim as f32).sqrt();
let ffn_dim = 4 * emb_dim;
let ffn_sc = (2.0 / emb_dim as f32).sqrt();
Self {
wq: (0..emb_dim * emb_dim)
.map(|_| rng.next_normal() * sc)
.collect(),
wk: (0..emb_dim * emb_dim)
.map(|_| rng.next_normal() * sc)
.collect(),
wv: (0..emb_dim * emb_dim)
.map(|_| rng.next_normal() * sc)
.collect(),
wo: (0..emb_dim * emb_dim)
.map(|_| rng.next_normal() * sc)
.collect(),
w1: (0..ffn_dim * emb_dim)
.map(|_| rng.next_normal() * ffn_sc)
.collect(),
b1: vec![0.0_f32; ffn_dim],
w2: (0..emb_dim * ffn_dim)
.map(|_| rng.next_normal() * ffn_sc)
.collect(),
b2: vec![0.0_f32; emb_dim],
ln1_g: vec![1.0_f32; emb_dim],
ln1_b: vec![0.0_f32; emb_dim],
ln2_g: vec![1.0_f32; emb_dim],
ln2_b: vec![0.0_f32; emb_dim],
}
}
}
pub struct Bert4Rec {
pub n_items: usize,
pub emb_dim: usize,
pub n_heads: usize,
pub n_layers: usize,
pub item_emb: Vec<f32>,
pub pos_emb: Vec<f32>,
pub mask_emb: Vec<f32>,
pub attn_layers: Vec<BertLayer>,
}
const MASK_TOKEN: usize = usize::MAX;
impl Bert4Rec {
pub fn new(
n_items: usize,
emb_dim: usize,
n_heads: usize,
n_layers: usize,
max_seq_len: usize,
rng: &mut LcgRng,
) -> RecsysResult<Self> {
if n_items == 0 {
return Err(RecsysError::InvalidNumItems { n: n_items });
}
if emb_dim == 0 {
return Err(RecsysError::InvalidEmbeddingDim { d: emb_dim });
}
let sc = (1.0 / emb_dim as f32).sqrt();
let item_emb: Vec<f32> = (0..n_items * emb_dim)
.map(|_| rng.next_normal() * sc)
.collect();
let pos_emb: Vec<f32> = (0..max_seq_len * emb_dim)
.map(|_| rng.next_normal() * sc)
.collect();
let mask_emb: Vec<f32> = (0..emb_dim).map(|_| rng.next_normal() * sc).collect();
let attn_layers: Vec<BertLayer> = (0..n_layers)
.map(|_| BertLayer::new(emb_dim, rng))
.collect();
Ok(Self {
n_items,
emb_dim,
n_heads,
n_layers,
item_emb,
pos_emb,
mask_emb,
attn_layers,
})
}
pub fn mask_sequence(
&self,
item_ids: &[usize],
mask_ratio: f32,
rng: &mut LcgRng,
) -> Vec<usize> {
item_ids
.iter()
.map(|&id| {
if rng.next_f32() < mask_ratio {
MASK_TOKEN
} else {
id
}
})
.collect()
}
fn embed_sequence(&self, masked_ids: &[usize]) -> Vec<f32> {
let d = self.emb_dim;
let seq_len = masked_ids.len();
let mut h = vec![0.0_f32; seq_len * d];
for (pos, &id) in masked_ids.iter().enumerate() {
let item_e: &[f32] = if id == MASK_TOKEN {
&self.mask_emb
} else if id < self.n_items {
&self.item_emb[id * d..(id + 1) * d]
} else {
&self.mask_emb
};
let pos_start = pos.min(self.pos_emb.len() / d - 1) * d;
let pos_e = &self.pos_emb[pos_start..pos_start + d];
for (k, (&ie, &pe)) in item_e.iter().zip(pos_e.iter()).enumerate() {
h[pos * d + k] = ie + pe;
}
}
h
}
fn apply_layer(&self, h: &[f32], layer: &BertLayer, seq_len: usize) -> Vec<f32> {
let d = self.emb_dim;
let scale = 1.0 / (d as f32).sqrt();
let q = matmul_rows(h, &layer.wq, seq_len, d, d);
let k = matmul_rows(h, &layer.wk, seq_len, d, d);
let v = matmul_rows(h, &layer.wv, seq_len, d, d);
let mut attn_out = vec![0.0_f32; seq_len * d];
for i in 0..seq_len {
let mut scores: Vec<f32> = (0..seq_len)
.map(|j| {
q[i * d..(i + 1) * d]
.iter()
.zip(k[j * d..(j + 1) * d].iter())
.map(|(&qi, &kj)| qi * kj)
.sum::<f32>()
* scale
})
.collect();
softmax_inplace(&mut scores);
for (j, &a) in scores.iter().enumerate() {
for (k_idx, &vk) in v[j * d..(j + 1) * d].iter().enumerate() {
attn_out[i * d + k_idx] += a * vk;
}
}
}
let proj = matmul_rows(&attn_out, &layer.wo, seq_len, d, d);
let ffn_dim = 4 * d;
let mut h_attn = vec![0.0_f32; seq_len * d];
for pos in 0..seq_len {
let res: Vec<f32> = h[pos * d..(pos + 1) * d]
.iter()
.zip(proj[pos * d..(pos + 1) * d].iter())
.map(|(&hv, &pv)| hv + pv)
.collect();
let normed = layer_norm(&res, &layer.ln1_g, &layer.ln1_b);
h_attn[pos * d..(pos + 1) * d].copy_from_slice(&normed);
}
let mut h_ffn = vec![0.0_f32; seq_len * d];
for pos in 0..seq_len {
let x = &h_attn[pos * d..(pos + 1) * d];
let mut mid: Vec<f32> = (0..ffn_dim)
.map(|o| {
layer.b1[o]
+ layer.w1[o * d..(o + 1) * d]
.iter()
.zip(x.iter())
.map(|(&w, &xi)| w * xi)
.sum::<f32>()
})
.collect();
for v in &mut mid {
if *v < 0.0 {
*v = 0.0;
}
}
let out: Vec<f32> = (0..d)
.map(|o| {
layer.b2[o]
+ layer.w2[o * ffn_dim..(o + 1) * ffn_dim]
.iter()
.zip(mid.iter())
.map(|(&w, &mi)| w * mi)
.sum::<f32>()
})
.collect();
let res2: Vec<f32> = x.iter().zip(out.iter()).map(|(&hv, &ov)| hv + ov).collect();
let normed2 = layer_norm(&res2, &layer.ln2_g, &layer.ln2_b);
h_ffn[pos * d..(pos + 1) * d].copy_from_slice(&normed2);
}
h_ffn
}
pub fn forward_masked(&self, masked_ids: &[usize]) -> RecsysResult<Vec<Vec<f32>>> {
if masked_ids.is_empty() {
return Err(RecsysError::EmptyInput);
}
for &id in masked_ids {
if id != MASK_TOKEN && id >= self.n_items {
return Err(RecsysError::UnknownItem { id });
}
}
let seq_len = masked_ids.len();
let d = self.emb_dim;
let mut h = self.embed_sequence(masked_ids);
for layer in &self.attn_layers {
h = self.apply_layer(&h, layer, seq_len);
}
let logits: Vec<Vec<f32>> = (0..seq_len)
.map(|pos| {
let h_pos = &h[pos * d..(pos + 1) * d];
(0..self.n_items)
.map(|item| {
self.item_emb[item * d..(item + 1) * d]
.iter()
.zip(h_pos.iter())
.map(|(&e, &q)| e * q)
.sum()
})
.collect()
})
.collect();
Ok(logits)
}
}
fn matmul_rows(x: &[f32], w: &[f32], n: usize, d_in: usize, d_out: usize) -> Vec<f32> {
let mut out = vec![0.0_f32; n * d_out];
for row in 0..n {
for col in 0..d_out {
out[row * d_out + col] = w[col * d_in..(col + 1) * d_in]
.iter()
.zip(x[row * d_in..(row + 1) * d_in].iter())
.map(|(&wi, &xi)| wi * xi)
.sum();
}
}
out
}