use super::types::NegativeSamplingStrategy;
use crate::base::{EdgeWeight, Graph, Node};
use scirs2_core::random::{Rng, RngExt};
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct NegativeSampler<N: Node> {
vocabulary: Vec<N>,
#[allow(dead_code)]
frequencies: Vec<f64>,
cumulative: Vec<f64>,
}
impl<N: Node> NegativeSampler<N> {
pub fn new<E, Ix>(graph: &Graph<N, E, Ix>) -> Self
where
N: Clone + std::fmt::Debug,
E: EdgeWeight,
Ix: petgraph::graph::IndexType,
{
let vocabulary: Vec<N> = graph.nodes().into_iter().cloned().collect();
let node_degrees = vocabulary
.iter()
.map(|node| graph.degree(node) as f64)
.collect::<Vec<_>>();
let total_degree: f64 = node_degrees.iter().sum();
let frequencies: Vec<f64> = node_degrees
.iter()
.map(|d| (d / total_degree).powf(0.75))
.collect();
let total_freq: f64 = frequencies.iter().sum();
let frequencies: Vec<f64> = frequencies.iter().map(|f| f / total_freq).collect();
let mut cumulative = vec![0.0; frequencies.len()];
cumulative[0] = frequencies[0];
for i in 1..frequencies.len() {
cumulative[i] = cumulative[i - 1] + frequencies[i];
}
NegativeSampler {
vocabulary,
frequencies,
cumulative,
}
}
pub fn sample(&self, rng: &mut impl Rng) -> Option<&N> {
if self.vocabulary.is_empty() {
return None;
}
let r = rng.random::<f64>();
for (i, &cum_freq) in self.cumulative.iter().enumerate() {
if r <= cum_freq {
return Some(&self.vocabulary[i]);
}
}
self.vocabulary.last()
}
pub fn sample_negatives(
&self,
count: usize,
exclude: &HashSet<&N>,
rng: &mut impl Rng,
) -> Vec<N> {
let mut negatives = Vec::new();
let mut attempts = 0;
let max_attempts = count * 10;
while negatives.len() < count && attempts < max_attempts {
if let Some(candidate) = self.sample(rng) {
if !exclude.contains(candidate) {
negatives.push(candidate.clone());
}
}
attempts += 1;
}
negatives
}
}