use scirs2_core::random::{Rng, RngExt};
#[derive(Debug, Clone)]
pub struct SkipGramEmbedding {
pub dim: usize,
pub embeddings: Vec<Vec<f64>>,
pub context_embeddings: Vec<Vec<f64>>,
}
impl SkipGramEmbedding {
pub fn get_embedding(&self, node: usize) -> &[f64] {
self.embeddings.get(node).map(|v| v.as_slice()).unwrap_or(&[])
}
pub fn get_context_embedding(&self, node: usize) -> &[f64] {
self.context_embeddings.get(node).map(|v| v.as_slice()).unwrap_or(&[])
}
pub fn n_nodes(&self) -> usize {
self.embeddings.len()
}
pub fn cosine_similarity(&self, a: usize, b: usize) -> f64 {
let ea = self.get_embedding(a);
let eb = self.get_embedding(b);
if ea.is_empty() || eb.is_empty() {
return 0.0;
}
let dot: f64 = ea.iter().zip(eb.iter()).map(|(x, y)| x * y).sum();
let norm_a: f64 = ea.iter().map(|x| x * x).sum::<f64>().sqrt();
let norm_b: f64 = eb.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm_a <= 0.0 || norm_b <= 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
}
pub fn train(
walks: &[Vec<usize>],
n_nodes: usize,
dim: usize,
window: usize,
epochs: usize,
lr: f64,
) -> SkipGramEmbedding {
let mut rng = scirs2_core::random::rng();
let mut emb: Vec<Vec<f64>> = (0..n_nodes)
.map(|_| {
(0..dim)
.map(|_| (rng.random::<f64>() - 0.5) / dim as f64)
.collect()
})
.collect();
let mut ctx_emb: Vec<Vec<f64>> = vec![vec![0.0f64; dim]; n_nodes];
let mut freq = vec![0usize; n_nodes];
for walk in walks {
for &v in walk {
if v < n_nodes {
freq[v] += 1;
}
}
}
let noise_dist = build_noise_distribution(&freq);
let total_walks = walks.len() * epochs;
let neg_k = 5usize;
for epoch in 0..epochs {
for (walk_idx, walk) in walks.iter().enumerate() {
let step = epoch * walks.len() + walk_idx;
let progress = step as f64 / total_walks.max(1) as f64;
let current_lr = (lr * (1.0 - progress)).max(lr * 0.0001);
for (pos, ¢er) in walk.iter().enumerate() {
if center >= n_nodes {
continue;
}
let win_start = pos.saturating_sub(window);
let win_end = (pos + window + 1).min(walk.len());
for ctx_pos in win_start..win_end {
if ctx_pos == pos {
continue;
}
let ctx = walk[ctx_pos];
if ctx >= n_nodes {
continue;
}
let dot = dot_product(&emb[center], &ctx_emb[ctx], dim);
let sig = sigmoid(dot);
let g_pos = current_lr * (1.0 - sig);
let mut grad_center = vec![0.0f64; dim];
for d in 0..dim {
grad_center[d] += g_pos * ctx_emb[ctx][d];
ctx_emb[ctx][d] += g_pos * emb[center][d];
}
for _ in 0..neg_k {
let neg = sample_noise(&noise_dist, &mut rng, n_nodes);
if neg == center || neg == ctx {
continue;
}
let dot_neg = dot_product(&emb[center], &ctx_emb[neg], dim);
let sig_neg = sigmoid(dot_neg);
let g_neg = current_lr * (-sig_neg);
for d in 0..dim {
grad_center[d] += g_neg * ctx_emb[neg][d];
ctx_emb[neg][d] += g_neg * emb[center][d];
}
}
for d in 0..dim {
emb[center][d] += grad_center[d];
}
}
}
}
}
SkipGramEmbedding {
dim,
embeddings: emb,
context_embeddings: ctx_emb,
}
}
fn build_noise_distribution(freq: &[usize]) -> (Vec<f64>, Vec<usize>) {
let n = freq.len();
if n == 0 {
return (Vec::new(), Vec::new());
}
let powered: Vec<f64> = freq.iter().map(|&f| (f as f64).powf(0.75) + 1e-10).collect();
let total: f64 = powered.iter().sum();
let mut prob = vec![0.0f64; n];
let mut alias = vec![0usize; n];
let mut work: Vec<f64> = powered.iter().map(|p| p / total * n as f64).collect();
let mut small: Vec<usize> = Vec::new();
let mut large: Vec<usize> = Vec::new();
for (i, &w) in work.iter().enumerate() {
if w < 1.0 {
small.push(i);
} else {
large.push(i);
}
}
while !small.is_empty() && !large.is_empty() {
let s = small.pop().unwrap_or(0);
let l = large.pop().unwrap_or(0);
prob[s] = work[s];
alias[s] = l;
work[l] = (work[l] + work[s]) - 1.0;
if work[l] < 1.0 {
small.push(l);
} else {
large.push(l);
}
}
for &i in large.iter().chain(small.iter()) {
prob[i] = 1.0;
}
(prob, alias)
}
fn sample_noise(
noise: &(Vec<f64>, Vec<usize>),
rng: &mut impl Rng,
n_nodes: usize,
) -> usize {
let (prob, alias) = noise;
if prob.is_empty() {
return rng.random_range(0..n_nodes.max(1));
}
let n = prob.len();
let i = rng.random_range(0..n);
if rng.random::<f64>() < prob[i] {
i
} else {
alias[i]
}
}
#[inline]
fn dot_product(a: &[f64], b: &[f64], dim: usize) -> f64 {
a.iter().take(dim).zip(b.iter().take(dim)).map(|(x, y)| x * y).sum()
}
#[inline]
fn sigmoid(x: f64) -> f64 {
if x > 6.0 {
1.0 - 1e-10
} else if x < -6.0 {
1e-10
} else {
1.0 / (1.0 + (-x).exp())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_walks() -> Vec<Vec<usize>> {
vec![
vec![0, 1, 2, 1, 0],
vec![1, 2, 3, 2, 1],
vec![0, 1, 3, 2, 0],
vec![2, 3, 0, 1, 2],
]
}
#[test]
fn test_train_produces_embeddings() {
let walks = make_walks();
let emb = train(&walks, 4, 8, 2, 2, 0.025);
assert_eq!(emb.n_nodes(), 4);
assert_eq!(emb.dim, 8);
}
#[test]
fn test_get_embedding_length() {
let walks = make_walks();
let emb = train(&walks, 4, 16, 2, 1, 0.025);
for v in 0..4 {
assert_eq!(
emb.get_embedding(v).len(),
16,
"node {v} embedding should have length 16"
);
}
}
#[test]
fn test_get_embedding_out_of_range() {
let walks = make_walks();
let emb = train(&walks, 4, 8, 2, 1, 0.025);
let e = emb.get_embedding(99);
assert!(e.is_empty(), "out-of-range node should return empty slice");
}
#[test]
fn test_cosine_similarity_self() {
let walks = make_walks();
let emb = train(&walks, 4, 8, 2, 2, 0.025);
let sim = emb.cosine_similarity(0, 0);
assert!(
(sim - 1.0).abs() < 1e-6,
"cosine similarity with self should be ~1.0, got {sim}"
);
}
#[test]
fn test_cosine_similarity_range() {
let walks = make_walks();
let emb = train(&walks, 4, 8, 2, 2, 0.025);
for a in 0..4 {
for b in 0..4 {
let sim = emb.cosine_similarity(a, b);
assert!(
(-1.0 - 1e-6..=1.0 + 1e-6).contains(&sim),
"cosine similarity should be in [-1,1], got {sim}"
);
}
}
}
#[test]
fn test_train_empty_walks() {
let emb = train(&[], 4, 8, 2, 1, 0.025);
assert_eq!(emb.n_nodes(), 4);
let e = emb.get_embedding(0);
assert_eq!(e.len(), 8);
}
#[test]
fn test_noise_distribution() {
let freq = vec![10usize, 5, 2, 1];
let (prob, alias) = build_noise_distribution(&freq);
assert_eq!(prob.len(), 4);
assert_eq!(alias.len(), 4);
for &p in &prob {
assert!((0.0..=1.0 + 1e-9).contains(&p), "prob should be in [0,1]");
}
}
}