use super::negative_sampling::NegativeSampler;
use super::types::ContextPair;
use crate::base::{DiGraph, EdgeWeight, Graph, Node};
use crate::error::{GraphError, Result};
use scirs2_core::random::{Rng, RngExt};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Embedding {
pub vector: Vec<f64>,
}
impl Embedding {
pub fn new(dimensions: usize) -> Self {
Embedding {
vector: vec![0.0; dimensions],
}
}
pub fn random(dimensions: usize, rng: &mut impl Rng) -> Self {
let vector: Vec<f64> = (0..dimensions)
.map(|_| rng.random_range(-0.5..0.5))
.collect();
Embedding { vector }
}
pub fn dimensions(&self) -> usize {
self.vector.len()
}
pub fn cosine_similarity(&self, other: &Embedding) -> Result<f64> {
if self.vector.len() != other.vector.len() {
return Err(GraphError::InvalidGraph(
"Embeddings must have same dimensions".to_string(),
));
}
let dot_product: f64 = self
.vector
.iter()
.zip(other.vector.iter())
.map(|(a, b)| a * b)
.sum();
let norm_a = self.norm();
let norm_b = other.norm();
if norm_a == 0.0 || norm_b == 0.0 {
Ok(0.0)
} else {
Ok(dot_product / (norm_a * norm_b))
}
}
pub fn norm(&self) -> f64 {
self.vector.iter().map(|x| x * x).sum::<f64>().sqrt()
}
pub fn normalize(&mut self) {
let norm = self.norm();
if norm > 0.0 {
for x in &mut self.vector {
*x /= norm;
}
}
}
pub fn add(&mut self, other: &Embedding) -> Result<()> {
if self.vector.len() != other.vector.len() {
return Err(GraphError::InvalidGraph(
"Embeddings must have same dimensions".to_string(),
));
}
for (a, b) in self.vector.iter_mut().zip(other.vector.iter()) {
*a += b;
}
Ok(())
}
pub fn scale(&mut self, factor: f64) {
for x in &mut self.vector {
*x *= factor;
}
}
pub fn dot_product(&self, other: &Embedding) -> Result<f64> {
if self.vector.len() != other.vector.len() {
return Err(GraphError::InvalidGraph(
"Embeddings must have same dimensions".to_string(),
));
}
let dot: f64 = self
.vector
.iter()
.zip(other.vector.iter())
.map(|(a, b)| a * b)
.sum();
Ok(dot)
}
pub fn sigmoid(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
}
pub fn update_gradient(&mut self, gradient: &[f64], learning_rate: f64) {
for (emb, &grad) in self.vector.iter_mut().zip(gradient.iter()) {
*emb -= learning_rate * grad;
}
}
}
#[derive(Debug)]
pub struct EmbeddingModel<N: Node> {
pub embeddings: HashMap<N, Embedding>,
pub context_embeddings: HashMap<N, Embedding>,
pub dimensions: usize,
}
impl<N: Node> EmbeddingModel<N> {
pub fn new(dimensions: usize) -> Self {
EmbeddingModel {
embeddings: HashMap::new(),
context_embeddings: HashMap::new(),
dimensions,
}
}
pub fn get_embedding(&self, node: &N) -> Option<&Embedding> {
self.embeddings.get(node)
}
pub fn set_embedding(&mut self, node: N, embedding: Embedding) -> Result<()> {
if embedding.dimensions() != self.dimensions {
return Err(GraphError::InvalidGraph(
"Embedding dimensions don't match model".to_string(),
));
}
self.embeddings.insert(node, embedding);
Ok(())
}
pub fn initialize_random<E, Ix>(&mut self, graph: &Graph<N, E, Ix>, rng: &mut impl Rng)
where
N: Clone + std::fmt::Debug,
E: EdgeWeight,
Ix: petgraph::graph::IndexType,
{
for node in graph.nodes() {
let embedding = Embedding::random(self.dimensions, rng);
let context_embedding = Embedding::random(self.dimensions, rng);
self.embeddings.insert(node.clone(), embedding);
self.context_embeddings
.insert(node.clone(), context_embedding);
}
}
pub fn initialize_random_digraph<E, Ix>(
&mut self,
graph: &DiGraph<N, E, Ix>,
rng: &mut impl Rng,
) where
N: Clone + std::fmt::Debug,
E: EdgeWeight,
Ix: petgraph::graph::IndexType,
{
for node in graph.nodes() {
let embedding = Embedding::random(self.dimensions, rng);
let context_embedding = Embedding::random(self.dimensions, rng);
self.embeddings.insert(node.clone(), embedding);
self.context_embeddings
.insert(node.clone(), context_embedding);
}
}
pub fn most_similar(&self, node: &N, k: usize) -> Result<Vec<(N, f64)>>
where
N: Clone,
{
let target_embedding = self
.embeddings
.get(node)
.ok_or(GraphError::node_not_found("node"))?;
let mut similarities = Vec::new();
for (other_node, other_embedding) in &self.embeddings {
if other_node != node {
let similarity = target_embedding.cosine_similarity(other_embedding)?;
similarities.push((other_node.clone(), similarity));
}
}
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Operation failed"));
similarities.truncate(k);
Ok(similarities)
}
pub fn generate_context_pairs(
walks: &[super::types::RandomWalk<N>],
window_size: usize,
) -> Vec<ContextPair<N>>
where
N: Clone,
{
let mut pairs = Vec::new();
for walk in walks {
for (i, target) in walk.nodes.iter().enumerate() {
let start = i.saturating_sub(window_size);
let end = (i + window_size + 1).min(walk.nodes.len());
for j in start..end {
if i != j {
pairs.push(ContextPair {
target: target.clone(),
context: walk.nodes[j].clone(),
});
}
}
}
}
pairs
}
pub fn train_skip_gram(
&mut self,
pairs: &[ContextPair<N>],
negative_sampler: &NegativeSampler<N>,
learning_rate: f64,
negative_samples: usize,
rng: &mut impl Rng,
) -> Result<()> {
for pair in pairs {
let target_emb = self
.embeddings
.get(&pair.target)
.ok_or(GraphError::node_not_found("node"))?
.clone();
let context_emb = self
.context_embeddings
.get(&pair.context)
.ok_or(GraphError::node_not_found("node"))?
.clone();
let positive_score = target_emb.dot_product(&context_emb)?;
let positive_prob = Embedding::sigmoid(positive_score);
let positive_error = 1.0 - positive_prob;
let mut target_gradient = vec![0.0; self.dimensions];
let mut context_gradient = vec![0.0; self.dimensions];
#[allow(clippy::needless_range_loop)]
for i in 0..self.dimensions {
target_gradient[i] += positive_error * context_emb.vector[i];
context_gradient[i] += positive_error * target_emb.vector[i];
}
let exclude_set: HashSet<&N> = [&pair.target, &pair.context].iter().cloned().collect();
let negatives = negative_sampler.sample_negatives(negative_samples, &exclude_set, rng);
for negative in &negatives {
if let Some(neg_context_emb) = self.context_embeddings.get(negative) {
let negative_score = target_emb.dot_product(neg_context_emb)?;
let negative_prob = Embedding::sigmoid(negative_score);
let negative_error = -negative_prob;
#[allow(clippy::needless_range_loop)]
for i in 0..self.dimensions {
target_gradient[i] += negative_error * neg_context_emb.vector[i];
}
}
}
for negative in &negatives {
if let Some(neg_context_emb_mut) = self.context_embeddings.get_mut(negative) {
let negative_score = target_emb.dot_product(neg_context_emb_mut)?;
let negative_prob = Embedding::sigmoid(negative_score);
let negative_error = -negative_prob;
#[allow(clippy::needless_range_loop)]
for i in 0..self.dimensions {
let neg_context_grad = negative_error * target_emb.vector[i];
neg_context_emb_mut.vector[i] -= learning_rate * neg_context_grad;
}
}
}
if let Some(target_emb_mut) = self.embeddings.get_mut(&pair.target) {
target_emb_mut.update_gradient(&target_gradient, learning_rate);
}
if let Some(context_emb_mut) = self.context_embeddings.get_mut(&pair.context) {
context_emb_mut.update_gradient(&context_gradient, learning_rate);
}
}
Ok(())
}
}