1use std::collections::HashMap;
61
62#[cfg(all(feature = "tensor", feature = "tensor-gpu"))]
63use dfdx::prelude::*;
64
65#[cfg(feature = "rand")]
66use rand::{random, Rng};
67
68#[derive(Debug, Clone)]
70pub struct GradientConfig {
71 pub temperature: f64,
73 pub use_ste: bool,
75 pub edge_learning_rate: f64,
77 pub node_learning_rate: f64,
79 pub sparsity_weight: f64,
81 pub smoothness_weight: f64,
83}
84
85impl Default for GradientConfig {
86 fn default() -> Self {
87 Self {
88 temperature: 1.0,
89 use_ste: true,
90 edge_learning_rate: 0.01,
91 node_learning_rate: 0.001,
92 sparsity_weight: 0.0,
93 smoothness_weight: 0.0,
94 }
95 }
96}
97
98impl GradientConfig {
99 pub fn new(temperature: f64, use_ste: bool, edge_lr: f64, node_lr: f64) -> Self {
101 Self {
102 temperature,
103 use_ste,
104 edge_learning_rate: edge_lr,
105 node_learning_rate: node_lr,
106 sparsity_weight: 0.0,
107 smoothness_weight: 0.0,
108 }
109 }
110
111 pub fn with_sparsity(mut self, weight: f64) -> Self {
113 self.sparsity_weight = weight;
114 self
115 }
116
117 pub fn with_smoothness(mut self, weight: f64) -> Self {
119 self.smoothness_weight = weight;
120 self
121 }
122
123 pub fn with_edge_learning_rate(mut self, lr: f64) -> Self {
125 self.edge_learning_rate = lr;
126 self
127 }
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
132pub enum EdgeEditOp {
133 Add,
135 Remove,
137 Modify,
139}
140
141#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
143pub enum NodeEditOp {
144 Add,
146 Remove,
148 Modify,
150}
151
152#[derive(Debug, Clone)]
154pub struct StructureEdit {
155 pub operation: EditOperation,
157 pub gradient: f64,
159 pub before: f64,
161 pub after: f64,
163}
164
165#[derive(Debug, Clone)]
167pub enum EditOperation {
168 EdgeEdit(usize, usize, EdgeEditOp),
170 NodeEdit(usize, NodeEditOp),
172}
173
174#[derive(Debug, Clone)]
176pub struct DifferentiableEdge {
177 pub src: usize,
179 pub dst: usize,
181 pub logits: f64,
183 pub probability: f64,
185 pub exists: bool,
187 pub gradient: Option<f64>,
189}
190
191impl DifferentiableEdge {
192 pub fn new(src: usize, dst: usize, init_probability: f64) -> Self {
194 let logits = Self::prob_to_logits(init_probability);
195 Self {
196 src,
197 dst,
198 logits,
199 probability: init_probability,
200 exists: init_probability > 0.5,
201 gradient: None,
202 }
203 }
204
205 fn prob_to_logits(prob: f64) -> f64 {
207 let p = prob.clamp(1e-7, 1.0 - 1e-7);
208 (p / (1.0 - p)).ln()
209 }
210
211 fn logits_to_prob(logits: f64, temperature: f64) -> f64 {
213 1.0 / (1.0 + (-logits / temperature).exp())
214 }
215
216 fn discretize(&mut self, temperature: f64, use_ste: bool) {
218 let prob = Self::logits_to_prob(self.logits, temperature);
219 self.probability = prob;
220 self.exists = prob > 0.5;
221
222 if use_ste {
223 }
227 }
228
229 pub fn update_logits(&mut self, gradient: f64, learning_rate: f64) {
237 self.logits -= learning_rate * gradient;
238 self.gradient = Some(gradient);
239 }
240}
241
242#[derive(Debug, Clone)]
244pub struct DifferentiableNode<T = Vec<f64>> {
245 pub id: usize,
247 pub existence_prob: f64,
249 pub features: Option<T>,
251 pub existence_gradient: Option<f64>,
253 pub feature_gradient: Option<T>,
255}
256
257impl<T: Clone> DifferentiableNode<T> {
258 pub fn new(id: usize, features: Option<T>) -> Self {
260 Self {
261 id,
262 existence_prob: 1.0,
263 features,
264 existence_gradient: None,
265 feature_gradient: None,
266 }
267 }
268
269 pub fn update_existence(&mut self, gradient: f64, learning_rate: f64) {
271 let new_prob = self.existence_prob + learning_rate * gradient;
272 self.existence_prob = new_prob.clamp(0.0, 1.0);
273 self.existence_gradient = Some(gradient);
274 }
275}
276
277#[derive(Debug, Clone)]
298pub struct DifferentiableGraph<T = Vec<f64>> {
299 num_nodes: usize,
301 edges: HashMap<(usize, usize), DifferentiableEdge>,
303 nodes: HashMap<usize, DifferentiableNode<T>>,
305 config: GradientConfig,
307 annealing_steps: usize,
309 current_step: usize,
311 use_ste: bool,
313 ste_corrections: HashMap<(usize, usize), f64>,
315}
316
317impl<T: Clone + Default> DifferentiableGraph<T> {
318 pub fn new(num_nodes: usize) -> Self {
320 Self {
321 num_nodes,
322 edges: HashMap::new(),
323 nodes: HashMap::new(),
324 config: GradientConfig::default(),
325 annealing_steps: 0,
326 current_step: 0,
327 use_ste: true,
328 ste_corrections: HashMap::new(),
329 }
330 }
331
332 pub fn with_config(num_nodes: usize, config: GradientConfig) -> Self {
334 let use_ste = config.use_ste;
335 Self {
336 num_nodes,
337 edges: HashMap::new(),
338 nodes: HashMap::new(),
339 config,
340 annealing_steps: 0,
341 current_step: 0,
342 use_ste,
343 ste_corrections: HashMap::new(),
344 }
345 }
346
347 pub fn init_nodes(&mut self, features: Option<T>) {
349 for i in 0..self.num_nodes {
350 self.nodes
351 .insert(i, DifferentiableNode::new(i, features.clone()));
352 }
353 }
354
355 pub fn add_learnable_edge(&mut self, src: usize, dst: usize, init_prob: f64) {
357 let edge = DifferentiableEdge::new(src, dst, init_prob);
358 self.edges.insert((src, dst), edge);
359 }
360
361 pub fn remove_edge(&mut self, src: usize, dst: usize) -> Option<DifferentiableEdge> {
363 self.edges.remove(&(src, dst))
364 }
365
366 pub fn get_edge_probability(&self, src: usize, dst: usize) -> Option<f64> {
368 self.edges.get(&(src, dst)).map(|e| e.probability)
369 }
370
371 pub fn get_edge_exists(&self, src: usize, dst: usize) -> Option<bool> {
373 self.edges.get(&(src, dst)).map(|e| e.exists)
374 }
375
376 pub fn get_probability_matrix(&self) -> Vec<Vec<f64>> {
378 let mut matrix = vec![vec![0.0; self.num_nodes]; self.num_nodes];
379 for ((src, dst), edge) in &self.edges {
380 matrix[*src][*dst] = edge.probability;
381 }
382 matrix
383 }
384
385 pub fn get_adjacency_matrix(&self) -> Vec<Vec<f64>> {
387 let mut matrix = vec![vec![0.0; self.num_nodes]; self.num_nodes];
388 for ((src, dst), edge) in &self.edges {
389 if edge.exists {
390 matrix[*src][*dst] = 1.0;
391 }
392 }
393 matrix
394 }
395
396 pub fn anneal_temperature(&mut self) {
398 if self.annealing_steps > 0 {
399 let progress = self.current_step as f64 / self.annealing_steps as f64;
400 let k = 3.0;
402 self.config.temperature = 1.0 * (-k * progress).exp();
403 self.config.temperature = self.config.temperature.max(0.1); }
405 self.current_step += 1;
406 }
407
408 pub fn with_temperature_annealing(mut self, steps: usize) -> Self {
410 self.annealing_steps = steps;
411 self
412 }
413
414 pub fn discretize(&mut self) {
419 self.ste_corrections.clear();
420
421 for (&(src, dst), edge) in &mut self.edges {
422 let prob_before = edge.probability;
423 edge.discretize(self.config.temperature, self.config.use_ste);
424
425 if self.use_ste {
427 let hard = if edge.exists { 1.0 } else { 0.0 };
428 let ste_correction = hard - prob_before;
429 self.ste_corrections.insert((src, dst), ste_correction);
430 }
431 }
432 }
433
434 pub fn compute_structure_gradients(
481 &mut self,
482 loss_gradients: &HashMap<(usize, usize), f64>,
483 ) -> HashMap<(usize, usize), f64> {
484 let mut gradients = HashMap::new();
485
486 for (&(src, dst), edge) in &self.edges {
488 if let Some(&loss_grad) = loss_gradients.get(&(src, dst)) {
489 let prob = edge.probability;
490 let logits = edge.logits;
491
492 let d_prob_d_logits = prob * (1.0 - prob) / self.config.temperature;
495 let mut logits_gradient = loss_grad * d_prob_d_logits;
496
497 if self.use_ste {
500 if let Some(&ste_correction) = self.ste_corrections.get(&(src, dst)) {
501 logits_gradient += ste_correction;
502 }
503 }
504
505 let sparse_grad = if self.config.sparsity_weight > 0.0 {
508 self.config.sparsity_weight * logits.signum()
509 } else {
510 0.0
511 };
512
513 let smooth_grad = if self.config.smoothness_weight > 0.0 {
516 self.compute_smoothness_gradient(src, dst, prob) * self.config.smoothness_weight
517 } else {
518 0.0
519 };
520
521 let total_gradient = logits_gradient + sparse_grad + smooth_grad;
522 gradients.insert((src, dst), total_gradient);
523 }
524 }
525
526 gradients
527 }
528
529 fn compute_smoothness_gradient(&self, src: usize, dst: usize, prob: f64) -> f64 {
535 let mut gradient = 0.0;
536
537 for (&(s, d), other_edge) in &self.edges {
539 let other_prob = other_edge.probability;
540
541 if s == src && d != dst {
543 gradient += 2.0 * (prob - other_prob);
544 }
545
546 if d == dst && s != src {
548 gradient += 2.0 * (prob - other_prob);
549 }
550 }
551
552 gradient
553 }
554
555 pub fn update_structure(&mut self, gradients: &HashMap<(usize, usize), f64>) {
557 for ((src, dst), &gradient) in gradients {
558 if let Some(edge) = self.edges.get_mut(&(*src, *dst)) {
559 edge.update_logits(gradient, self.config.edge_learning_rate);
560 }
561 }
562 }
563
564 pub fn optimization_step(
566 &mut self,
567 loss_gradients: HashMap<(usize, usize), f64>,
568 ) -> HashMap<(usize, usize), f64> {
569 self.discretize();
571
572 let gradients = self.compute_structure_gradients(&loss_gradients);
574
575 self.update_structure(&gradients);
577
578 self.anneal_temperature();
580
581 gradients
582 }
583
584 pub fn get_learnable_edges(&self) -> Vec<&DifferentiableEdge> {
586 self.edges.values().collect()
587 }
588
589 pub fn num_edges(&self) -> usize {
591 self.edges.len()
592 }
593
594 pub fn num_nodes(&self) -> usize {
596 self.num_nodes
597 }
598
599 pub fn config(&self) -> &GradientConfig {
601 &self.config
602 }
603
604 pub fn set_config(&mut self, config: GradientConfig) {
606 self.config = config;
607 }
608
609 pub fn temperature(&self) -> f64 {
611 self.config.temperature
612 }
613
614 pub fn set_temperature(&mut self, temp: f64) {
616 self.config.temperature = temp;
617 }
618
619 pub fn edges(&self) -> impl Iterator<Item = (&(usize, usize), &DifferentiableEdge)> {
621 self.edges.iter()
622 }
623
624 pub fn to_graph(&self) -> crate::graph::Graph<usize, f64> {
635 use crate::graph::traits::GraphOps;
636 use crate::graph::Graph;
637 use crate::node::NodeIndex;
638
639 let mut graph: crate::graph::Graph<usize, f64> =
640 Graph::with_capacity(self.num_nodes, self.edges.len());
641
642 let mut node_indices: Vec<NodeIndex> = Vec::with_capacity(self.num_nodes);
645 for i in 0..self.num_nodes {
646 let result = graph.add_node(i);
647 match result {
648 Ok(idx) => node_indices.push(idx),
649 Err(_) => {
650 node_indices.push(NodeIndex::new(i, 0));
652 }
653 }
654 }
655
656 for (&(src, dst), edge) in &self.edges {
658 if edge.exists && src < node_indices.len() && dst < node_indices.len() {
659 let _ = graph.add_edge(node_indices[src], node_indices[dst], 1.0);
660 }
661 }
662
663 graph
664 }
665
666 #[cfg(feature = "transformer")]
678 pub fn to_graph_with_types(
679 &self,
680 node_types: &std::collections::HashMap<usize, crate::transformer::optimization::switch::OperatorType>,
681 edge_weights: &std::collections::HashMap<(usize, usize), crate::transformer::optimization::switch::WeightTensor>,
682 ) -> crate::graph::Graph<crate::transformer::optimization::switch::OperatorType, crate::transformer::optimization::switch::WeightTensor> {
683 use crate::graph::traits::GraphOps;
684 use crate::graph::Graph;
685 use crate::node::NodeIndex;
686 use crate::transformer::optimization::switch::{OperatorType, WeightTensor};
687
688 let mut graph: Graph<OperatorType, WeightTensor> =
689 Graph::with_capacity(self.num_nodes, self.edges.len());
690
691 let mut node_indices: Vec<NodeIndex> = Vec::with_capacity(self.num_nodes);
693 for i in 0..self.num_nodes {
694 let node_type = node_types.get(&i)
695 .cloned()
696 .unwrap_or_else(|| OperatorType::Custom { name: format!("node_{}", i) });
697
698 let result = graph.add_node(node_type);
699 match result {
700 Ok(idx) => node_indices.push(idx),
701 Err(_) => {
702 node_indices.push(NodeIndex::new(i, 0));
704 }
705 }
706 }
707
708 for (&(src, dst), edge) in &self.edges {
710 if edge.exists && src < node_indices.len() && dst < node_indices.len() {
711 let weight = edge_weights.get(&(src, dst))
712 .cloned()
713 .unwrap_or_else(|| WeightTensor::new(
714 format!("edge_{}_to_{}", src, dst),
715 vec![1.0],
716 vec![1],
717 ));
718 let _ = graph.add_edge(node_indices[src], node_indices[dst], weight);
719 }
720 }
721
722 graph
723 }
724
725 pub fn from_graph<U, V>(
738 graph: &crate::graph::Graph<U, V>,
739 init_probs: Option<HashMap<(usize, usize), f64>>,
740 ) -> DifferentiableGraph<()>
741 where
742 U: Clone,
743 V: Clone,
744 {
745 use crate::graph::traits::{GraphBase, GraphQuery};
746
747 let num_nodes = graph.node_count();
748 let mut diff_graph = DifferentiableGraph::new(num_nodes);
749
750 if let Some(probs) = init_probs {
751 for ((src, dst), &prob) in &probs {
753 diff_graph.add_learnable_edge(*src, *dst, prob);
754 }
755 } else {
756 for node in graph.nodes() {
758 let src_idx = node.index().index();
759 for neighbor in graph.neighbors(node.index()) {
760 let dst_idx = neighbor.index();
761 diff_graph.add_learnable_edge(src_idx, dst_idx, 1.0);
762 }
763 }
764 }
765
766 diff_graph
767 }
768
769 pub fn from_graph_with_prob<U, V>(
782 graph: &crate::graph::Graph<U, V>,
783 init_prob: Option<f64>,
784 ) -> DifferentiableGraph<()>
785 where
786 U: Clone,
787 V: Clone,
788 {
789 use crate::graph::traits::{GraphBase, GraphQuery};
790
791 let num_nodes = graph.node_count();
792 let mut diff_graph = DifferentiableGraph::new(num_nodes);
793
794 let prob = init_prob.unwrap_or(1.0);
795
796 for node in graph.nodes() {
798 let src_idx = node.index().index();
799 for neighbor in graph.neighbors(node.index()) {
800 let dst_idx = neighbor.index();
801 diff_graph.add_learnable_edge(src_idx, dst_idx, prob);
802 }
803 }
804
805 diff_graph
806 }
807
808 pub fn set_ste(&mut self, use_ste: bool) {
810 self.use_ste = use_ste;
811 self.config.use_ste = use_ste;
812 }
813
814 pub fn get_ste_corrections(&self) -> &HashMap<(usize, usize), f64> {
816 &self.ste_corrections
817 }
818}
819
820pub struct GumbelSoftmaxSampler {
822 temperature: f64,
823}
824
825impl GumbelSoftmaxSampler {
826 pub fn new(temperature: f64) -> Self {
828 Self { temperature }
829 }
830
831 pub fn sample_soft(&self, logits: &[f64]) -> Vec<f64> {
836 let gumbel_noise: Vec<f64> = logits.iter().map(|_| self.gumbel_sample()).collect();
837
838 let max_logit = logits
839 .iter()
840 .zip(&gumbel_noise)
841 .map(|(&l, &g)| l + g)
842 .fold(f64::NEG_INFINITY, f64::max);
843
844 let exp_logits: Vec<f64> = logits
845 .iter()
846 .zip(&gumbel_noise)
847 .map(|(&l, &g)| ((l + g - max_logit) / self.temperature).exp())
848 .collect();
849
850 let sum_exp: f64 = exp_logits.iter().sum();
851
852 exp_logits.iter().map(|&e| e / sum_exp).collect()
853 }
854
855 pub fn sample_hard(&self, logits: &[f64]) -> Vec<f64> {
857 let soft = self.sample_soft(logits);
858 let mut result = vec![0.0; soft.len()];
859
860 if let Some(max_idx) = soft
862 .iter()
863 .enumerate()
864 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
865 .map(|(i, _)| i)
866 {
867 result[max_idx] = 1.0;
868 }
869
870 result
871 }
872
873 pub fn sample_ste(&self, logits: &[f64]) -> (Vec<f64>, Vec<f64>) {
875 let hard = self.sample_hard(logits);
876 let soft = self.sample_soft(logits);
877
878 (hard, soft)
881 }
882
883 fn gumbel_sample(&self) -> f64 {
885 #[cfg(feature = "rand")]
886 {
887 let u: f64 = random::<f64>();
888 -(-u.ln()).ln()
889 }
890 #[cfg(not(feature = "rand"))]
891 {
892 let u: f64 = 0.5;
895 -(-u.ln()).ln()
896 }
897 }
898
899 pub fn set_temperature(&mut self, temp: f64) {
901 self.temperature = temp;
902 }
903
904 #[cfg(feature = "rand")]
908 pub fn gumbel_sample_with_rng<R: Rng>(&self, rng: &mut R) -> f64 {
909 let u: f64 = rng.gen_range(1e-7..1.0 - 1e-7);
910 -(-u.ln()).ln()
911 }
912
913 #[cfg(feature = "rand")]
915 pub fn sample_soft_with_rng(&self, logits: &[f64], rng: &mut impl Rng) -> Vec<f64> {
916 let gumbel_noise: Vec<f64> = logits
917 .iter()
918 .map(|_| self.gumbel_sample_with_rng(rng))
919 .collect();
920
921 let max_logit = logits
922 .iter()
923 .zip(&gumbel_noise)
924 .map(|(&l, &g)| l + g)
925 .fold(f64::NEG_INFINITY, f64::max);
926
927 let exp_logits: Vec<f64> = logits
928 .iter()
929 .zip(&gumbel_noise)
930 .map(|(&l, &g)| ((l + g - max_logit) / self.temperature).exp())
931 .collect();
932
933 let sum_exp: f64 = exp_logits.iter().sum();
934
935 exp_logits.iter().map(|&e| e / sum_exp).collect()
936 }
937}
938
939pub trait EdgeEditPolicy: Send + Sync {
941 fn should_add_edge(&self, gradient: f64, current_prob: f64) -> bool;
943
944 fn should_remove_edge(&self, gradient: f64, current_prob: f64) -> bool;
946
947 fn update_probability(&self, current_prob: f64, gradient: f64, learning_rate: f64) -> f64;
949}
950
951#[derive(Debug, Clone)]
953pub struct ThresholdEditPolicy {
954 pub add_threshold: f64,
956 pub remove_threshold: f64,
958 pub min_prob: f64,
960 pub max_prob: f64,
962}
963
964impl Default for ThresholdEditPolicy {
965 fn default() -> Self {
966 Self {
967 add_threshold: 0.1,
968 remove_threshold: -0.1,
969 min_prob: 0.01,
970 max_prob: 0.99,
971 }
972 }
973}
974
975impl EdgeEditPolicy for ThresholdEditPolicy {
976 fn should_add_edge(&self, gradient: f64, current_prob: f64) -> bool {
977 gradient > self.add_threshold && current_prob < 0.5
978 }
979
980 fn should_remove_edge(&self, gradient: f64, current_prob: f64) -> bool {
981 gradient < self.remove_threshold && current_prob > 0.5
982 }
983
984 fn update_probability(&self, current_prob: f64, gradient: f64, learning_rate: f64) -> f64 {
985 let new_prob = current_prob + learning_rate * gradient;
986 new_prob.clamp(self.min_prob, self.max_prob)
987 }
988}
989
990#[derive(Debug, Default, Clone)]
992pub struct GradientRecorder {
993 edge_gradients: HashMap<(usize, usize), f64>,
995 node_gradients: HashMap<usize, f64>,
997 edge_velocities: HashMap<(usize, usize), f64>,
999 momentum: f64,
1001}
1002
1003impl GradientRecorder {
1004 pub fn new(momentum: f64) -> Self {
1006 Self {
1007 edge_gradients: HashMap::new(),
1008 node_gradients: HashMap::new(),
1009 edge_velocities: HashMap::new(),
1010 momentum,
1011 }
1012 }
1013
1014 pub fn record_edge_gradient(&mut self, src: usize, dst: usize, gradient: f64) {
1016 self.edge_gradients.insert((src, dst), gradient);
1017 }
1018
1019 pub fn record_node_gradient(&mut self, node_id: usize, gradient: f64) {
1021 self.node_gradients.insert(node_id, gradient);
1022 }
1023
1024 pub fn get_edge_gradient(&self, src: usize, dst: usize) -> Option<f64> {
1026 self.edge_gradients.get(&(src, dst)).copied()
1027 }
1028
1029 pub fn get_all_edge_gradients(&self) -> &HashMap<(usize, usize), f64> {
1031 &self.edge_gradients
1032 }
1033
1034 pub fn apply_momentum(&mut self) -> HashMap<(usize, usize), f64> {
1046 let mut momentum_gradients = HashMap::new();
1047
1048 for ((src, dst), &grad) in &self.edge_gradients {
1049 let last_velocity = self
1050 .edge_velocities
1051 .get(&(*src, *dst))
1052 .copied()
1053 .unwrap_or(0.0);
1054 let new_velocity = self.momentum * last_velocity + grad;
1056 self.edge_velocities.insert((*src, *dst), new_velocity);
1057 momentum_gradients.insert((*src, *dst), new_velocity);
1058 }
1059
1060 momentum_gradients
1061 }
1062
1063 pub fn clear(&mut self) {
1065 self.edge_gradients.clear();
1066 self.node_gradients.clear();
1067 }
1068
1069 pub fn reset(&mut self) {
1071 self.clear();
1072 self.edge_velocities.clear();
1073 }
1074}
1075
1076pub struct GraphTransformer<T> {
1078 policy: Box<dyn EdgeEditPolicy>,
1080 recorder: GradientRecorder,
1082 _marker: std::marker::PhantomData<T>,
1084}
1085
1086impl<T: Clone + Default> GraphTransformer<T> {
1087 pub fn new(policy: Box<dyn EdgeEditPolicy>) -> Self {
1089 Self {
1090 policy,
1091 recorder: GradientRecorder::new(0.9),
1092 _marker: std::marker::PhantomData,
1093 }
1094 }
1095
1096 pub fn transform(&mut self, graph: &mut DifferentiableGraph<T>) -> Vec<StructureEdit> {
1098 let mut edits = Vec::new();
1099
1100 let momentum_gradients = self.recorder.apply_momentum();
1102
1103 for ((src, dst), edge) in &mut graph.edges {
1105 if let Some(&gradient) = momentum_gradients.get(&(*src, *dst)) {
1106 let before = edge.probability;
1107
1108 if self.policy.should_remove_edge(gradient, edge.probability) {
1110 let new_prob = self.policy.update_probability(
1111 edge.probability,
1112 gradient,
1113 graph.config.edge_learning_rate,
1114 );
1115
1116 let after = new_prob;
1117 edge.probability = new_prob;
1118 edge.exists = new_prob > 0.5;
1119
1120 edits.push(StructureEdit {
1121 operation: EditOperation::EdgeEdit(*src, *dst, EdgeEditOp::Remove),
1122 gradient,
1123 before,
1124 after,
1125 });
1126 }
1127 else if self.policy.should_add_edge(gradient, edge.probability) {
1129 let new_prob = self.policy.update_probability(
1130 edge.probability,
1131 gradient,
1132 graph.config.edge_learning_rate,
1133 );
1134
1135 let after = new_prob;
1136 edge.probability = new_prob;
1137 edge.exists = new_prob > 0.5;
1138
1139 edits.push(StructureEdit {
1140 operation: EditOperation::EdgeEdit(*src, *dst, EdgeEditOp::Add),
1141 gradient,
1142 before,
1143 after,
1144 });
1145 }
1146 else {
1148 let new_prob = self.policy.update_probability(
1149 edge.probability,
1150 gradient,
1151 graph.config.edge_learning_rate,
1152 );
1153
1154 let after = new_prob;
1155 edge.probability = new_prob;
1156 edge.exists = new_prob > 0.5;
1157
1158 edits.push(StructureEdit {
1159 operation: EditOperation::EdgeEdit(*src, *dst, EdgeEditOp::Modify),
1160 gradient,
1161 before,
1162 after,
1163 });
1164 }
1165 }
1166 }
1167
1168 edits
1169 }
1170
1171 pub fn record_gradients(&mut self, gradients: &HashMap<(usize, usize), f64>) {
1173 for ((src, dst), &grad) in gradients {
1174 self.recorder.record_edge_gradient(*src, *dst, grad);
1175 }
1176 }
1177}
1178
1179#[cfg(test)]
1180mod tests {
1181 use super::*;
1182
1183 #[test]
1184 fn test_differentiable_edge() {
1185 let mut edge = DifferentiableEdge::new(0, 1, 0.5);
1186
1187 assert_eq!(edge.src, 0);
1188 assert_eq!(edge.dst, 1);
1189 assert!((edge.logits - 0.0).abs() < 1e-6); assert!((edge.probability - 0.5).abs() < 1e-6);
1191
1192 edge.update_logits(-0.1, 0.01); assert!(edge.logits > 0.0);
1195 }
1196
1197 #[test]
1198 fn test_differentiable_graph() {
1199 let mut graph = DifferentiableGraph::<Vec<f64>>::new(4);
1200
1201 graph.add_learnable_edge(0, 1, 0.5);
1203 graph.add_learnable_edge(1, 2, 0.8);
1204 graph.add_learnable_edge(2, 3, 0.3);
1205
1206 assert_eq!(graph.num_edges(), 3);
1207 assert_eq!(graph.num_nodes(), 4);
1208
1209 let prob_matrix = graph.get_probability_matrix();
1211 assert!((prob_matrix[0][1] - 0.5).abs() < 1e-6);
1212 assert!((prob_matrix[1][2] - 0.8).abs() < 1e-6);
1213
1214 graph.discretize();
1216 assert!(!graph.get_edge_exists(0, 1).unwrap());
1218 assert!(graph.get_edge_exists(1, 2).unwrap()); assert!(!graph.get_edge_exists(2, 3).unwrap()); }
1221
1222 #[test]
1223 fn test_structure_gradient_computation() {
1224 let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
1225 graph.add_learnable_edge(0, 1, 0.5);
1226 graph.add_learnable_edge(1, 2, 0.8);
1227
1228 graph.set_ste(false);
1231
1232 let mut loss_gradients = HashMap::new();
1234 loss_gradients.insert((0, 1), 0.5); loss_gradients.insert((1, 2), -0.3); let gradients = graph.compute_structure_gradients(&loss_gradients);
1238
1239 assert!(gradients.contains_key(&(0, 1)));
1240 assert!(gradients.contains_key(&(1, 2)));
1241
1242 assert!(*gradients.get(&(0, 1)).unwrap() > 0.0);
1244 assert!(*gradients.get(&(1, 2)).unwrap() < 0.0);
1246 }
1247
1248 #[test]
1249 fn test_gumbel_softmax_sampler() {
1250 let sampler = GumbelSoftmaxSampler::new(1.0);
1251 let logits = vec![1.0, 2.0, 3.0];
1252
1253 let soft = sampler.sample_soft(&logits);
1255 assert_eq!(soft.len(), 3);
1256 assert!((soft.iter().sum::<f64>() - 1.0).abs() < 1e-5); let hard = sampler.sample_hard(&logits);
1260 assert_eq!(hard.len(), 3);
1261 assert_eq!(hard.iter().filter(|&&x| x > 0.5).count(), 1); let (hard_ste, soft_ste) = sampler.sample_ste(&logits);
1265 assert_eq!(hard_ste.len(), 3);
1266 assert_eq!(soft_ste.len(), 3);
1267 }
1268
1269 #[test]
1270 fn test_threshold_edit_policy() {
1271 let policy = ThresholdEditPolicy::default();
1272
1273 assert!(policy.should_add_edge(0.2, 0.3)); assert!(!policy.should_add_edge(0.05, 0.3)); assert!(policy.should_remove_edge(-0.2, 0.7)); assert!(!policy.should_remove_edge(-0.05, 0.7)); let new_prob = policy.update_probability(0.5, 0.1, 0.01);
1283 assert!((new_prob - 0.501).abs() < 1e-6);
1284 }
1285
1286 #[test]
1287 fn test_gradient_recorder_with_momentum() {
1288 let mut recorder = GradientRecorder::new(0.9);
1289
1290 recorder.record_edge_gradient(0, 1, 0.5);
1291 recorder.record_edge_gradient(1, 2, -0.3);
1292
1293 let momentum_grads = recorder.apply_momentum();
1294
1295 assert!((momentum_grads.get(&(0, 1)).unwrap() - 0.5).abs() < 1e-6);
1297 assert!((momentum_grads.get(&(1, 2)).unwrap() + 0.3).abs() < 1e-6);
1298
1299 recorder.clear();
1301 recorder.record_edge_gradient(0, 1, 0.6);
1302 recorder.record_edge_gradient(1, 2, -0.2);
1303
1304 let momentum_grads2 = recorder.apply_momentum();
1305
1306 let expected_01 = 0.9 * 0.5 + 0.6;
1310 let expected_12 = 0.9 * (-0.3) + (-0.2);
1311
1312 assert!((momentum_grads2.get(&(0, 1)).unwrap() - expected_01).abs() < 1e-6);
1313 assert!((momentum_grads2.get(&(1, 2)).unwrap() - expected_12).abs() < 1e-6);
1314 }
1315
1316 #[test]
1317 fn test_optimization_step() {
1318 let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
1319 graph.add_learnable_edge(0, 1, 0.5);
1320 graph.add_learnable_edge(1, 2, 0.8);
1321
1322 let mut loss_gradients = HashMap::new();
1323 loss_gradients.insert((0, 1), 0.5);
1324 loss_gradients.insert((1, 2), -0.3);
1325
1326 let gradients = graph.optimization_step(loss_gradients);
1327
1328 assert!(gradients.contains_key(&(0, 1)));
1329 assert!(gradients.contains_key(&(1, 2)));
1330
1331 assert!(graph.temperature() <= 1.0);
1333 }
1334
1335 #[test]
1336 fn test_gradient_computation_with_low_temperature() {
1337 let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
1339 graph.add_learnable_edge(0, 1, 0.5);
1340 graph.config.temperature = 0.1; let mut loss_gradients = HashMap::new();
1343 loss_gradients.insert((0, 1), 1.0);
1344
1345 let gradients = graph.compute_structure_gradients(&loss_gradients);
1346
1347 for &grad in gradients.values() {
1349 assert!(grad.is_finite(), "Gradient should be finite, got {}", grad);
1350 }
1351 }
1352
1353 #[test]
1354 fn test_gradient_computation_with_zero_probability() {
1355 let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
1357 graph.add_learnable_edge(0, 1, 1e-7); let mut loss_gradients = HashMap::new();
1360 loss_gradients.insert((0, 1), 1.0);
1361
1362 let gradients = graph.compute_structure_gradients(&loss_gradients);
1363
1364 for &grad in gradients.values() {
1366 assert!(grad.is_finite(), "Gradient should be finite, got {}", grad);
1367 }
1368 }
1369
1370 #[test]
1371 fn test_gradient_computation_with_one_probability() {
1372 let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
1374 graph.add_learnable_edge(0, 1, 1.0 - 1e-7); let mut loss_gradients = HashMap::new();
1377 loss_gradients.insert((0, 1), 1.0);
1378
1379 let gradients = graph.compute_structure_gradients(&loss_gradients);
1380
1381 for &grad in gradients.values() {
1383 assert!(grad.is_finite(), "Gradient should be finite, got {}", grad);
1384 }
1385 }
1386
1387 #[test]
1388 fn test_smoothness_gradient_computation() {
1389 let mut graph = DifferentiableGraph::<Vec<f64>>::with_config(
1391 4,
1392 GradientConfig::new(1.0, true, 0.01, 0.01).with_smoothness(0.1),
1393 );
1394
1395 graph.add_learnable_edge(0, 1, 0.8);
1397 graph.add_learnable_edge(0, 2, 0.2);
1398 graph.add_learnable_edge(0, 3, 0.5);
1399
1400 let mut loss_gradients = HashMap::new();
1401 loss_gradients.insert((0, 1), -0.5);
1402 loss_gradients.insert((0, 2), -0.5);
1403 loss_gradients.insert((0, 3), -0.5);
1404
1405 let gradients = graph.compute_structure_gradients(&loss_gradients);
1406
1407 assert!(gradients.contains_key(&(0, 1)));
1411 assert!(gradients.contains_key(&(0, 2)));
1412 assert!(gradients.contains_key(&(0, 3)));
1413 }
1414
1415 #[test]
1416 fn test_sparsity_gradient_computation() {
1417 let mut graph = DifferentiableGraph::<Vec<f64>>::with_config(
1419 3,
1420 GradientConfig::new(1.0, true, 0.01, 0.01).with_sparsity(0.1),
1421 );
1422
1423 graph.add_learnable_edge(0, 1, 0.5);
1424 graph.add_learnable_edge(1, 2, 0.5);
1425
1426 if let Some(edge) = graph.edges.get_mut(&(0, 1)) {
1428 edge.logits = 2.0; }
1430 if let Some(edge) = graph.edges.get_mut(&(1, 2)) {
1431 edge.logits = -2.0; }
1433
1434 let mut loss_gradients = HashMap::new();
1435 loss_gradients.insert((0, 1), 0.0); loss_gradients.insert((1, 2), 0.0);
1437
1438 let gradients = graph.compute_structure_gradients(&loss_gradients);
1439
1440 assert!(*gradients.get(&(0, 1)).unwrap() > 0.0);
1442 assert!(*gradients.get(&(1, 2)).unwrap() < 0.0);
1444 }
1445
1446 #[test]
1447 fn test_ste_correction() {
1448 let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
1450 graph.add_learnable_edge(0, 1, 0.6); graph.add_learnable_edge(1, 2, 0.4); graph.discretize();
1454
1455 let corrections = graph.get_ste_corrections();
1456
1457 assert!((corrections.get(&(0, 1)).unwrap() - 0.4).abs() < 0.01);
1459 assert!((corrections.get(&(1, 2)).unwrap() + 0.4).abs() < 0.01);
1461 }
1462
1463 #[test]
1464 fn test_momentum_classical() {
1465 let mut recorder = GradientRecorder::new(0.9);
1467
1468 recorder.record_edge_gradient(0, 1, 1.0);
1470 let momentum_grads_1 = recorder.apply_momentum();
1471 assert!((momentum_grads_1.get(&(0, 1)).unwrap() - 1.0).abs() < 1e-6);
1473
1474 recorder.clear();
1476 recorder.record_edge_gradient(0, 1, 1.0);
1477 let momentum_grads_2 = recorder.apply_momentum();
1478 assert!((momentum_grads_2.get(&(0, 1)).unwrap() - 1.9).abs() < 1e-6);
1480
1481 recorder.clear();
1483 recorder.record_edge_gradient(0, 1, 1.0);
1484 let momentum_grads_3 = recorder.apply_momentum();
1485 assert!((momentum_grads_3.get(&(0, 1)).unwrap() - 2.71).abs() < 1e-6);
1487 }
1488
1489 #[test]
1490 fn test_graph_conversion() {
1491 use crate::graph::traits::{GraphBase, GraphQuery};
1493
1494 let mut diff_graph = DifferentiableGraph::<()>::new(4);
1495 diff_graph.add_learnable_edge(0, 1, 0.8);
1496 diff_graph.add_learnable_edge(1, 2, 0.3);
1497 diff_graph.add_learnable_edge(2, 3, 0.9);
1498
1499 diff_graph.discretize();
1501
1502 let graph = diff_graph.to_graph();
1504
1505 assert_eq!(graph.node_count(), 4);
1507
1508 let nodes: Vec<_> = graph.nodes().collect();
1511 assert_eq!(nodes.len(), 4);
1512
1513 let n0 = nodes[0].index();
1515 let n1 = nodes[1].index();
1516 let n2 = nodes[2].index();
1517 let n3 = nodes[3].index();
1518
1519 assert!(graph.has_edge(n0, n1)); assert!(!graph.has_edge(n1, n2)); assert!(graph.has_edge(n2, n3)); }
1524
1525 #[test]
1526 fn test_from_graph() {
1527 use crate::graph::builders::GraphBuilder;
1529
1530 let graph = GraphBuilder::directed()
1531 .with_nodes(vec![(0, ()), (1, ()), (2, ()), (3, ())])
1532 .with_edges(vec![(0, 1, 1.0), (1, 2, 1.0), (2, 3, 1.0)])
1533 .build()
1534 .unwrap();
1535
1536 let diff_graph = DifferentiableGraph::<()>::from_graph(&graph, None);
1538
1539 assert_eq!(diff_graph.num_nodes(), 4);
1540 assert_eq!(diff_graph.num_edges(), 3);
1541 assert!(diff_graph.get_edge_probability(0, 1).is_some());
1542 assert!(diff_graph.get_edge_probability(1, 2).is_some());
1543 assert!(diff_graph.get_edge_probability(2, 3).is_some());
1544 }
1545}