Skip to main content

god_graph/tensor/
differentiable.rs

1//! 可微图结构变换模块
2//!
3//! 本模块实现了图结构变换操作的梯度计算,支持:
4//! - 可微边编辑(添加/删除/修改边权重)
5//! - 可微节点编辑(添加/删除节点)
6//! - Straight-Through Estimator (STE) 用于离散操作
7//! - Gumbel-Softmax 松弛用于可微采样
8//! - 图结构优化的梯度传播
9//!
10//! ## 核心概念
11//!
12//! ### 连续松弛表示
13//!
14//! 传统图结构是离散的:边要么存在 (1) 要么不存在 (0)。
15//! 为了支持梯度计算,我们使用连续松弛:
16//!
17//! ```text
18//! A_soft = σ(A_logits / τ)
19//!
20//! 其中:
21//! - A_logits: 边的对数几率(可学习参数)
22//! - τ: 温度参数(控制离散程度)
23//! - σ: sigmoid 函数
24//! ```
25//!
26//! ### Straight-Through Estimator (STE)
27//!
28//! 对于需要离散输出的场景,使用 STE:
29//! - 前向传播:硬阈值(0/1)
30//! - 反向传播:软梯度(通过 sigmoid)
31//!
32//! ```text
33//! A_hard = (A_soft > 0.5).to_f64()
34//! gradient = A_hard - A_soft.detach() + A_soft
35//! ```
36//!
37//! ## 示例
38//!
39//! ```ignore
40//! use god_gragh::graph::Graph;
41//! use god_gragh::tensor::differentiable::{
42//!     DifferentiableGraph, EdgeEditPolicy, GradientConfig
43//! };
44//!
45//! // 创建可微图
46//! let mut diff_graph = DifferentiableGraph::new(4);
47//!
48//! // 添加可学习边
49//! diff_graph.add_learnable_edge(0, 1, 0.5);
50//! diff_graph.add_learnable_edge(1, 2, 0.8);
51//!
52//! // 计算损失对边权重的梯度
53//! let loss = compute_loss(&diff_graph);
54//! let gradients = diff_graph.compute_structure_gradients(loss);
55//!
56//! // 基于梯度更新结构
57//! diff_graph.update_structure(&gradients, learning_rate=0.01);
58//! ```
59
60use std::collections::HashMap;
61
62#[cfg(all(feature = "tensor", feature = "tensor-gpu"))]
63use dfdx::prelude::*;
64
65#[cfg(feature = "rand")]
66use rand::{random, Rng};
67
68/// 图结构变换的梯度配置
69#[derive(Debug, Clone)]
70pub struct GradientConfig {
71    /// 温度参数(用于 Gumbel-Softmax)
72    pub temperature: f64,
73    /// 是否使用 Straight-Through Estimator
74    pub use_ste: bool,
75    /// 边编辑的学习率
76    pub edge_learning_rate: f64,
77    /// 节点编辑的学习率
78    pub node_learning_rate: f64,
79    /// 结构正则化权重(L1 稀疏)
80    pub sparsity_weight: f64,
81    /// 结构正则化权重(L2 平滑)
82    pub smoothness_weight: f64,
83}
84
85impl Default for GradientConfig {
86    fn default() -> Self {
87        Self {
88            temperature: 1.0,
89            use_ste: true,
90            edge_learning_rate: 0.01,
91            node_learning_rate: 0.001,
92            sparsity_weight: 0.0,
93            smoothness_weight: 0.0,
94        }
95    }
96}
97
98impl GradientConfig {
99    /// 创建新的梯度配置
100    pub fn new(temperature: f64, use_ste: bool, edge_lr: f64, node_lr: f64) -> Self {
101        Self {
102            temperature,
103            use_ste,
104            edge_learning_rate: edge_lr,
105            node_learning_rate: node_lr,
106            sparsity_weight: 0.0,
107            smoothness_weight: 0.0,
108        }
109    }
110
111    /// 启用稀疏正则化
112    pub fn with_sparsity(mut self, weight: f64) -> Self {
113        self.sparsity_weight = weight;
114        self
115    }
116
117    /// 启用平滑正则化
118    pub fn with_smoothness(mut self, weight: f64) -> Self {
119        self.smoothness_weight = weight;
120        self
121    }
122
123    /// 设置边学习率
124    pub fn with_edge_learning_rate(mut self, lr: f64) -> Self {
125        self.edge_learning_rate = lr;
126        self
127    }
128}
129
130/// 边编辑操作类型
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
132pub enum EdgeEditOp {
133    /// 添加边
134    Add,
135    /// 删除边
136    Remove,
137    /// 修改边权重
138    Modify,
139}
140
141/// 节点编辑操作类型
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
143pub enum NodeEditOp {
144    /// 添加节点
145    Add,
146    /// 删除节点
147    Remove,
148    /// 修改节点特征
149    Modify,
150}
151
152/// 结构编辑操作(带梯度信息)
153#[derive(Debug, Clone)]
154pub struct StructureEdit {
155    /// 操作类型
156    pub operation: EditOperation,
157    /// 梯度值
158    pub gradient: f64,
159    /// 编辑前的值
160    pub before: f64,
161    /// 编辑后的值
162    pub after: f64,
163}
164
165/// 编辑操作枚举
166#[derive(Debug, Clone)]
167pub enum EditOperation {
168    /// 边编辑 (src, dst, operation_type)
169    EdgeEdit(usize, usize, EdgeEditOp),
170    /// 节点编辑 (node_id, operation_type)
171    NodeEdit(usize, NodeEditOp),
172}
173
174/// 可微边:包含可学习的存在概率
175#[derive(Debug, Clone)]
176pub struct DifferentiableEdge {
177    /// 源节点索引
178    pub src: usize,
179    /// 目标节点索引
180    pub dst: usize,
181    /// 边的对数几率(logits)
182    pub logits: f64,
183    /// 边的存在概率(由 logits 计算)
184    pub probability: f64,
185    /// 离散化后的存在性(0 或 1)
186    pub exists: bool,
187    /// 梯度值
188    pub gradient: Option<f64>,
189}
190
191impl DifferentiableEdge {
192    /// 创建新的可微边
193    pub fn new(src: usize, dst: usize, init_probability: f64) -> Self {
194        let logits = Self::prob_to_logits(init_probability);
195        Self {
196            src,
197            dst,
198            logits,
199            probability: init_probability,
200            exists: init_probability > 0.5,
201            gradient: None,
202        }
203    }
204
205    /// 概率转 logits
206    fn prob_to_logits(prob: f64) -> f64 {
207        let p = prob.clamp(1e-7, 1.0 - 1e-7);
208        (p / (1.0 - p)).ln()
209    }
210
211    /// logits 转概率(带温度)
212    fn logits_to_prob(logits: f64, temperature: f64) -> f64 {
213        1.0 / (1.0 + (-logits / temperature).exp())
214    }
215
216    /// 离散化(使用 STE)
217    fn discretize(&mut self, temperature: f64, use_ste: bool) {
218        let prob = Self::logits_to_prob(self.logits, temperature);
219        self.probability = prob;
220        self.exists = prob > 0.5;
221
222        if use_ste {
223            // STE: 前向离散,反向连续
224            // gradient = exists - prob.detach() + prob
225            // 这里我们存储概率,梯度计算在外层
226        }
227    }
228
229    /// 基于梯度更新 logits(梯度下降)
230    /// 
231    /// # Gradient Descent
232    /// 
233    /// logits -= learning_rate * gradient
234    /// 
235    /// 其中 gradient = ∂L/∂logits(增加损失的方向)
236    pub fn update_logits(&mut self, gradient: f64, learning_rate: f64) {
237        self.logits -= learning_rate * gradient;
238        self.gradient = Some(gradient);
239    }
240}
241
242/// 可微节点:包含可学习的存在概率和特征
243#[derive(Debug, Clone)]
244pub struct DifferentiableNode<T = Vec<f64>> {
245    /// 节点索引
246    pub id: usize,
247    /// 节点存在概率
248    pub existence_prob: f64,
249    /// 节点特征(可选)
250    pub features: Option<T>,
251    /// 存在性的梯度
252    pub existence_gradient: Option<f64>,
253    /// 特征的梯度(如果是 tensor)
254    pub feature_gradient: Option<T>,
255}
256
257impl<T: Clone> DifferentiableNode<T> {
258    /// 创建新的可微节点
259    pub fn new(id: usize, features: Option<T>) -> Self {
260        Self {
261            id,
262            existence_prob: 1.0,
263            features,
264            existence_gradient: None,
265            feature_gradient: None,
266        }
267    }
268
269    /// 更新存在性
270    pub fn update_existence(&mut self, gradient: f64, learning_rate: f64) {
271        let new_prob = self.existence_prob + learning_rate * gradient;
272        self.existence_prob = new_prob.clamp(0.0, 1.0);
273        self.existence_gradient = Some(gradient);
274    }
275}
276
277/// 可微图结构:支持梯度计算的结构变换
278///
279/// 核心思想:将离散的图结构参数化为连续空间,
280/// 使得梯度可以反向传播到结构参数。
281///
282/// # Architecture Notes
283///
284/// ## 与自动微分框架的集成
285///
286/// 当前实现使用手动梯度计算。要与真正的自动微分框架(如 dfdx)集成,
287/// 需要:
288///
289/// 1. 将 `logits` 存储为 `Tensor1D<f64>` 而非 `f64`
290/// 2. 构建计算图:logits → probability → adjacency_matrix → loss
291/// 3. 调用 `loss.backward()` 获取梯度
292///
293/// ## 与 Graph 的转换
294///
295/// 使用 `to_graph()` 将可微图转换为普通 `Graph`,
296/// 使用 `from_graph()` 从现有图初始化可微图。
297#[derive(Debug, Clone)]
298pub struct DifferentiableGraph<T = Vec<f64>> {
299    /// 节点数
300    num_nodes: usize,
301    /// 可微边集合(key: (src, dst))
302    edges: HashMap<(usize, usize), DifferentiableEdge>,
303    /// 可微节点集合
304    nodes: HashMap<usize, DifferentiableNode<T>>,
305    /// 梯度配置
306    config: GradientConfig,
307    /// 温度退火步数
308    annealing_steps: usize,
309    /// 当前步数
310    current_step: usize,
311    /// STE 模式:如果为 true,在 discretize 时存储 STE 修正项
312    use_ste: bool,
313    /// STE 修正项:hard - soft
314    ste_corrections: HashMap<(usize, usize), f64>,
315}
316
317impl<T: Clone + Default> DifferentiableGraph<T> {
318    /// 创建新的可微图
319    pub fn new(num_nodes: usize) -> Self {
320        Self {
321            num_nodes,
322            edges: HashMap::new(),
323            nodes: HashMap::new(),
324            config: GradientConfig::default(),
325            annealing_steps: 0,
326            current_step: 0,
327            use_ste: true,
328            ste_corrections: HashMap::new(),
329        }
330    }
331
332    /// 创建带配置的可微图
333    pub fn with_config(num_nodes: usize, config: GradientConfig) -> Self {
334        let use_ste = config.use_ste;
335        Self {
336            num_nodes,
337            edges: HashMap::new(),
338            nodes: HashMap::new(),
339            config,
340            annealing_steps: 0,
341            current_step: 0,
342            use_ste,
343            ste_corrections: HashMap::new(),
344        }
345    }
346
347    /// 初始化节点
348    pub fn init_nodes(&mut self, features: Option<T>) {
349        for i in 0..self.num_nodes {
350            self.nodes
351                .insert(i, DifferentiableNode::new(i, features.clone()));
352        }
353    }
354
355    /// 添加可学习边
356    pub fn add_learnable_edge(&mut self, src: usize, dst: usize, init_prob: f64) {
357        let edge = DifferentiableEdge::new(src, dst, init_prob);
358        self.edges.insert((src, dst), edge);
359    }
360
361    /// 移除边
362    pub fn remove_edge(&mut self, src: usize, dst: usize) -> Option<DifferentiableEdge> {
363        self.edges.remove(&(src, dst))
364    }
365
366    /// 获取边的存在概率
367    pub fn get_edge_probability(&self, src: usize, dst: usize) -> Option<f64> {
368        self.edges.get(&(src, dst)).map(|e| e.probability)
369    }
370
371    /// 获取边的存在性(离散)
372    pub fn get_edge_exists(&self, src: usize, dst: usize) -> Option<bool> {
373        self.edges.get(&(src, dst)).map(|e| e.exists)
374    }
375
376    /// 获取所有边的概率矩阵
377    pub fn get_probability_matrix(&self) -> Vec<Vec<f64>> {
378        let mut matrix = vec![vec![0.0; self.num_nodes]; self.num_nodes];
379        for ((src, dst), edge) in &self.edges {
380            matrix[*src][*dst] = edge.probability;
381        }
382        matrix
383    }
384
385    /// 获取离散邻接矩阵(使用 STE)
386    pub fn get_adjacency_matrix(&self) -> Vec<Vec<f64>> {
387        let mut matrix = vec![vec![0.0; self.num_nodes]; self.num_nodes];
388        for ((src, dst), edge) in &self.edges {
389            if edge.exists {
390                matrix[*src][*dst] = 1.0;
391            }
392        }
393        matrix
394    }
395
396    /// 温度退火
397    pub fn anneal_temperature(&mut self) {
398        if self.annealing_steps > 0 {
399            let progress = self.current_step as f64 / self.annealing_steps as f64;
400            // 指数退火:τ_t = τ_0 * exp(-k * t)
401            let k = 3.0;
402            self.config.temperature = 1.0 * (-k * progress).exp();
403            self.config.temperature = self.config.temperature.max(0.1); // 最小温度
404        }
405        self.current_step += 1;
406    }
407
408    /// 设置温度退火
409    pub fn with_temperature_annealing(mut self, steps: usize) -> Self {
410        self.annealing_steps = steps;
411        self
412    }
413
414    /// 离散化所有边(前向传播)
415    ///
416    /// 如果启用了 STE 模式,会存储 STE 修正项 (hard - soft),
417    /// 用于后续梯度计算时修正梯度。
418    pub fn discretize(&mut self) {
419        self.ste_corrections.clear();
420
421        for (&(src, dst), edge) in &mut self.edges {
422            let prob_before = edge.probability;
423            edge.discretize(self.config.temperature, self.config.use_ste);
424
425            // 存储 STE 修正项
426            if self.use_ste {
427                let hard = if edge.exists { 1.0 } else { 0.0 };
428                let ste_correction = hard - prob_before;
429                self.ste_corrections.insert((src, dst), ste_correction);
430            }
431        }
432    }
433
434    /// 计算结构梯度
435    ///
436    /// # Arguments
437    /// * `loss_gradients` - 损失对边存在性的梯度 {(src, dst): ∂L/∂A_ij}
438    ///
439    /// # Returns
440    /// HashMap {(src, dst): ∂L/∂logits},可用于更新边的 logits 参数
441    ///
442    /// # Gradient Computation
443    ///
444    /// 梯度计算遵循链式法则:
445    /// ```text
446    /// ∂L/∂logits = ∂L/∂A * ∂A/∂logits
447    /// ```
448    ///
449    /// 其中 A = σ(logits/τ),所以:
450    /// ```text
451    /// ∂A/∂logits = A * (1 - A) / τ
452    /// ```
453    ///
454    /// ## STE 修正
455    ///
456    /// 当启用 STE 模式时,梯度会加上 STE 修正项:
457    /// ```text
458    /// gradient = ∂L/∂logits + (hard - soft)
459    /// ```
460    ///
461    /// 这确保了前向传播的离散化与反向传播的连续梯度一致。
462    ///
463    /// # Regularization
464    ///
465    /// ## L1 稀疏正则化
466    /// L_sparse = λ_sparse * Σ|logits|
467    /// ∂L_sparse/∂logits = λ_sparse * sign(logits)
468    ///
469    /// 梯度下降更新:logits -= lr * gradient
470    /// - 正 logits → 正梯度 → logits 减小 → 概率趋向 0 → 稀疏
471    /// - 负 logits → 负梯度 → logits 增大 → 概率趋向 0 → 稀疏
472    ///
473    /// ## L2 平滑正则化
474    /// L_smooth = λ_smooth * Σ_{(i,j),(i,k)∈E} (A_ij - A_ik)²
475    /// ∂L_smooth/∂A_ij = 2 * λ_smooth * Σ_k (A_ij - A_ik)
476    ///
477    /// 平滑正则化鼓励:
478    /// - 共享源节点的边有相似概率
479    /// - 共享目标节点的边有相似概率
480    pub fn compute_structure_gradients(
481        &mut self,
482        loss_gradients: &HashMap<(usize, usize), f64>,
483    ) -> HashMap<(usize, usize), f64> {
484        let mut gradients = HashMap::new();
485
486        // 直接遍历 edges,避免不必要的收集操作
487        for (&(src, dst), edge) in &self.edges {
488            if let Some(&loss_grad) = loss_gradients.get(&(src, dst)) {
489                let prob = edge.probability;
490                let logits = edge.logits;
491
492                // 链式法则:∂L/∂logits = ∂L/∂A * ∂A/∂logits
493                // 其中 ∂A/∂logits = A * (1 - A) / τ
494                let d_prob_d_logits = prob * (1.0 - prob) / self.config.temperature;
495                let mut logits_gradient = loss_grad * d_prob_d_logits;
496
497                // STE 修正:gradient = gradient + (hard - soft)
498                // 这确保了前向离散的梯度能正确传播
499                if self.use_ste {
500                    if let Some(&ste_correction) = self.ste_corrections.get(&(src, dst)) {
501                        logits_gradient += ste_correction;
502                    }
503                }
504
505                // L1 稀疏正则化梯度
506                // ∂L_sparse/∂logits = λ_sparse * sign(logits)
507                let sparse_grad = if self.config.sparsity_weight > 0.0 {
508                    self.config.sparsity_weight * logits.signum()
509                } else {
510                    0.0
511                };
512
513                // L2 平滑正则化梯度
514                // ∂L_smooth/∂logits = 2 * λ_smooth * Σ_k (A_ij - A_ik)
515                let smooth_grad = if self.config.smoothness_weight > 0.0 {
516                    self.compute_smoothness_gradient(src, dst, prob) * self.config.smoothness_weight
517                } else {
518                    0.0
519                };
520
521                let total_gradient = logits_gradient + sparse_grad + smooth_grad;
522                gradients.insert((src, dst), total_gradient);
523            }
524        }
525
526        gradients
527    }
528
529    /// 计算平滑正则化梯度
530    ///
531    /// 考虑两种相邻关系:
532    /// 1. 共享源节点的边:(src, dst) 和 (src, k)
533    /// 2. 共享目标节点的边:(src, dst) 和 (k, dst)
534    fn compute_smoothness_gradient(&self, src: usize, dst: usize, prob: f64) -> f64 {
535        let mut gradient = 0.0;
536
537        // 遍历所有边,计算平滑梯度
538        for (&(s, d), other_edge) in &self.edges {
539            let other_prob = other_edge.probability;
540
541            // 共享源节点:(src, dst) 和 (src, k)
542            if s == src && d != dst {
543                gradient += 2.0 * (prob - other_prob);
544            }
545
546            // 共享目标节点:(src, dst) 和 (k, dst)
547            if d == dst && s != src {
548                gradient += 2.0 * (prob - other_prob);
549            }
550        }
551
552        gradient
553    }
554
555    /// 基于梯度更新结构
556    pub fn update_structure(&mut self, gradients: &HashMap<(usize, usize), f64>) {
557        for ((src, dst), &gradient) in gradients {
558            if let Some(edge) = self.edges.get_mut(&(*src, *dst)) {
559                edge.update_logits(gradient, self.config.edge_learning_rate);
560            }
561        }
562    }
563
564    /// 一步优化:离散化 -> 计算梯度 -> 更新
565    pub fn optimization_step(
566        &mut self,
567        loss_gradients: HashMap<(usize, usize), f64>,
568    ) -> HashMap<(usize, usize), f64> {
569        // 1. 离散化(前向)
570        self.discretize();
571
572        // 2. 计算梯度(反向)
573        let gradients = self.compute_structure_gradients(&loss_gradients);
574
575        // 3. 更新结构
576        self.update_structure(&gradients);
577
578        // 4. 温度退火
579        self.anneal_temperature();
580
581        gradients
582    }
583
584    /// 获取可微边列表
585    pub fn get_learnable_edges(&self) -> Vec<&DifferentiableEdge> {
586        self.edges.values().collect()
587    }
588
589    /// 获取边数
590    pub fn num_edges(&self) -> usize {
591        self.edges.len()
592    }
593
594    /// 获取节点数
595    pub fn num_nodes(&self) -> usize {
596        self.num_nodes
597    }
598
599    /// 获取配置
600    pub fn config(&self) -> &GradientConfig {
601        &self.config
602    }
603
604    /// 设置配置
605    pub fn set_config(&mut self, config: GradientConfig) {
606        self.config = config;
607    }
608
609    /// 获取当前温度
610    pub fn temperature(&self) -> f64 {
611        self.config.temperature
612    }
613
614    /// 设置温度
615    pub fn set_temperature(&mut self, temp: f64) {
616        self.config.temperature = temp;
617    }
618
619    /// 获取边迭代器
620    pub fn edges(&self) -> impl Iterator<Item = (&(usize, usize), &DifferentiableEdge)> {
621        self.edges.iter()
622    }
623
624    /// 转换为普通 Graph
625    ///
626    /// 使用离散化的边存在性构建 Graph。
627    /// 边的权重为 1.0(如果存在)或 0.0(如果不存在)。
628    ///
629    /// # Note
630    ///
631    /// 此方法创建的图使用节点索引作为节点数据,边权重为 f64。
632    /// 节点索引通过 `NodeIndex::new(index, generation)` 创建,
633    /// 其中 generation 由 Graph 内部管理。
634    pub fn to_graph(&self) -> crate::graph::Graph<usize, f64> {
635        use crate::graph::traits::GraphOps;
636        use crate::graph::Graph;
637        use crate::node::NodeIndex;
638
639        let mut graph: crate::graph::Graph<usize, f64> =
640            Graph::with_capacity(self.num_nodes, self.edges.len());
641
642        // 添加节点,使用索引作为节点数据
643        // Graph 会内部管理 NodeIndex 的 generation
644        let mut node_indices: Vec<NodeIndex> = Vec::with_capacity(self.num_nodes);
645        for i in 0..self.num_nodes {
646            let result = graph.add_node(i);
647            match result {
648                Ok(idx) => node_indices.push(idx),
649                Err(_) => {
650                    // 如果失败,创建一个占位符
651                    node_indices.push(NodeIndex::new(i, 0));
652                }
653            }
654        }
655
656        // 添加存在的边
657        for (&(src, dst), edge) in &self.edges {
658            if edge.exists && src < node_indices.len() && dst < node_indices.len() {
659                let _ = graph.add_edge(node_indices[src], node_indices[dst], 1.0);
660            }
661        }
662
663        graph
664    }
665
666    /// 转换为带类型信息的 Graph(保留 OperatorType 和 WeightTensor)
667    ///
668    /// 使用离散化的边存在性构建 Graph,保留原始的节点和边类型信息。
669    ///
670    /// # Arguments
671    /// * `node_types` - 节点类型映射 (node_index -> OperatorType)
672    /// * `edge_weights` - 边权重映射 ((src, dst) -> WeightTensor)
673    ///
674    /// # Returns
675    ///
676    /// 带类型信息的 Graph<OperatorType, WeightTensor>
677    #[cfg(feature = "transformer")]
678    pub fn to_graph_with_types(
679        &self,
680        node_types: &std::collections::HashMap<usize, crate::transformer::optimization::switch::OperatorType>,
681        edge_weights: &std::collections::HashMap<(usize, usize), crate::transformer::optimization::switch::WeightTensor>,
682    ) -> crate::graph::Graph<crate::transformer::optimization::switch::OperatorType, crate::transformer::optimization::switch::WeightTensor> {
683        use crate::graph::traits::GraphOps;
684        use crate::graph::Graph;
685        use crate::node::NodeIndex;
686        use crate::transformer::optimization::switch::{OperatorType, WeightTensor};
687
688        let mut graph: Graph<OperatorType, WeightTensor> =
689            Graph::with_capacity(self.num_nodes, self.edges.len());
690
691        // 添加节点,使用提供的类型信息
692        let mut node_indices: Vec<NodeIndex> = Vec::with_capacity(self.num_nodes);
693        for i in 0..self.num_nodes {
694            let node_type = node_types.get(&i)
695                .cloned()
696                .unwrap_or_else(|| OperatorType::Custom { name: format!("node_{}", i) });
697            
698            let result = graph.add_node(node_type);
699            match result {
700                Ok(idx) => node_indices.push(idx),
701                Err(_) => {
702                    // 如果失败,创建一个占位符
703                    node_indices.push(NodeIndex::new(i, 0));
704                }
705            }
706        }
707
708        // 添加存在的边,使用提供的权重信息
709        for (&(src, dst), edge) in &self.edges {
710            if edge.exists && src < node_indices.len() && dst < node_indices.len() {
711                let weight = edge_weights.get(&(src, dst))
712                    .cloned()
713                    .unwrap_or_else(|| WeightTensor::new(
714                        format!("edge_{}_to_{}", src, dst),
715                        vec![1.0],
716                        vec![1],
717                    ));
718                let _ = graph.add_edge(node_indices[src], node_indices[dst], weight);
719            }
720        }
721
722        graph
723    }
724
725    /// 从普通 Graph 初始化可微图
726    ///
727    /// # Arguments
728    /// * `graph` - 源图
729    /// * `init_probs` - 边的初始存在概率
730    ///   - 如果提供,只初始化指定的边
731    ///   - 如果为 None,根据图中存在的边初始化(概率设为 1.0)
732    ///
733    /// # Note
734    ///
735    /// 此方法忽略原图的节点和边数据,只使用图结构。
736    /// 节点数据默认为 `()`,边数据默认为 `()`.
737    pub fn from_graph<U, V>(
738        graph: &crate::graph::Graph<U, V>,
739        init_probs: Option<HashMap<(usize, usize), f64>>,
740    ) -> DifferentiableGraph<()>
741    where
742        U: Clone,
743        V: Clone,
744    {
745        use crate::graph::traits::{GraphBase, GraphQuery};
746
747        let num_nodes = graph.node_count();
748        let mut diff_graph = DifferentiableGraph::new(num_nodes);
749
750        if let Some(probs) = init_probs {
751            // 使用提供的概率初始化边
752            for ((src, dst), &prob) in &probs {
753                diff_graph.add_learnable_edge(*src, *dst, prob);
754            }
755        } else {
756            // 根据图中存在的边初始化(概率设为 1.0)
757            for node in graph.nodes() {
758                let src_idx = node.index().index();
759                for neighbor in graph.neighbors(node.index()) {
760                    let dst_idx = neighbor.index();
761                    diff_graph.add_learnable_edge(src_idx, dst_idx, 1.0);
762                }
763            }
764        }
765
766        diff_graph
767    }
768
769    /// 从普通图构建可微图(使用统一的初始概率)
770    ///
771    /// # Arguments
772    /// * `graph` - 原始图
773    /// * `init_prob` - 边的初始存在概率(0.0~1.0)
774    ///
775    /// # Returns
776    /// DifferentiableGraph<()> - 可微图
777    ///
778    /// # Note
779    ///
780    /// 此方法忽略原图的节点和边数据,只使用图结构。
781    pub fn from_graph_with_prob<U, V>(
782        graph: &crate::graph::Graph<U, V>,
783        init_prob: Option<f64>,
784    ) -> DifferentiableGraph<()>
785    where
786        U: Clone,
787        V: Clone,
788    {
789        use crate::graph::traits::{GraphBase, GraphQuery};
790
791        let num_nodes = graph.node_count();
792        let mut diff_graph = DifferentiableGraph::new(num_nodes);
793
794        let prob = init_prob.unwrap_or(1.0);
795
796        // 根据图中存在的边初始化
797        for node in graph.nodes() {
798            let src_idx = node.index().index();
799            for neighbor in graph.neighbors(node.index()) {
800                let dst_idx = neighbor.index();
801                diff_graph.add_learnable_edge(src_idx, dst_idx, prob);
802            }
803        }
804
805        diff_graph
806    }
807
808    /// 启用/禁用 STE 模式
809    pub fn set_ste(&mut self, use_ste: bool) {
810        self.use_ste = use_ste;
811        self.config.use_ste = use_ste;
812    }
813
814    /// 获取 STE 修正项
815    pub fn get_ste_corrections(&self) -> &HashMap<(usize, usize), f64> {
816        &self.ste_corrections
817    }
818}
819
820/// Gumbel-Softmax 采样器:用于可微离散采样
821pub struct GumbelSoftmaxSampler {
822    temperature: f64,
823}
824
825impl GumbelSoftmaxSampler {
826    /// 创建新的采样器
827    pub fn new(temperature: f64) -> Self {
828        Self { temperature }
829    }
830
831    /// 采样(软版本,可微)
832    ///
833    /// y_i = exp((log(π_i) + g_i) / τ) / Σ_j exp((log(π_j) + g_j) / τ)
834    /// 其中 g_i ~ Gumbel(0, 1)
835    pub fn sample_soft(&self, logits: &[f64]) -> Vec<f64> {
836        let gumbel_noise: Vec<f64> = logits.iter().map(|_| self.gumbel_sample()).collect();
837
838        let max_logit = logits
839            .iter()
840            .zip(&gumbel_noise)
841            .map(|(&l, &g)| l + g)
842            .fold(f64::NEG_INFINITY, f64::max);
843
844        let exp_logits: Vec<f64> = logits
845            .iter()
846            .zip(&gumbel_noise)
847            .map(|(&l, &g)| ((l + g - max_logit) / self.temperature).exp())
848            .collect();
849
850        let sum_exp: f64 = exp_logits.iter().sum();
851
852        exp_logits.iter().map(|&e| e / sum_exp).collect()
853    }
854
855    /// 采样(硬版本,不可微,用于前向)
856    pub fn sample_hard(&self, logits: &[f64]) -> Vec<f64> {
857        let soft = self.sample_soft(logits);
858        let mut result = vec![0.0; soft.len()];
859
860        // 取最大值位置为 1
861        if let Some(max_idx) = soft
862            .iter()
863            .enumerate()
864            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
865            .map(|(i, _)| i)
866        {
867            result[max_idx] = 1.0;
868        }
869
870        result
871    }
872
873    /// STE 版本:前向硬,反向软
874    pub fn sample_ste(&self, logits: &[f64]) -> (Vec<f64>, Vec<f64>) {
875        let hard = self.sample_hard(logits);
876        let soft = self.sample_soft(logits);
877
878        // STE: gradient = hard - soft.detach() + soft = hard (因为 soft 会 detach)
879        // 实际实现中,我们返回 (hard, soft) 用于后续梯度计算
880        (hard, soft)
881    }
882
883    /// Gumbel 分布采样:g = -log(-log(u)), u ~ Uniform(0,1)
884    fn gumbel_sample(&self) -> f64 {
885        #[cfg(feature = "rand")]
886        {
887            let u: f64 = random::<f64>();
888            -(-u.ln()).ln()
889        }
890        #[cfg(not(feature = "rand"))]
891        {
892            // 无 rand 特性时,使用简单确定性值
893            // 注意:这会使 Gumbel-Softmax 变成确定性函数,仅用于测试
894            let u: f64 = 0.5;
895            -(-u.ln()).ln()
896        }
897    }
898
899    /// 设置温度
900    pub fn set_temperature(&mut self, temp: f64) {
901        self.temperature = temp;
902    }
903
904    /// 使用自定义 RNG 的 Gumbel 采样
905    ///
906    /// 允许调用者提供随机数生成器,便于控制和复现结果
907    #[cfg(feature = "rand")]
908    pub fn gumbel_sample_with_rng<R: Rng>(&self, rng: &mut R) -> f64 {
909        let u: f64 = rng.gen_range(1e-7..1.0 - 1e-7);
910        -(-u.ln()).ln()
911    }
912
913    /// 使用自定义 RNG 的软采样
914    #[cfg(feature = "rand")]
915    pub fn sample_soft_with_rng(&self, logits: &[f64], rng: &mut impl Rng) -> Vec<f64> {
916        let gumbel_noise: Vec<f64> = logits
917            .iter()
918            .map(|_| self.gumbel_sample_with_rng(rng))
919            .collect();
920
921        let max_logit = logits
922            .iter()
923            .zip(&gumbel_noise)
924            .map(|(&l, &g)| l + g)
925            .fold(f64::NEG_INFINITY, f64::max);
926
927        let exp_logits: Vec<f64> = logits
928            .iter()
929            .zip(&gumbel_noise)
930            .map(|(&l, &g)| ((l + g - max_logit) / self.temperature).exp())
931            .collect();
932
933        let sum_exp: f64 = exp_logits.iter().sum();
934
935        exp_logits.iter().map(|&e| e / sum_exp).collect()
936    }
937}
938
939/// 边编辑策略:定义如何基于梯度编辑边
940pub trait EdgeEditPolicy: Send + Sync {
941    /// 决定是否添加边
942    fn should_add_edge(&self, gradient: f64, current_prob: f64) -> bool;
943
944    /// 决定是否删除边
945    fn should_remove_edge(&self, gradient: f64, current_prob: f64) -> bool;
946
947    /// 计算新的边概率
948    fn update_probability(&self, current_prob: f64, gradient: f64, learning_rate: f64) -> f64;
949}
950
951/// 基于阈值的编辑策略
952#[derive(Debug, Clone)]
953pub struct ThresholdEditPolicy {
954    /// 添加边的梯度阈值
955    pub add_threshold: f64,
956    /// 删除边的梯度阈值
957    pub remove_threshold: f64,
958    /// 概率下限
959    pub min_prob: f64,
960    /// 概率上限
961    pub max_prob: f64,
962}
963
964impl Default for ThresholdEditPolicy {
965    fn default() -> Self {
966        Self {
967            add_threshold: 0.1,
968            remove_threshold: -0.1,
969            min_prob: 0.01,
970            max_prob: 0.99,
971        }
972    }
973}
974
975impl EdgeEditPolicy for ThresholdEditPolicy {
976    fn should_add_edge(&self, gradient: f64, current_prob: f64) -> bool {
977        gradient > self.add_threshold && current_prob < 0.5
978    }
979
980    fn should_remove_edge(&self, gradient: f64, current_prob: f64) -> bool {
981        gradient < self.remove_threshold && current_prob > 0.5
982    }
983
984    fn update_probability(&self, current_prob: f64, gradient: f64, learning_rate: f64) -> f64 {
985        let new_prob = current_prob + learning_rate * gradient;
986        new_prob.clamp(self.min_prob, self.max_prob)
987    }
988}
989
990/// 结构梯度记录器:记录所有结构变换的梯度
991#[derive(Debug, Default, Clone)]
992pub struct GradientRecorder {
993    /// 边梯度记录 {(src, dst): gradient}
994    edge_gradients: HashMap<(usize, usize), f64>,
995    /// 节点梯度记录 {node_id: gradient}
996    node_gradients: HashMap<usize, f64>,
997    /// 速度历史(用于经典动量):v_t = μ * v_{t-1} + g_t
998    edge_velocities: HashMap<(usize, usize), f64>,
999    /// 动量系数
1000    momentum: f64,
1001}
1002
1003impl GradientRecorder {
1004    /// 创建新的记录器
1005    pub fn new(momentum: f64) -> Self {
1006        Self {
1007            edge_gradients: HashMap::new(),
1008            node_gradients: HashMap::new(),
1009            edge_velocities: HashMap::new(),
1010            momentum,
1011        }
1012    }
1013
1014    /// 记录边梯度
1015    pub fn record_edge_gradient(&mut self, src: usize, dst: usize, gradient: f64) {
1016        self.edge_gradients.insert((src, dst), gradient);
1017    }
1018
1019    /// 记录节点梯度
1020    pub fn record_node_gradient(&mut self, node_id: usize, gradient: f64) {
1021        self.node_gradients.insert(node_id, gradient);
1022    }
1023
1024    /// 获取边梯度
1025    pub fn get_edge_gradient(&self, src: usize, dst: usize) -> Option<f64> {
1026        self.edge_gradients.get(&(src, dst)).copied()
1027    }
1028
1029    /// 获取所有边梯度
1030    pub fn get_all_edge_gradients(&self) -> &HashMap<(usize, usize), f64> {
1031        &self.edge_gradients
1032    }
1033
1034    /// 应用经典动量
1035    ///
1036    /// 使用经典动量公式:v_t = μ * v_{t-1} + g_t
1037    /// 其中:
1038    /// - v_t: t 时刻的速度(累积梯度)
1039    /// - μ: 动量系数 (0.9 常见)
1040    /// - g_t: t 时刻的原始梯度
1041    ///
1042    /// 这与指数移动平均 (EMA) 不同:
1043    /// - EMA: g_ema = μ * g_ema + (1-μ) * g_t (会缩小梯度)
1044    /// - 经典动量:v_t = μ * v_{t-1} + g_t (保持梯度量级)
1045    pub fn apply_momentum(&mut self) -> HashMap<(usize, usize), f64> {
1046        let mut momentum_gradients = HashMap::new();
1047
1048        for ((src, dst), &grad) in &self.edge_gradients {
1049            let last_velocity = self
1050                .edge_velocities
1051                .get(&(*src, *dst))
1052                .copied()
1053                .unwrap_or(0.0);
1054            // 经典动量公式:v_t = μ * v_{t-1} + g_t
1055            let new_velocity = self.momentum * last_velocity + grad;
1056            self.edge_velocities.insert((*src, *dst), new_velocity);
1057            momentum_gradients.insert((*src, *dst), new_velocity);
1058        }
1059
1060        momentum_gradients
1061    }
1062
1063    /// 清空记录(保留速度历史)
1064    pub fn clear(&mut self) {
1065        self.edge_gradients.clear();
1066        self.node_gradients.clear();
1067    }
1068
1069    /// 清空所有状态(包括速度历史)
1070    pub fn reset(&mut self) {
1071        self.clear();
1072        self.edge_velocities.clear();
1073    }
1074}
1075
1076/// 图结构变换器:执行具体的结构编辑操作
1077pub struct GraphTransformer<T> {
1078    /// 编辑策略
1079    policy: Box<dyn EdgeEditPolicy>,
1080    /// 梯度记录器
1081    recorder: GradientRecorder,
1082    /// 标记
1083    _marker: std::marker::PhantomData<T>,
1084}
1085
1086impl<T: Clone + Default> GraphTransformer<T> {
1087    /// 创建新的变换器
1088    pub fn new(policy: Box<dyn EdgeEditPolicy>) -> Self {
1089        Self {
1090            policy,
1091            recorder: GradientRecorder::new(0.9),
1092            _marker: std::marker::PhantomData,
1093        }
1094    }
1095
1096    /// 执行结构变换
1097    pub fn transform(&mut self, graph: &mut DifferentiableGraph<T>) -> Vec<StructureEdit> {
1098        let mut edits = Vec::new();
1099
1100        // 应用动量
1101        let momentum_gradients = self.recorder.apply_momentum();
1102
1103        // 遍历所有边,决定是否编辑
1104        for ((src, dst), edge) in &mut graph.edges {
1105            if let Some(&gradient) = momentum_gradients.get(&(*src, *dst)) {
1106                let before = edge.probability;
1107
1108                // 决定是否删除
1109                if self.policy.should_remove_edge(gradient, edge.probability) {
1110                    let new_prob = self.policy.update_probability(
1111                        edge.probability,
1112                        gradient,
1113                        graph.config.edge_learning_rate,
1114                    );
1115
1116                    let after = new_prob;
1117                    edge.probability = new_prob;
1118                    edge.exists = new_prob > 0.5;
1119
1120                    edits.push(StructureEdit {
1121                        operation: EditOperation::EdgeEdit(*src, *dst, EdgeEditOp::Remove),
1122                        gradient,
1123                        before,
1124                        after,
1125                    });
1126                }
1127                // 决定是否添加
1128                else if self.policy.should_add_edge(gradient, edge.probability) {
1129                    let new_prob = self.policy.update_probability(
1130                        edge.probability,
1131                        gradient,
1132                        graph.config.edge_learning_rate,
1133                    );
1134
1135                    let after = new_prob;
1136                    edge.probability = new_prob;
1137                    edge.exists = new_prob > 0.5;
1138
1139                    edits.push(StructureEdit {
1140                        operation: EditOperation::EdgeEdit(*src, *dst, EdgeEditOp::Add),
1141                        gradient,
1142                        before,
1143                        after,
1144                    });
1145                }
1146                // 否则只是修改概率
1147                else {
1148                    let new_prob = self.policy.update_probability(
1149                        edge.probability,
1150                        gradient,
1151                        graph.config.edge_learning_rate,
1152                    );
1153
1154                    let after = new_prob;
1155                    edge.probability = new_prob;
1156                    edge.exists = new_prob > 0.5;
1157
1158                    edits.push(StructureEdit {
1159                        operation: EditOperation::EdgeEdit(*src, *dst, EdgeEditOp::Modify),
1160                        gradient,
1161                        before,
1162                        after,
1163                    });
1164                }
1165            }
1166        }
1167
1168        edits
1169    }
1170
1171    /// 记录梯度
1172    pub fn record_gradients(&mut self, gradients: &HashMap<(usize, usize), f64>) {
1173        for ((src, dst), &grad) in gradients {
1174            self.recorder.record_edge_gradient(*src, *dst, grad);
1175        }
1176    }
1177}
1178
1179#[cfg(test)]
1180mod tests {
1181    use super::*;
1182
1183    #[test]
1184    fn test_differentiable_edge() {
1185        let mut edge = DifferentiableEdge::new(0, 1, 0.5);
1186
1187        assert_eq!(edge.src, 0);
1188        assert_eq!(edge.dst, 1);
1189        assert!((edge.logits - 0.0).abs() < 1e-6); // log(0.5/0.5) = 0
1190        assert!((edge.probability - 0.5).abs() < 1e-6);
1191
1192        // 更新 logits(负梯度增加 logits,正梯度减小 logits)
1193        edge.update_logits(-0.1, 0.01); // 负梯度:增加 logits
1194        assert!(edge.logits > 0.0);
1195    }
1196
1197    #[test]
1198    fn test_differentiable_graph() {
1199        let mut graph = DifferentiableGraph::<Vec<f64>>::new(4);
1200
1201        // 添加边
1202        graph.add_learnable_edge(0, 1, 0.5);
1203        graph.add_learnable_edge(1, 2, 0.8);
1204        graph.add_learnable_edge(2, 3, 0.3);
1205
1206        assert_eq!(graph.num_edges(), 3);
1207        assert_eq!(graph.num_nodes(), 4);
1208
1209        // 获取概率矩阵
1210        let prob_matrix = graph.get_probability_matrix();
1211        assert!((prob_matrix[0][1] - 0.5).abs() < 1e-6);
1212        assert!((prob_matrix[1][2] - 0.8).abs() < 1e-6);
1213
1214        // 离散化
1215        graph.discretize();
1216        // 0.5 概率时 exists=false (因为 0.5 > 0.5 为 false)
1217        assert!(!graph.get_edge_exists(0, 1).unwrap());
1218        assert!(graph.get_edge_exists(1, 2).unwrap()); // 0.8 > 0.5 -> true
1219        assert!(!graph.get_edge_exists(2, 3).unwrap()); // 0.3 < 0.5 -> false
1220    }
1221
1222    #[test]
1223    fn test_structure_gradient_computation() {
1224        let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
1225        graph.add_learnable_edge(0, 1, 0.5);
1226        graph.add_learnable_edge(1, 2, 0.8);
1227
1228        // 不 discretize,因为我们要测试纯梯度计算(不含 STE 修正)
1229        // 或者禁用 STE 模式
1230        graph.set_ste(false);
1231
1232        // 模拟损失梯度
1233        let mut loss_gradients = HashMap::new();
1234        loss_gradients.insert((0, 1), 0.5); // 正梯度:鼓励添加
1235        loss_gradients.insert((1, 2), -0.3); // 负梯度:鼓励删除
1236
1237        let gradients = graph.compute_structure_gradients(&loss_gradients);
1238
1239        assert!(gradients.contains_key(&(0, 1)));
1240        assert!(gradients.contains_key(&(1, 2)));
1241
1242        // 正梯度应该导致正的 logits 梯度(不考虑 STE 修正时)
1243        assert!(*gradients.get(&(0, 1)).unwrap() > 0.0);
1244        // 负梯度应该导致负的 logits 梯度
1245        assert!(*gradients.get(&(1, 2)).unwrap() < 0.0);
1246    }
1247
1248    #[test]
1249    fn test_gumbel_softmax_sampler() {
1250        let sampler = GumbelSoftmaxSampler::new(1.0);
1251        let logits = vec![1.0, 2.0, 3.0];
1252
1253        // 软采样
1254        let soft = sampler.sample_soft(&logits);
1255        assert_eq!(soft.len(), 3);
1256        assert!((soft.iter().sum::<f64>() - 1.0).abs() < 1e-5); // 和为 1
1257
1258        // 硬采样
1259        let hard = sampler.sample_hard(&logits);
1260        assert_eq!(hard.len(), 3);
1261        assert_eq!(hard.iter().filter(|&&x| x > 0.5).count(), 1); // 只有一个为 1
1262
1263        // STE 采样
1264        let (hard_ste, soft_ste) = sampler.sample_ste(&logits);
1265        assert_eq!(hard_ste.len(), 3);
1266        assert_eq!(soft_ste.len(), 3);
1267    }
1268
1269    #[test]
1270    fn test_threshold_edit_policy() {
1271        let policy = ThresholdEditPolicy::default();
1272
1273        // 测试添加边决策
1274        assert!(policy.should_add_edge(0.2, 0.3)); // 梯度>阈值,概率<0.5
1275        assert!(!policy.should_add_edge(0.05, 0.3)); // 梯度<阈值
1276
1277        // 测试删除边决策
1278        assert!(policy.should_remove_edge(-0.2, 0.7)); // 梯度<阈值,概率>0.5
1279        assert!(!policy.should_remove_edge(-0.05, 0.7)); // 梯度>阈值
1280
1281        // 测试概率更新
1282        let new_prob = policy.update_probability(0.5, 0.1, 0.01);
1283        assert!((new_prob - 0.501).abs() < 1e-6);
1284    }
1285
1286    #[test]
1287    fn test_gradient_recorder_with_momentum() {
1288        let mut recorder = GradientRecorder::new(0.9);
1289
1290        recorder.record_edge_gradient(0, 1, 0.5);
1291        recorder.record_edge_gradient(1, 2, -0.3);
1292
1293        let momentum_grads = recorder.apply_momentum();
1294
1295        // 第一轮:v_1 = 0.9 * 0 + 0.5 = 0.5
1296        assert!((momentum_grads.get(&(0, 1)).unwrap() - 0.5).abs() < 1e-6);
1297        assert!((momentum_grads.get(&(1, 2)).unwrap() + 0.3).abs() < 1e-6);
1298
1299        // 第二轮
1300        recorder.clear();
1301        recorder.record_edge_gradient(0, 1, 0.6);
1302        recorder.record_edge_gradient(1, 2, -0.2);
1303
1304        let momentum_grads2 = recorder.apply_momentum();
1305
1306        // 经典动量:v_2 = 0.9 * v_1 + g_2
1307        // v_2(0,1) = 0.9 * 0.5 + 0.6 = 1.05
1308        // v_2(1,2) = 0.9 * (-0.3) + (-0.2) = -0.47
1309        let expected_01 = 0.9 * 0.5 + 0.6;
1310        let expected_12 = 0.9 * (-0.3) + (-0.2);
1311
1312        assert!((momentum_grads2.get(&(0, 1)).unwrap() - expected_01).abs() < 1e-6);
1313        assert!((momentum_grads2.get(&(1, 2)).unwrap() - expected_12).abs() < 1e-6);
1314    }
1315
1316    #[test]
1317    fn test_optimization_step() {
1318        let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
1319        graph.add_learnable_edge(0, 1, 0.5);
1320        graph.add_learnable_edge(1, 2, 0.8);
1321
1322        let mut loss_gradients = HashMap::new();
1323        loss_gradients.insert((0, 1), 0.5);
1324        loss_gradients.insert((1, 2), -0.3);
1325
1326        let gradients = graph.optimization_step(loss_gradients);
1327
1328        assert!(gradients.contains_key(&(0, 1)));
1329        assert!(gradients.contains_key(&(1, 2)));
1330
1331        // 温度应该退火
1332        assert!(graph.temperature() <= 1.0);
1333    }
1334
1335    #[test]
1336    fn test_gradient_computation_with_low_temperature() {
1337        // 测试:低温下梯度计算不应产生 NaN 或 inf
1338        let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
1339        graph.add_learnable_edge(0, 1, 0.5);
1340        graph.config.temperature = 0.1; // 最小温度
1341
1342        let mut loss_gradients = HashMap::new();
1343        loss_gradients.insert((0, 1), 1.0);
1344
1345        let gradients = graph.compute_structure_gradients(&loss_gradients);
1346
1347        // 梯度应该是有限的
1348        for &grad in gradients.values() {
1349            assert!(grad.is_finite(), "Gradient should be finite, got {}", grad);
1350        }
1351    }
1352
1353    #[test]
1354    fn test_gradient_computation_with_zero_probability() {
1355        // 测试:概率接近 0 时的梯度计算
1356        let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
1357        graph.add_learnable_edge(0, 1, 1e-7); // 接近 0 的概率
1358
1359        let mut loss_gradients = HashMap::new();
1360        loss_gradients.insert((0, 1), 1.0);
1361
1362        let gradients = graph.compute_structure_gradients(&loss_gradients);
1363
1364        // 梯度应该是有限的
1365        for &grad in gradients.values() {
1366            assert!(grad.is_finite(), "Gradient should be finite, got {}", grad);
1367        }
1368    }
1369
1370    #[test]
1371    fn test_gradient_computation_with_one_probability() {
1372        // 测试:概率接近 1 时的梯度计算
1373        let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
1374        graph.add_learnable_edge(0, 1, 1.0 - 1e-7); // 接近 1 的概率
1375
1376        let mut loss_gradients = HashMap::new();
1377        loss_gradients.insert((0, 1), 1.0);
1378
1379        let gradients = graph.compute_structure_gradients(&loss_gradients);
1380
1381        // 梯度应该是有限的
1382        for &grad in gradients.values() {
1383            assert!(grad.is_finite(), "Gradient should be finite, got {}", grad);
1384        }
1385    }
1386
1387    #[test]
1388    fn test_smoothness_gradient_computation() {
1389        // 测试:平滑正则化梯度计算
1390        let mut graph = DifferentiableGraph::<Vec<f64>>::with_config(
1391            4,
1392            GradientConfig::new(1.0, true, 0.01, 0.01).with_smoothness(0.1),
1393        );
1394
1395        // 添加共享源节点的边
1396        graph.add_learnable_edge(0, 1, 0.8);
1397        graph.add_learnable_edge(0, 2, 0.2);
1398        graph.add_learnable_edge(0, 3, 0.5);
1399
1400        let mut loss_gradients = HashMap::new();
1401        loss_gradients.insert((0, 1), -0.5);
1402        loss_gradients.insert((0, 2), -0.5);
1403        loss_gradients.insert((0, 3), -0.5);
1404
1405        let gradients = graph.compute_structure_gradients(&loss_gradients);
1406
1407        // 平滑正则化应该使梯度趋向于平均
1408        // (0, 1) 的概率最高,平滑梯度应该为负(降低概率)
1409        // (0, 2) 的概率最低,平滑梯度应该为正(提高概率)
1410        assert!(gradients.contains_key(&(0, 1)));
1411        assert!(gradients.contains_key(&(0, 2)));
1412        assert!(gradients.contains_key(&(0, 3)));
1413    }
1414
1415    #[test]
1416    fn test_sparsity_gradient_computation() {
1417        // 测试:稀疏正则化梯度计算
1418        let mut graph = DifferentiableGraph::<Vec<f64>>::with_config(
1419            3,
1420            GradientConfig::new(1.0, true, 0.01, 0.01).with_sparsity(0.1),
1421        );
1422
1423        graph.add_learnable_edge(0, 1, 0.5);
1424        graph.add_learnable_edge(1, 2, 0.5);
1425
1426        // 设置 logits 为正
1427        if let Some(edge) = graph.edges.get_mut(&(0, 1)) {
1428            edge.logits = 2.0; // 正 logits
1429        }
1430        if let Some(edge) = graph.edges.get_mut(&(1, 2)) {
1431            edge.logits = -2.0; // 负 logits
1432        }
1433
1434        let mut loss_gradients = HashMap::new();
1435        loss_gradients.insert((0, 1), 0.0); // 无损失梯度,只有正则化梯度
1436        loss_gradients.insert((1, 2), 0.0);
1437
1438        let gradients = graph.compute_structure_gradients(&loss_gradients);
1439
1440        // 正 logits 应该得到正梯度(推向 0)
1441        assert!(*gradients.get(&(0, 1)).unwrap() > 0.0);
1442        // 负 logits 应该得到负梯度(推向 0)
1443        assert!(*gradients.get(&(1, 2)).unwrap() < 0.0);
1444    }
1445
1446    #[test]
1447    fn test_ste_correction() {
1448        // 测试:STE 修正项计算
1449        let mut graph = DifferentiableGraph::<Vec<f64>>::new(3);
1450        graph.add_learnable_edge(0, 1, 0.6); // 概率 > 0.5,离散化后为 1
1451        graph.add_learnable_edge(1, 2, 0.4); // 概率 < 0.5,离散化后为 0
1452
1453        graph.discretize();
1454
1455        let corrections = graph.get_ste_corrections();
1456
1457        // (0, 1): hard=1, soft=0.6, correction=0.4
1458        assert!((corrections.get(&(0, 1)).unwrap() - 0.4).abs() < 0.01);
1459        // (1, 2): hard=0, soft=0.4, correction=-0.4
1460        assert!((corrections.get(&(1, 2)).unwrap() + 0.4).abs() < 0.01);
1461    }
1462
1463    #[test]
1464    fn test_momentum_classical() {
1465        // 测试:经典动量公式
1466        let mut recorder = GradientRecorder::new(0.9);
1467
1468        // 第一轮
1469        recorder.record_edge_gradient(0, 1, 1.0);
1470        let momentum_grads_1 = recorder.apply_momentum();
1471        // v_1 = 0.9 * 0 + 1.0 = 1.0
1472        assert!((momentum_grads_1.get(&(0, 1)).unwrap() - 1.0).abs() < 1e-6);
1473
1474        // 第二轮
1475        recorder.clear();
1476        recorder.record_edge_gradient(0, 1, 1.0);
1477        let momentum_grads_2 = recorder.apply_momentum();
1478        // v_2 = 0.9 * 1.0 + 1.0 = 1.9
1479        assert!((momentum_grads_2.get(&(0, 1)).unwrap() - 1.9).abs() < 1e-6);
1480
1481        // 第三轮
1482        recorder.clear();
1483        recorder.record_edge_gradient(0, 1, 1.0);
1484        let momentum_grads_3 = recorder.apply_momentum();
1485        // v_3 = 0.9 * 1.9 + 1.0 = 2.71
1486        assert!((momentum_grads_3.get(&(0, 1)).unwrap() - 2.71).abs() < 1e-6);
1487    }
1488
1489    #[test]
1490    fn test_graph_conversion() {
1491        // 测试:DifferentiableGraph 与 Graph 的转换
1492        use crate::graph::traits::{GraphBase, GraphQuery};
1493
1494        let mut diff_graph = DifferentiableGraph::<()>::new(4);
1495        diff_graph.add_learnable_edge(0, 1, 0.8);
1496        diff_graph.add_learnable_edge(1, 2, 0.3);
1497        diff_graph.add_learnable_edge(2, 3, 0.9);
1498
1499        // 离散化
1500        diff_graph.discretize();
1501
1502        // 转换为普通 Graph
1503        let graph = diff_graph.to_graph();
1504
1505        // 验证节点数
1506        assert_eq!(graph.node_count(), 4);
1507
1508        // 验证边:只有概率 > 0.5 的边应该存在
1509        // 使用 graph.nodes() 获取正确的 NodeIndex
1510        let nodes: Vec<_> = graph.nodes().collect();
1511        assert_eq!(nodes.len(), 4);
1512
1513        // nodes 按索引排序,所以 nodes[0] 对应索引 0,等等
1514        let n0 = nodes[0].index();
1515        let n1 = nodes[1].index();
1516        let n2 = nodes[2].index();
1517        let n3 = nodes[3].index();
1518
1519        // 检查边是否存在
1520        assert!(graph.has_edge(n0, n1)); // 0.8 > 0.5
1521        assert!(!graph.has_edge(n1, n2)); // 0.3 < 0.5
1522        assert!(graph.has_edge(n2, n3)); // 0.9 > 0.5
1523    }
1524
1525    #[test]
1526    fn test_from_graph() {
1527        // 测试:从普通 Graph 初始化
1528        use crate::graph::builders::GraphBuilder;
1529
1530        let graph = GraphBuilder::directed()
1531            .with_nodes(vec![(0, ()), (1, ()), (2, ()), (3, ())])
1532            .with_edges(vec![(0, 1, 1.0), (1, 2, 1.0), (2, 3, 1.0)])
1533            .build()
1534            .unwrap();
1535
1536        // 使用 turbofish 语法明确指定类型
1537        let diff_graph = DifferentiableGraph::<()>::from_graph(&graph, None);
1538
1539        assert_eq!(diff_graph.num_nodes(), 4);
1540        assert_eq!(diff_graph.num_edges(), 3);
1541        assert!(diff_graph.get_edge_probability(0, 1).is_some());
1542        assert!(diff_graph.get_edge_probability(1, 2).is_some());
1543        assert!(diff_graph.get_edge_probability(2, 3).is_some());
1544    }
1545}