use std::collections::HashMap;
#[cfg(all(feature = "tensor", feature = "tensor-gpu"))]
use dfdx::prelude::*;
#[cfg(feature = "rand")]
use rand::{random, Rng};
#[derive(Debug, Clone)]
pub struct GradientConfig {
pub temperature: f64,
pub use_ste: bool,
pub edge_learning_rate: f64,
pub node_learning_rate: f64,
pub sparsity_weight: f64,
pub smoothness_weight: f64,
}
impl Default for GradientConfig {
fn default() -> Self {
Self {
temperature: 1.0,
use_ste: true,
edge_learning_rate: 0.01,
node_learning_rate: 0.001,
sparsity_weight: 0.0,
smoothness_weight: 0.0,
}
}
}
impl GradientConfig {
pub fn new(temperature: f64, use_ste: bool, edge_lr: f64, node_lr: f64) -> Self {
Self {
temperature,
use_ste,
edge_learning_rate: edge_lr,
node_learning_rate: node_lr,
sparsity_weight: 0.0,
smoothness_weight: 0.0,
}
}
pub fn with_sparsity(mut self, weight: f64) -> Self {
self.sparsity_weight = weight;
self
}
pub fn with_smoothness(mut self, weight: f64) -> Self {
self.smoothness_weight = weight;
self
}
pub fn with_edge_learning_rate(mut self, lr: f64) -> Self {
self.edge_learning_rate = lr;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum EdgeEditOp {
Add,
Remove,
Modify,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum NodeEditOp {
Add,
Remove,
Modify,
}
#[derive(Debug, Clone)]
pub struct StructureEdit {
pub operation: EditOperation,
pub gradient: f64,
pub before: f64,
pub after: f64,
}
#[derive(Debug, Clone)]
pub enum EditOperation {
EdgeEdit(usize, usize, EdgeEditOp),
NodeEdit(usize, NodeEditOp),
}
#[derive(Debug, Clone)]
pub struct DifferentiableEdge {
pub src: usize,
pub dst: usize,
pub logits: f64,
pub probability: f64,
pub exists: bool,
pub gradient: Option<f64>,
}
impl DifferentiableEdge {
pub fn new(src: usize, dst: usize, init_probability: f64) -> Self {
let logits = Self::prob_to_logits(init_probability);
Self {
src,
dst,
logits,
probability: init_probability,
exists: init_probability > 0.5,
gradient: None,
}
}
fn prob_to_logits(prob: f64) -> f64 {
let p = prob.clamp(1e-7, 1.0 - 1e-7);
(p / (1.0 - p)).ln()
}
fn logits_to_prob(logits: f64, temperature: f64) -> f64 {
1.0 / (1.0 + (-logits / temperature).exp())
}
fn discretize(&mut self, temperature: f64, use_ste: bool) {
let prob = Self::logits_to_prob(self.logits, temperature);
self.probability = prob;
self.exists = prob > 0.5;
if use_ste {
}
}
pub fn update_logits(&mut self, gradient: f64, learning_rate: f64) {
self.logits -= learning_rate * gradient;
self.gradient = Some(gradient);
}
}
#[derive(Debug, Clone)]
pub struct DifferentiableNode<T = Vec<f64>> {
pub id: usize,
pub existence_prob: f64,
pub features: Option<T>,
pub existence_gradient: Option<f64>,
pub feature_gradient: Option<T>,
}
impl<T: Clone> DifferentiableNode<T> {
pub fn new(id: usize, features: Option<T>) -> Self {
Self {
id,
existence_prob: 1.0,
features,
existence_gradient: None,
feature_gradient: None,
}
}
pub fn update_existence(&mut self, gradient: f64, learning_rate: f64) {
let new_prob = self.existence_prob + learning_rate * gradient;
self.existence_prob = new_prob.clamp(0.0, 1.0);
self.existence_gradient = Some(gradient);
}
}
#[derive(Debug, Clone)]
pub struct DifferentiableGraph<T = Vec<f64>> {
num_nodes: usize,
edges: HashMap<(usize, usize), DifferentiableEdge>,
nodes: HashMap<usize, DifferentiableNode<T>>,
config: GradientConfig,
annealing_steps: usize,
current_step: usize,
use_ste: bool,
ste_corrections: HashMap<(usize, usize), f64>,
}
impl<T: Clone + Default> DifferentiableGraph<T> {
pub fn new(num_nodes: usize) -> Self {
Self {
num_nodes,
edges: HashMap::new(),
nodes: HashMap::new(),
config: GradientConfig::default(),
annealing_steps: 0,
current_step: 0,
use_ste: true,
ste_corrections: HashMap::new(),
}
}
pub fn with_config(num_nodes: usize, config: GradientConfig) -> Self {
let use_ste = config.use_ste;
Self {
num_nodes,
edges: HashMap::new(),
nodes: HashMap::new(),
config,
annealing_steps: 0,
current_step: 0,
use_ste,
ste_corrections: HashMap::new(),
}
}
pub fn init_nodes(&mut self, features: Option<T>) {
for i in 0..self.num_nodes {
self.nodes
.insert(i, DifferentiableNode::new(i, features.clone()));
}
}
pub fn add_learnable_edge(&mut self, src: usize, dst: usize, init_prob: f64) {
let edge = DifferentiableEdge::new(src, dst, init_prob);
self.edges.insert((src, dst), edge);
}
pub fn remove_edge(&mut self, src: usize, dst: usize) -> Option<DifferentiableEdge> {
self.edges.remove(&(src, dst))
}
pub fn get_edge_probability(&self, src: usize, dst: usize) -> Option<f64> {
self.edges.get(&(src, dst)).map(|e| e.probability)
}
pub fn get_edge_exists(&self, src: usize, dst: usize) -> Option<bool> {
self.edges.get(&(src, dst)).map(|e| e.exists)
}
pub fn get_probability_matrix(&self) -> Vec<Vec<f64>> {
let mut matrix = vec![vec![0.0; self.num_nodes]; self.num_nodes];
for ((src, dst), edge) in &self.edges {
matrix[*src][*dst] = edge.probability;
}
matrix
}
pub fn get_adjacency_matrix(&self) -> Vec<Vec<f64>> {
let mut matrix = vec![vec![0.0; self.num_nodes]; self.num_nodes];
for ((src, dst), edge) in &self.edges {
if edge.exists {
matrix[*src][*dst] = 1.0;
}
}
matrix
}
pub fn anneal_temperature(&mut self) {
if self.annealing_steps > 0 {
let progress = self.current_step as f64 / self.annealing_steps as f64;
let k = 3.0;
self.config.temperature = 1.0 * (-k * progress).exp();
self.config.temperature = self.config.temperature.max(0.1); }
self.current_step += 1;
}
pub fn with_temperature_annealing(mut self, steps: usize) -> Self {
self.annealing_steps = steps;
self
}
pub fn discretize(&mut self) {
self.ste_corrections.clear();
for (&(src, dst), edge) in &mut self.edges {
let prob_before = edge.probability;
edge.discretize(self.config.temperature, self.config.use_ste);
if self.use_ste {
let hard = if edge.exists { 1.0 } else { 0.0 };
let ste_correction = hard - prob_before;
self.ste_corrections.insert((src, dst), ste_correction);
}
}
}
pub fn compute_structure_gradients(
&mut self,
loss_gradients: &HashMap<(usize, usize), f64>,
) -> HashMap<(usize, usize), f64> {
let mut gradients = HashMap::new();
for (&(src, dst), edge) in &self.edges {
if let Some(&loss_grad) = loss_gradients.get(&(src, dst)) {
let prob = edge.probability;
let logits = edge.logits;
let d_prob_d_logits = prob * (1.0 - prob) / self.config.temperature;
let mut logits_gradient = loss_grad * d_prob_d_logits;
if self.use_ste {
if let Some(&ste_correction) = self.ste_corrections.get(&(src, dst)) {
logits_gradient += ste_correction;
}
}
let sparse_grad = if self.config.sparsity_weight > 0.0 {
self.config.sparsity_weight * logits.signum()
} else {
0.0
};
let smooth_grad = if self.config.smoothness_weight > 0.0 {
self.compute_smoothness_gradient(src, dst, prob) * self.config.smoothness_weight
} else {
0.0
};
let total_gradient = logits_gradient + sparse_grad + smooth_grad;
gradients.insert((src, dst), total_gradient);
}
}
gradients
}
fn compute_smoothness_gradient(&self, src: usize, dst: usize, prob: f64) -> f64 {
let mut gradient = 0.0;
for (&(s, d), other_edge) in &self.edges {
let other_prob = other_edge.probability;
if s == src && d != dst {
gradient += 2.0 * (prob - other_prob);
}
if d == dst && s != src {
gradient += 2.0 * (prob - other_prob);
}
}
gradient
}
pub fn update_structure(&mut self, gradients: &HashMap<(usize, usize), f64>) {
for ((src, dst), &gradient) in gradients {
if let Some(edge) = self.edges.get_mut(&(*src, *dst)) {
edge.update_logits(gradient, self.config.edge_learning_rate);
}
}
}
pub fn optimization_step(
&mut self,
loss_gradients: HashMap<(usize, usize), f64>,
) -> HashMap<(usize, usize), f64> {
self.discretize();
let gradients = self.compute_structure_gradients(&loss_gradients);
self.update_structure(&gradients);
self.anneal_temperature();
gradients
}
pub fn get_learnable_edges(&self) -> Vec<&DifferentiableEdge> {
self.edges.values().collect()
}
pub fn num_edges(&self) -> usize {
self.edges.len()
}
pub fn num_nodes(&self) -> usize {
self.num_nodes
}
pub fn config(&self) -> &GradientConfig {
&self.config
}
pub fn set_config(&mut self, config: GradientConfig) {
self.config = config;
}
pub fn temperature(&self) -> f64 {
self.config.temperature
}
pub fn set_temperature(&mut self, temp: f64) {
self.config.temperature = temp;
}
pub fn edges(&self) -> impl Iterator<Item = (&(usize, usize), &DifferentiableEdge)> {
self.edges.iter()
}
pub fn to_graph(&self) -> crate::graph::Graph<usize, f64> {
use crate::graph::traits::GraphOps;
use crate::graph::Graph;
use crate::node::NodeIndex;
let mut graph: crate::graph::Graph<usize, f64> =
Graph::with_capacity(self.num_nodes, self.edges.len());
let mut node_indices: Vec<NodeIndex> = Vec::with_capacity(self.num_nodes);
for i in 0..self.num_nodes {
let result = graph.add_node(i);
match result {
Ok(idx) => node_indices.push(idx),
Err(_) => {
node_indices.push(NodeIndex::new(i, 0));
}
}
}
for (&(src, dst), edge) in &self.edges {
if edge.exists && src < node_indices.len() && dst < node_indices.len() {
let _ = graph.add_edge(node_indices[src], node_indices[dst], 1.0);
}
}
graph
}
#[cfg(feature = "transformer")]
pub fn to_graph_with_types(
&self,
node_types: &std::collections::HashMap<usize, crate::transformer::optimization::switch::OperatorType>,
edge_weights: &std::collections::HashMap<(usize, usize), crate::transformer::optimization::switch::WeightTensor>,
) -> crate::graph::Graph<crate::transformer::optimization::switch::OperatorType, crate::transformer::optimization::switch::WeightTensor> {
use crate::graph::traits::GraphOps;
use crate::graph::Graph;
use crate::node::NodeIndex;
use crate::transformer::optimization::switch::{OperatorType, WeightTensor};
let mut graph: Graph<OperatorType, WeightTensor> =
Graph::with_capacity(self.num_nodes, self.edges.len());
let mut node_indices: Vec<NodeIndex> = Vec::with_capacity(self.num_nodes);
for i in 0..self.num_nodes {
let node_type = node_types.get(&i)
.cloned()
.unwrap_or_else(|| OperatorType::Custom { name: format!("node_{}", i) });
let result = graph.add_node(node_type);
match result {
Ok(idx) => node_indices.push(idx),
Err(_) => {
node_indices.push(NodeIndex::new(i, 0));
}
}
}
for (&(src, dst), edge) in &self.edges {
if edge.exists && src < node_indices.len() && dst < node_indices.len() {
let weight = edge_weights.get(&(src, dst))
.cloned()
.unwrap_or_else(|| WeightTensor::new(
format!("edge_{}_to_{}", src, dst),
vec![1.0],
vec![1],
));
let _ = graph.add_edge(node_indices[src], node_indices[dst], weight);
}
}
graph
}
pub fn from_graph<U, V>(
graph: &crate::graph::Graph<U, V>,
init_probs: Option<HashMap<(usize, usize), f64>>,
) -> DifferentiableGraph<()>
where
U: Clone,
V: Clone,
{
use crate::graph::traits::{GraphBase, GraphQuery};
let num_nodes = graph.node_count();
let mut diff_graph = DifferentiableGraph::new(num_nodes);
if let Some(probs) = init_probs {
for ((src, dst), &prob) in &probs {
diff_graph.add_learnable_edge(*src, *dst, prob);
}
} else {
for node in graph.nodes() {
let src_idx = node.index().index();
for neighbor in graph.neighbors(node.index()) {
let dst_idx = neighbor.index();
diff_graph.add_learnable_edge(src_idx, dst_idx, 1.0);
}
}
}
diff_graph
}
pub fn from_graph_with_prob<U, V>(
graph: &crate::graph::Graph<U, V>,
init_prob: Option<f64>,
) -> DifferentiableGraph<()>
where
U: Clone,
V: Clone,
{
use crate::graph::traits::{GraphBase, GraphQuery};
let num_nodes = graph.node_count();
let mut diff_graph = DifferentiableGraph::new(num_nodes);
let prob = init_prob.unwrap_or(1.0);
for node in graph.nodes() {
let src_idx = node.index().index();
for neighbor in graph.neighbors(node.index()) {
let dst_idx = neighbor.index();
diff_graph.add_learnable_edge(src_idx, dst_idx, prob);
}
}
diff_graph
}
pub fn set_ste(&mut self, use_ste: bool) {
self.use_ste = use_ste;
self.config.use_ste = use_ste;
}
pub fn get_ste_corrections(&self) -> &HashMap<(usize, usize), f64> {
&self.ste_corrections
}
}
pub struct GumbelSoftmaxSampler {
temperature: f64,
}
impl GumbelSoftmaxSampler {
pub fn new(temperature: f64) -> Self {
Self { temperature }
}
pub fn sample_soft(&self, logits: &[f64]) -> Vec<f64> {
let gumbel_noise: Vec<f64> = logits.iter().map(|_| self.gumbel_sample()).collect();
let max_logit = logits
.iter()
.zip(&gumbel_noise)
.map(|(&l, &g)| l + g)
.fold(f64::NEG_INFINITY, f64::max);
let exp_logits: Vec<f64> = logits
.iter()
.zip(&gumbel_noise)
.map(|(&l, &g)| ((l + g - max_logit) / self.temperature).exp())
.collect();
let sum_exp: f64 = exp_logits.iter().sum();
exp_logits.iter().map(|&e| e / sum_exp).collect()
}
pub fn sample_hard(&self, logits: &[f64]) -> Vec<f64> {
let soft = self.sample_soft(logits);
let mut result = vec![0.0; soft.len()];
if let Some(max_idx) = soft
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
{
result[max_idx] = 1.0;
}
result
}
pub fn sample_ste(&self, logits: &[f64]) -> (Vec<f64>, Vec<f64>) {
let hard = self.sample_hard(logits);
let soft = self.sample_soft(logits);
(hard, soft)
}
fn gumbel_sample(&self) -> f64 {
#[cfg(feature = "rand")]
{
let u: f64 = random::<f64>();
-(-u.ln()).ln()
}
#[cfg(not(feature = "rand"))]
{
let u: f64 = 0.5;
-(-u.ln()).ln()
}
}
pub fn set_temperature(&mut self, temp: f64) {
self.temperature = temp;
}
#[cfg(feature = "rand")]
pub fn gumbel_sample_with_rng<R: Rng>(&self, rng: &mut R) -> f64 {
let u: f64 = rng.gen_range(1e-7..1.0 - 1e-7);
-(-u.ln()).ln()
}
#[cfg(feature = "rand")]
pub fn sample_soft_with_rng(&self, logits: &[f64], rng: &mut impl Rng) -> Vec<f64> {
let gumbel_noise: Vec<f64> = logits
.iter()
.map(|_| self.gumbel_sample_with_rng(rng))
.collect();
let max_logit = logits
.iter()
.zip(&gumbel_noise)
.map(|(&l, &g)| l + g)
.fold(f64::NEG_INFINITY, f64::max);
let exp_logits: Vec<f64> = logits
.iter()
.zip(&gumbel_noise)
.map(|(&l, &g)| ((l + g - max_logit) / self.temperature).exp())
.collect();
let sum_exp: f64 = exp_logits.iter().sum();
exp_logits.iter().map(|&e| e / sum_exp).collect()
}
}
pub trait EdgeEditPolicy: Send + Sync {
fn should_add_edge(&self, gradient: f64, current_prob: f64) -> bool;
fn should_remove_edge(&self, gradient: f64, current_prob: f64) -> bool;
fn update_probability(&self, current_prob: f64, gradient: f64, learning_rate: f64) -> f64;
}
#[derive(Debug, Clone)]
pub struct ThresholdEditPolicy {
pub add_threshold: f64,
pub remove_threshold: f64,
pub min_prob: f64,
pub max_prob: f64,
}
impl Default for ThresholdEditPolicy {
fn default() -> Self {
Self {
add_threshold: 0.1,
remove_threshold: -0.1,
min_prob: 0.01,
max_prob: 0.99,
}
}
}
impl EdgeEditPolicy for ThresholdEditPolicy {
fn should_add_edge(&self, gradient: f64, current_prob: f64) -> bool {
gradient > self.add_threshold && current_prob < 0.5
}
fn should_remove_edge(&self, gradient: f64, current_prob: f64) -> bool {
gradient < self.remove_threshold && current_prob > 0.5
}
fn update_probability(&self, current_prob: f64, gradient: f64, learning_rate: f64) -> f64 {
let new_prob = current_prob + learning_rate * gradient;
new_prob.clamp(self.min_prob, self.max_prob)
}
}
#[derive(Debug, Default, Clone)]
pub struct GradientRecorder {
edge_gradients: HashMap<(usize, usize), f64>,
node_gradients: HashMap<usize, f64>,
edge_velocities: HashMap<(usize, usize), f64>,
momentum: f64,
}
impl GradientRecorder {
pub fn new(momentum: f64) -> Self {
Self {
edge_gradients: HashMap::new(),
node_gradients: HashMap::new(),
edge_velocities: HashMap::new(),
momentum,
}
}
pub fn record_edge_gradient(&mut self, src: usize, dst: usize, gradient: f64) {
self.edge_gradients.insert((src, dst), gradient);
}
pub fn record_node_gradient(&mut self, node_id: usize, gradient: f64) {
self.node_gradients.insert(node_id, gradient);
}
pub fn get_edge_gradient(&self, src: usize, dst: usize) -> Option<f64> {
self.edge_gradients.get(&(src, dst)).copied()
}
pub fn get_all_edge_gradients(&self) -> &HashMap<(usize, usize), f64> {
&self.edge_gradients
}
pub fn apply_momentum(&mut self) -> HashMap<(usize, usize), f64> {
let mut momentum_gradients = HashMap::new();
for ((src, dst), &grad) in &self.edge_gradients {
let last_velocity = self
.edge_velocities
.get(&(*src, *dst))
.copied()
.unwrap_or(0.0);
let new_velocity = self.momentum * last_velocity + grad;
self.edge_velocities.insert((*src, *dst), new_velocity);
momentum_gradients.insert((*src, *dst), new_velocity);
}
momentum_gradients
}
pub fn clear(&mut self) {
self.edge_gradients.clear();
self.node_gradients.clear();
}
pub fn reset(&mut self) {
self.clear();
self.edge_velocities.clear();
}
}
pub struct GraphTransformer<T> {
policy: Box<dyn EdgeEditPolicy>,
recorder: GradientRecorder,
_marker: std::marker::PhantomData<T>,
}
impl<T: Clone + Default> GraphTransformer<T> {
pub fn new(policy: Box<dyn EdgeEditPolicy>) -> Self {
Self {
policy,
recorder: GradientRecorder::new(0.9),
_marker: std::marker::PhantomData,
}
}
pub fn transform(&mut self, graph: &mut DifferentiableGraph<T>) -> Vec<StructureEdit> {
let mut edits = Vec::new();
let momentum_gradients = self.recorder.apply_momentum();
for ((src, dst), edge) in &mut graph.edges {
if let Some(&gradient) = momentum_gradients.get(&(*src, *dst)) {
let before = edge.probability;
if self.policy.should_remove_edge(gradient, edge.probability) {
let new_prob = self.policy.update_probability(
edge.probability,
gradient,
graph.config.edge_learning_rate,
);
let after = new_prob;
edge.probability = new_prob;
edge.exists = new_prob > 0.5;
edits.push(StructureEdit {
operation: EditOperation::EdgeEdit(*src, *dst, EdgeEditOp::Remove),
gradient,
before,
after,
});
}
else if self.policy.should_add_edge(gradient, edge.probability) {
let new_prob = self.policy.update_probability(
edge.probability,
gradient,
graph.config.edge_learning_rate,
);
let after = new_prob;
edge.probability = new_prob;
edge.exists = new_prob > 0.5;
edits.push(StructureEdit {
operation: EditOperation::EdgeEdit(*src, *dst, EdgeEditOp::Add),
gradient,
before,
after,
});
}
else {
let new_prob = self.policy.update_probability(
edge.probability,
gradient,
graph.config.edge_learning_rate,
);
let after = new_prob;
edge.probability = new_prob;
edge.exists = new_prob > 0.5;
edits.push(StructureEdit {
operation: EditOperation::EdgeEdit(*src, *dst, EdgeEditOp::Modify),
gradient,
before,
after,
});
}
}
}
edits
}
pub fn record_gradients(&mut self, gradients: &HashMap<(usize, usize), f64>) {
for ((src, dst), &grad) in gradients {
self.recorder.record_edge_gradient(*src, *dst, grad);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_differentiable_edge() {
let mut edge = DifferentiableEdge::new(0, 1, 0.5);
assert_eq!(edge.src, 0);
assert_eq!(edge.dst, 1);
assert!((edge.logits - 0.0).abs() < 1e-6); assert!((edge.probability - 0.5).abs() < 1e-6);
edge.update_logits(-0.1, 0.01); assert!(edge.logits > 0.0);
}
#[test]
fn test_differentiable_graph() {
let mut graph = DifferentiableGraph::<Vec<f64>>::new(4);
graph.add_learnable_edge(0, 1, 0.5);
graph.add_learnable_edge(1, 2, 0.8);
graph.add_learnable_edge(2, 3, 0.3);
assert_eq!(graph.num_edges(), 3);
assert_eq!(graph.num_nodes(), 4);
let prob_matrix = graph.get_probability_matrix();
assert!((prob_matrix[0][1] - 0.5).abs() < 1e-6);
assert!((prob_matrix[1][2] - 0.8).abs() < 1e-6);
graph.discretize();
assert!(!graph.get_edge_exists(0, 1).unwrap());
assert!(graph.get_edge_exists(1, 2).unwrap()); assert!(!graph.get_edge_exists(2, 3).unwrap()); }
#[test]
fn test_structure_gradient_computation() {
let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
graph.add_learnable_edge(0, 1, 0.5);
graph.add_learnable_edge(1, 2, 0.8);
graph.set_ste(false);
let mut loss_gradients = HashMap::new();
loss_gradients.insert((0, 1), 0.5); loss_gradients.insert((1, 2), -0.3);
let gradients = graph.compute_structure_gradients(&loss_gradients);
assert!(gradients.contains_key(&(0, 1)));
assert!(gradients.contains_key(&(1, 2)));
assert!(*gradients.get(&(0, 1)).unwrap() > 0.0);
assert!(*gradients.get(&(1, 2)).unwrap() < 0.0);
}
#[test]
fn test_gumbel_softmax_sampler() {
let sampler = GumbelSoftmaxSampler::new(1.0);
let logits = vec![1.0, 2.0, 3.0];
let soft = sampler.sample_soft(&logits);
assert_eq!(soft.len(), 3);
assert!((soft.iter().sum::<f64>() - 1.0).abs() < 1e-5);
let hard = sampler.sample_hard(&logits);
assert_eq!(hard.len(), 3);
assert_eq!(hard.iter().filter(|&&x| x > 0.5).count(), 1);
let (hard_ste, soft_ste) = sampler.sample_ste(&logits);
assert_eq!(hard_ste.len(), 3);
assert_eq!(soft_ste.len(), 3);
}
#[test]
fn test_threshold_edit_policy() {
let policy = ThresholdEditPolicy::default();
assert!(policy.should_add_edge(0.2, 0.3)); assert!(!policy.should_add_edge(0.05, 0.3));
assert!(policy.should_remove_edge(-0.2, 0.7)); assert!(!policy.should_remove_edge(-0.05, 0.7));
let new_prob = policy.update_probability(0.5, 0.1, 0.01);
assert!((new_prob - 0.501).abs() < 1e-6);
}
#[test]
fn test_gradient_recorder_with_momentum() {
let mut recorder = GradientRecorder::new(0.9);
recorder.record_edge_gradient(0, 1, 0.5);
recorder.record_edge_gradient(1, 2, -0.3);
let momentum_grads = recorder.apply_momentum();
assert!((momentum_grads.get(&(0, 1)).unwrap() - 0.5).abs() < 1e-6);
assert!((momentum_grads.get(&(1, 2)).unwrap() + 0.3).abs() < 1e-6);
recorder.clear();
recorder.record_edge_gradient(0, 1, 0.6);
recorder.record_edge_gradient(1, 2, -0.2);
let momentum_grads2 = recorder.apply_momentum();
let expected_01 = 0.9 * 0.5 + 0.6;
let expected_12 = 0.9 * (-0.3) + (-0.2);
assert!((momentum_grads2.get(&(0, 1)).unwrap() - expected_01).abs() < 1e-6);
assert!((momentum_grads2.get(&(1, 2)).unwrap() - expected_12).abs() < 1e-6);
}
#[test]
fn test_optimization_step() {
let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
graph.add_learnable_edge(0, 1, 0.5);
graph.add_learnable_edge(1, 2, 0.8);
let mut loss_gradients = HashMap::new();
loss_gradients.insert((0, 1), 0.5);
loss_gradients.insert((1, 2), -0.3);
let gradients = graph.optimization_step(loss_gradients);
assert!(gradients.contains_key(&(0, 1)));
assert!(gradients.contains_key(&(1, 2)));
assert!(graph.temperature() <= 1.0);
}
#[test]
fn test_gradient_computation_with_low_temperature() {
let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
graph.add_learnable_edge(0, 1, 0.5);
graph.config.temperature = 0.1;
let mut loss_gradients = HashMap::new();
loss_gradients.insert((0, 1), 1.0);
let gradients = graph.compute_structure_gradients(&loss_gradients);
for &grad in gradients.values() {
assert!(grad.is_finite(), "Gradient should be finite, got {}", grad);
}
}
#[test]
fn test_gradient_computation_with_zero_probability() {
let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
graph.add_learnable_edge(0, 1, 1e-7);
let mut loss_gradients = HashMap::new();
loss_gradients.insert((0, 1), 1.0);
let gradients = graph.compute_structure_gradients(&loss_gradients);
for &grad in gradients.values() {
assert!(grad.is_finite(), "Gradient should be finite, got {}", grad);
}
}
#[test]
fn test_gradient_computation_with_one_probability() {
let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
graph.add_learnable_edge(0, 1, 1.0 - 1e-7);
let mut loss_gradients = HashMap::new();
loss_gradients.insert((0, 1), 1.0);
let gradients = graph.compute_structure_gradients(&loss_gradients);
for &grad in gradients.values() {
assert!(grad.is_finite(), "Gradient should be finite, got {}", grad);
}
}
#[test]
fn test_smoothness_gradient_computation() {
let mut graph = DifferentiableGraph::<Vec<f64>>::with_config(
4,
GradientConfig::new(1.0, true, 0.01, 0.01).with_smoothness(0.1),
);
graph.add_learnable_edge(0, 1, 0.8);
graph.add_learnable_edge(0, 2, 0.2);
graph.add_learnable_edge(0, 3, 0.5);
let mut loss_gradients = HashMap::new();
loss_gradients.insert((0, 1), -0.5);
loss_gradients.insert((0, 2), -0.5);
loss_gradients.insert((0, 3), -0.5);
let gradients = graph.compute_structure_gradients(&loss_gradients);
assert!(gradients.contains_key(&(0, 1)));
assert!(gradients.contains_key(&(0, 2)));
assert!(gradients.contains_key(&(0, 3)));
}
#[test]
fn test_sparsity_gradient_computation() {
let mut graph = DifferentiableGraph::<Vec<f64>>::with_config(
3,
GradientConfig::new(1.0, true, 0.01, 0.01).with_sparsity(0.1),
);
graph.add_learnable_edge(0, 1, 0.5);
graph.add_learnable_edge(1, 2, 0.5);
if let Some(edge) = graph.edges.get_mut(&(0, 1)) {
edge.logits = 2.0; }
if let Some(edge) = graph.edges.get_mut(&(1, 2)) {
edge.logits = -2.0; }
let mut loss_gradients = HashMap::new();
loss_gradients.insert((0, 1), 0.0); loss_gradients.insert((1, 2), 0.0);
let gradients = graph.compute_structure_gradients(&loss_gradients);
assert!(*gradients.get(&(0, 1)).unwrap() > 0.0);
assert!(*gradients.get(&(1, 2)).unwrap() < 0.0);
}
#[test]
fn test_ste_correction() {
let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
graph.add_learnable_edge(0, 1, 0.6); graph.add_learnable_edge(1, 2, 0.4);
graph.discretize();
let corrections = graph.get_ste_corrections();
assert!((corrections.get(&(0, 1)).unwrap() - 0.4).abs() < 0.01);
assert!((corrections.get(&(1, 2)).unwrap() + 0.4).abs() < 0.01);
}
#[test]
fn test_momentum_classical() {
let mut recorder = GradientRecorder::new(0.9);
recorder.record_edge_gradient(0, 1, 1.0);
let momentum_grads_1 = recorder.apply_momentum();
assert!((momentum_grads_1.get(&(0, 1)).unwrap() - 1.0).abs() < 1e-6);
recorder.clear();
recorder.record_edge_gradient(0, 1, 1.0);
let momentum_grads_2 = recorder.apply_momentum();
assert!((momentum_grads_2.get(&(0, 1)).unwrap() - 1.9).abs() < 1e-6);
recorder.clear();
recorder.record_edge_gradient(0, 1, 1.0);
let momentum_grads_3 = recorder.apply_momentum();
assert!((momentum_grads_3.get(&(0, 1)).unwrap() - 2.71).abs() < 1e-6);
}
#[test]
fn test_graph_conversion() {
use crate::graph::traits::{GraphBase, GraphQuery};
let mut diff_graph = DifferentiableGraph::<()>::new(4);
diff_graph.add_learnable_edge(0, 1, 0.8);
diff_graph.add_learnable_edge(1, 2, 0.3);
diff_graph.add_learnable_edge(2, 3, 0.9);
diff_graph.discretize();
let graph = diff_graph.to_graph();
assert_eq!(graph.node_count(), 4);
let nodes: Vec<_> = graph.nodes().collect();
assert_eq!(nodes.len(), 4);
let n0 = nodes[0].index();
let n1 = nodes[1].index();
let n2 = nodes[2].index();
let n3 = nodes[3].index();
assert!(graph.has_edge(n0, n1)); assert!(!graph.has_edge(n1, n2)); assert!(graph.has_edge(n2, n3)); }
#[test]
fn test_from_graph() {
use crate::graph::builders::GraphBuilder;
let graph = GraphBuilder::directed()
.with_nodes(vec![(0, ()), (1, ()), (2, ()), (3, ())])
.with_edges(vec![(0, 1, 1.0), (1, 2, 1.0), (2, 3, 1.0)])
.build()
.unwrap();
let diff_graph = DifferentiableGraph::<()>::from_graph(&graph, None);
assert_eq!(diff_graph.num_nodes(), 4);
assert_eq!(diff_graph.num_edges(), 3);
assert!(diff_graph.get_edge_probability(0, 1).is_some());
assert!(diff_graph.get_edge_probability(1, 2).is_some());
assert!(diff_graph.get_edge_probability(2, 3).is_some());
}
}