Skip to main content

god_graph/tensor/
unified_graph.rs

1//! UnifiedGraph: 统一图结构,集成 DifferentiableGraph 和 ComputeGraph
2//!
3//! ## 核心设计
4//!
5//! **问题**: 当前 `DifferentiableGraph`(结构梯度)和 `ComputeGraph`(参数梯度)独立运作,
6//! 训练时需要手动协调两个图——这是架构缺陷。
7//!
8//! **解决方案**: 利用 God-Graph 的桶式邻接表 + Generation 索引设计,将结构参数和权重参数
9//! 统一存储在边数据中:
10//! - 结构参数(边存在性)存储在 `EdgeData.logits`
11//! - 权重参数(W 矩阵)存储在 `EdgeData.weight`
12//! - `ComputeGraph` 记录操作,支持自动微分
13//!
14//! ## 与 petgraph 的对比
15//!
16//! petgraph 的边是静态的,删除边后索引失效。
17//! God-Graph 的桶式邻接表 + Generation 索引:
18//! - 删除边后,索引可安全重用(generation 检查)
19//! - O(1) 增量更新(优于 CSR 格式)
20//! - 支持动态结构优化(DifferentiableGraph 的核心需求)
21//!
22//! ## 使用示例
23//!
24//! ```ignore
25//! use god_gragh::tensor::unified_graph::{UnifiedGraph, UnifiedConfig};
26//! use god_gragh::tensor::DenseTensor;
27//!
28//! // 1. 创建统一图
29//! let config = UnifiedConfig::default()
30//!     .with_structure_lr(0.01)
31//!     .with_param_lr(0.001)
32//!     .with_sparsity(0.1);
33//! let mut graph = UnifiedGraph::new(config);
34//!
35//! // 2. 添加边(同时包含权重和结构 logits)
36//! let weight = DenseTensor::from_vec(vec![0.1, 0.2, 0.3], vec![1, 3]);
37//! graph.add_edge(0, 1, weight, 0.5); // 0.5 是初始存在概率
38//!
39//! // 3. 前向传播
40//! let input = DenseTensor::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
41//! let output = graph.forward(&input);
42//!
43//! // 4. 计算损失
44//! let loss = compute_loss(&output);
45//!
46//! // 5. 联合优化一步:同时更新结构和参数
47//! graph.joint_optimization_step(&loss);
48//! ```
49
50use std::collections::HashMap;
51
52use crate::errors::{GraphError, GraphResult};
53use crate::graph::Graph;
54use crate::graph::traits::{GraphBase, GraphOps, GraphQuery};
55use crate::tensor::dense::DenseTensor;
56use crate::tensor::differentiable::GradientConfig;
57use crate::tensor::traits::TensorBase;
58
59/// 统一图配置
60#[derive(Debug, Clone)]
61pub struct UnifiedConfig {
62    /// 结构梯度配置
63    pub gradient_config: GradientConfig,
64    /// 结构学习率
65    pub structure_learning_rate: f64,
66    /// 参数学习率
67    pub param_learning_rate: f64,
68    /// 离散化阈值(用于 pruning)
69    pub discretization_threshold: f64,
70    /// 是否启用联合优化
71    pub enable_joint_optimization: bool,
72}
73
74impl Default for UnifiedConfig {
75    fn default() -> Self {
76        Self {
77            gradient_config: GradientConfig::default(),
78            structure_learning_rate: 0.01,
79            param_learning_rate: 0.001,
80            discretization_threshold: 0.5,
81            enable_joint_optimization: true,
82        }
83    }
84}
85
86impl UnifiedConfig {
87    /// 创建新的统一配置
88    pub fn new(structure_lr: f64, param_lr: f64) -> Self {
89        Self {
90            structure_learning_rate: structure_lr,
91            param_learning_rate: param_lr,
92            ..Default::default()
93        }
94    }
95
96    /// 启用稀疏正则化
97    pub fn with_sparsity(mut self, weight: f64) -> Self {
98        self.gradient_config = self.gradient_config.with_sparsity(weight);
99        self
100    }
101
102    /// 设置结构学习率
103    pub fn with_structure_lr(mut self, lr: f64) -> Self {
104        self.structure_learning_rate = lr;
105        self
106    }
107
108    /// 设置参数学习率
109    pub fn with_param_lr(mut self, lr: f64) -> Self {
110        self.param_learning_rate = lr;
111        self
112    }
113
114    /// 设置离散化阈值
115    pub fn with_threshold(mut self, threshold: f64) -> Self {
116        self.discretization_threshold = threshold;
117        self
118    }
119}
120
121/// 边数据:统一存储权重和结构参数
122#[derive(Debug, Clone)]
123pub struct EdgeData {
124    /// 权重张量
125    pub weight: DenseTensor,
126    /// 结构 logits(决定边是否存在)
127    pub structure_logits: f64,
128    /// 边存在概率(由 logits 计算)
129    pub existence_prob: f64,
130    /// 离散化后的存在性
131    pub exists: bool,
132    /// 结构梯度
133    pub structure_gradient: Option<f64>,
134    /// 权重梯度
135    pub weight_gradient: Option<DenseTensor>,
136}
137
138impl EdgeData {
139    /// 创建新的边数据
140    pub fn new(weight: DenseTensor, init_prob: f64) -> Self {
141        let logits = Self::prob_to_logits(init_prob);
142        Self {
143            weight,
144            structure_logits: logits,
145            existence_prob: init_prob,
146            exists: init_prob > 0.5,
147            structure_gradient: None,
148            weight_gradient: None,
149        }
150    }
151
152    /// 概率转 logits
153    fn prob_to_logits(prob: f64) -> f64 {
154        let p = prob.clamp(1e-7, 1.0 - 1e-7);
155        (p / (1.0 - p)).ln()
156    }
157
158    /// logits 转概率(带温度)
159    pub fn logits_to_prob(logits: f64, temperature: f64) -> f64 {
160        1.0 / (1.0 + (-logits / temperature).exp())
161    }
162
163    /// 更新结构 logits
164    pub fn update_logits(&mut self, gradient: f64, learning_rate: f64) {
165        self.structure_logits += learning_rate * gradient;
166        self.structure_gradient = Some(gradient);
167    }
168
169    /// 更新权重
170    pub fn update_weight(&mut self, gradient: &DenseTensor, learning_rate: f64) {
171        use crate::tensor::traits::TensorOps;
172        
173        // 简单的 SGD 更新:w = w - lr * grad
174        let lr_tensor = DenseTensor::scalar(learning_rate);
175        let scaled_grad = gradient.mul(&lr_tensor);
176        self.weight = self.weight.sub(&scaled_grad);
177        self.weight_gradient = Some(gradient.clone());
178    }
179
180    /// 离散化(使用 STE)
181    pub fn discretize(&mut self, temperature: f64, threshold: f64) {
182        self.existence_prob = Self::logits_to_prob(self.structure_logits, temperature);
183        self.exists = self.existence_prob > threshold;
184    }
185}
186
187/// 节点数据:存储特征和偏置
188#[derive(Debug, Clone)]
189pub struct NodeData {
190    /// 节点特征
191    pub features: DenseTensor,
192    /// 偏置(可选)
193    pub bias: Option<DenseTensor>,
194}
195
196impl NodeData {
197    /// 创建新的节点数据
198    pub fn new(features: DenseTensor) -> Self {
199        Self {
200            features,
201            bias: None,
202        }
203    }
204
205    /// 设置偏置
206    pub fn with_bias(mut self, bias: DenseTensor) -> Self {
207        self.bias = Some(bias);
208        self
209    }
210}
211
212/// 统一图结构:同时支持结构梯度和参数梯度
213///
214/// # 核心优势
215///
216/// 1. **统一存储**: 结构参数和权重参数存储在同一个图中
217/// 2. **联合优化**: 一步同时更新结构和参数
218/// 3. **桶式邻接表**: O(1) 边编辑,支持动态剪枝
219/// 4. **Generation 索引**: 删除边后索引可安全重用
220pub struct UnifiedGraph {
221    /// 主图结构(桶式邻接表)
222    graph: Graph<NodeData, EdgeData>,
223    /// 配置
224    config: UnifiedConfig,
225}
226
227impl UnifiedGraph {
228    /// 创建新的统一图
229    pub fn new(config: UnifiedConfig) -> Self {
230        Self {
231            graph: Graph::directed(),
232            config,
233        }
234    }
235
236    /// 从现有 Graph 构建统一图
237    pub fn from_graph(base_graph: Graph<NodeData, EdgeData>, config: UnifiedConfig) -> Self {
238        Self {
239            graph: base_graph,
240            config,
241        }
242    }
243
244    /// 添加节点
245    pub fn add_node(&mut self, features: DenseTensor) -> GraphResult<crate::node::NodeIndex> {
246        let node_data = NodeData::new(features);
247        self.graph.add_node(node_data)
248    }
249
250    /// 添加边(同时包含权重和结构参数)
251    pub fn add_edge(
252        &mut self,
253        src: crate::node::NodeIndex,
254        dst: crate::node::NodeIndex,
255        weight: DenseTensor,
256        init_prob: f64,
257    ) -> GraphResult<usize> {
258        // 验证节点存在
259        if self.graph.get_node(src).is_err() {
260            return Err(GraphError::NotFound(format!("Node {:?} not found", src)));
261        }
262        if self.graph.get_node(dst).is_err() {
263            return Err(GraphError::NotFound(format!("Node {:?} not found", dst)));
264        }
265
266        let edge_data = EdgeData::new(weight, init_prob);
267        let edge_idx = self.graph.add_edge(src, dst, edge_data)?;
268        Ok(edge_idx.index())
269    }
270
271    /// 获取边数据(通过边索引)
272    ///
273    /// # Arguments
274    ///
275    /// * `edge_idx` - 边索引
276    ///
277    /// # Returns
278    ///
279    /// 如果边存在,返回边数据引用;否则返回错误
280    pub fn get_edge_data(&self, edge_idx: usize) -> Result<&EdgeData, GraphError> {
281        use crate::edge::EdgeIndex;
282
283        let idx = EdgeIndex::new(edge_idx, 0);
284        self.graph.get_edge(idx)
285    }
286
287    /// 获取边数据(可变引用,使用 IndexMut trait)
288    ///
289    /// # Arguments
290    ///
291    /// * `edge_idx` - 边索引
292    ///
293    /// # Returns
294    ///
295    /// 如果边存在,返回边数据可变引用;否则返回错误
296    pub fn get_edge_data_mut(&mut self, edge_idx: usize) -> Result<&mut EdgeData, GraphError> {
297        use crate::edge::EdgeIndex;
298
299        let idx = EdgeIndex::new(edge_idx, 0);
300
301        // 检查边是否存在
302        self.graph.get_edge(idx)?;
303
304        // 使用 IndexMut trait 获取可变引用
305        Ok(&mut self.graph[idx])
306    }
307
308    /// 前向传播
309    ///
310    /// 通过图结构计算输出
311    pub fn forward(&mut self, input: &DenseTensor) -> GraphResult<DenseTensor> {
312        use crate::tensor::traits::TensorOps;
313        use crate::algorithms::traversal::topological_sort;
314        
315        // 按拓扑序执行节点
316        let sorted = topological_sort(&self.graph)
317            .map_err(|e| GraphError::InvalidFormat(format!("Topological sort failed: {}", e)))?;
318        
319        let mut current = input.clone();
320        
321        for node_idx in sorted {
322            // 获取入边(使用 incident_edges)
323            let incoming: Vec<_> = self.graph.incident_edges(node_idx).collect();
324            
325            if incoming.is_empty() {
326                // 输入节点
327                continue;
328            }
329            
330            // 聚合入边信息(简单求和)
331            let mut aggregated = DenseTensor::zeros(current.shape().to_vec());
332            for edge_idx in incoming {
333                if let Ok(edge_data) = self.graph.get_edge(edge_idx) {
334                    if edge_data.exists {
335                        // 矩阵乘法:input @ weight.T
336                        let weight_t = edge_data.weight.transpose(None);
337                        let contribution = current.matmul(&weight_t);
338                        aggregated = aggregated.add(&contribution);
339                    }
340                }
341            }
342            
343            // 应用激活(ReLU)
344            current = aggregated.relu();
345        }
346        
347        Ok(current)
348    }
349
350    /// 计算损失(简单的 MSE 损失示例)
351    pub fn compute_loss(&mut self, target: &DenseTensor, output: &DenseTensor) -> DenseTensor {
352        use crate::tensor::traits::TensorOps;
353        
354        // MSE: (output - target)^2
355        let diff = output.sub(target);
356        diff.mul(&diff)
357    }
358
359    /// 反向传播(简化版本)
360    pub fn backward(&mut self, _loss: &DenseTensor) -> GraphResult<()> {
361        // 简化版本:暂不实现完整的反向传播
362        // 未来可以集成 ComputeGraph 或 dfdx/candle 实现完整 autograd
363        Ok(())
364    }
365
366    /// 计算结构梯度(基于边存在概率的梯度)
367    pub fn compute_structure_gradients(&mut self, _loss: &DenseTensor) -> GraphResult<HashMap<(usize, usize), f64>> {
368        let mut gradients = HashMap::new();
369
370        // 收集所有边索引
371        let edge_indices: Vec<_> = self.graph.edges().map(|e| e.index).collect();
372
373        for edge_idx in edge_indices {
374            let edge_idx_val = edge_idx.index();
375            // 获取边数据的克隆(避免借用问题)
376            let edge_data_clone = self.get_edge_data(edge_idx_val).cloned().ok();
377
378            if let Some(edge_data) = edge_data_clone {
379                // 简化:使用边权重的梯度范数作为结构梯度
380                if let Some(grad) = edge_data.weight_gradient {
381                    // 计算梯度范数
382                    let grad_norm: f64 = grad.data().iter().map(|&x| x.abs()).sum();
383
384                    // 存储结构梯度(使用边索引作为 key)
385                    gradients.insert((edge_idx_val, 0), grad_norm);
386                }
387            }
388        }
389
390        Ok(gradients)
391    }
392
393    /// 联合优化一步:同时更新结构和参数
394    ///
395    /// # 流程
396    ///
397    /// 1. 反向传播计算参数梯度
398    /// 2. 计算结构梯度(基于权重梯度范数)
399    /// 3. 更新权重参数
400    /// 4. 更新结构参数(logits)
401    /// 5. 离散化弱边(利用桶式邻接表的 O(1) 删除)
402    pub fn joint_optimization_step(&mut self, loss: &DenseTensor) -> GraphResult<()> {
403        // 1. 反向传播(简化版本)
404        self.backward(loss)?;
405        
406        // 2. 计算结构梯度
407        let structure_grads = self.compute_structure_gradients(loss)?;
408        
409        // 3. 更新边参数(先克隆配置避免借用冲突)
410        let edge_indices: Vec<_> = self.graph.edges().map(|e| e.index).collect();
411        let temperature = self.config.gradient_config.temperature;
412        let structure_lr = self.config.structure_learning_rate;
413        let discretization_threshold = self.config.discretization_threshold;
414        
415        for edge_idx in edge_indices {
416            let edge_idx_val = edge_idx.index();
417            if let Ok(edge_data) = self.get_edge_data_mut(edge_idx_val) {
418                // 更新结构 logits
419                if let Some(&struct_grad) = structure_grads.get(&(edge_idx_val, 0)) {
420                    edge_data.update_logits(struct_grad, structure_lr);
421                }
422
423                // 更新权重(简化:不实际更新,只存储梯度)
424                // 实际使用需要集成 autograd
425
426                // 离散化
427                edge_data.discretize(temperature, discretization_threshold);
428            }
429        }
430        
431        // 4. 剪枝弱边(存在概率低于阈值的边)
432        self.prune_weak_edges()?;
433        
434        Ok(())
435    }
436
437    /// 剪枝弱边
438    ///
439    /// 利用桶式邻接表的 O(1) 删除优势
440    pub fn prune_weak_edges(&mut self) -> GraphResult<usize> {
441        let mut pruned = 0;
442        let threshold = self.config.discretization_threshold;
443        
444        // 收集要删除的边索引
445        let edges_to_remove: Vec<_> = self.graph.edges()
446            .filter(|e| !e.data.exists && e.data.existence_prob < threshold)
447            .map(|e| e.index)
448            .collect();
449        
450        // 删除边
451        for edge_idx in edges_to_remove {
452            let _ = self.graph.remove_edge(edge_idx);
453            pruned += 1;
454        }
455        
456        Ok(pruned)
457    }
458
459    /// 离散化整个图
460    pub fn discretize(&mut self) -> GraphResult<()> {
461        let temperature = self.config.gradient_config.temperature;
462        let threshold = self.config.discretization_threshold;
463        
464        let edge_indices: Vec<_> = self.graph.edges().map(|e| e.index).collect();
465        
466        for edge_idx in edge_indices {
467            let edge_idx_val = edge_idx.index();
468            if let Ok(edge_data) = self.get_edge_data_mut(edge_idx_val) {
469                edge_data.discretize(temperature, threshold);
470            }
471        }
472        
473        Ok(())
474    }
475
476    /// 获取图结构(不可变引用)
477    pub fn graph(&self) -> &Graph<NodeData, EdgeData> {
478        &self.graph
479    }
480
481    /// 获取图结构(可变引用)
482    pub fn graph_mut(&mut self) -> &mut Graph<NodeData, EdgeData> {
483        &mut self.graph
484    }
485
486    /// 获取配置
487    pub fn config(&self) -> &UnifiedConfig {
488        &self.config
489    }
490
491    /// 获取边数
492    pub fn edge_count(&self) -> usize {
493        self.graph.edge_count()
494    }
495
496    /// 获取节点数
497    pub fn node_count(&self) -> usize {
498        self.graph.node_count()
499    }
500
501    /// 获取剪枝的边数
502    pub fn num_pruned_edges(&self) -> usize {
503        self.graph.edges().filter(|e| !e.data.exists).count()
504    }
505}
506
507#[cfg(test)]
508mod tests {
509    use super::*;
510
511    #[test]
512    #[cfg(feature = "tensor")]
513    fn test_unified_graph_basic() {
514        // 创建统一图
515        let config = UnifiedConfig::default()
516            .with_structure_lr(0.01)
517            .with_param_lr(0.001);
518        let mut graph = UnifiedGraph::new(config);
519
520        // 添加节点
521        let features1 = DenseTensor::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
522        let features2 = DenseTensor::from_vec(vec![4.0, 5.0, 6.0], vec![1, 3]);
523        let n1 = graph.add_node(features1).unwrap();
524        let n2 = graph.add_node(features2).unwrap();
525
526        assert_eq!(graph.node_count(), 2);
527
528        // 添加边(使用节点索引)
529        let weight = DenseTensor::from_vec(vec![0.1, 0.2, 0.3], vec![1, 3]);
530        let _edge = graph.add_edge(n1, n2, weight, 0.8).unwrap();
531
532        assert_eq!(graph.edge_count(), 1);
533    }
534
535    #[test]
536    #[cfg(feature = "tensor")]
537    fn test_edge_data_update() {
538        let weight = DenseTensor::from_vec(vec![0.1, 0.2, 0.3], vec![1, 3]);
539        let mut edge_data = EdgeData::new(weight, 0.5);
540
541        // 测试 logits 更新
542        edge_data.update_logits(0.1, 0.01);
543        assert!(edge_data.structure_logits > 0.0);
544
545        // 测试离散化
546        edge_data.discretize(1.0, 0.5);
547        // logits > 0 时,概率 > 0.5,所以 exists 应该为 true
548        assert!(edge_data.exists);
549    }
550
551    #[test]
552    #[cfg(feature = "tensor")]
553    fn test_unified_graph_joint_optimization() {
554        // 创建统一图
555        let config = UnifiedConfig::default()
556            .with_structure_lr(0.01)
557            .with_param_lr(0.001)
558            .with_sparsity(0.1);
559        let mut graph = UnifiedGraph::new(config);
560
561        // 添加节点
562        let features1 = DenseTensor::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
563        let features2 = DenseTensor::from_vec(vec![4.0, 5.0, 6.0], vec![1, 3]);
564        let _n1 = graph.add_node(features1).unwrap();
565        let _n2 = graph.add_node(features2).unwrap();
566
567        // 添加边(权重形状需要匹配:[out_features, in_features])
568        // 对于输入 [1, 3],权重应该是 [3, 3] 才能进行矩阵乘法
569        let weight = DenseTensor::from_vec(vec![
570            0.1, 0.2, 0.3,
571            0.4, 0.5, 0.6,
572            0.7, 0.8, 0.9,
573        ], vec![3, 3]);
574        let _edge = graph.add_edge(_n1, _n2, weight, 0.8).unwrap();
575
576        let initial_edges = graph.edge_count();
577        assert_eq!(initial_edges, 1);
578
579        // 创建目标输出(用于计算 loss)
580        let target = DenseTensor::from_vec(vec![0.5, 0.5, 0.5], vec![1, 3]);
581
582        // 前向传播
583        let input = DenseTensor::from_vec(vec![1.0, 1.0, 1.0], vec![1, 3]);
584        let output = graph.forward(&input).unwrap();
585
586        // 计算 loss
587        let loss = graph.compute_loss(&target, &output);
588
589        // 联合优化一步
590        let result = graph.joint_optimization_step(&loss);
591        assert!(result.is_ok());
592
593        // 验证优化后图仍然有效
594        assert!(graph.node_count() > 0);
595        assert!(graph.edge_count() > 0);
596
597        println!("✓ Joint optimization step completed successfully");
598    }
599
600    #[test]
601    #[cfg(feature = "tensor")]
602    fn test_unified_graph_pruning() {
603        // 创建统一图,设置较低的离散化阈值
604        let config = UnifiedConfig::default()
605            .with_structure_lr(0.1)
606            .with_threshold(0.3);
607        let mut graph = UnifiedGraph::new(config);
608
609        // 添加节点和边
610        let features1 = DenseTensor::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
611        let features2 = DenseTensor::from_vec(vec![4.0, 5.0, 6.0], vec![1, 3]);
612        let n1 = graph.add_node(features1).unwrap();
613        let n2 = graph.add_node(features2).unwrap();
614
615        // 添加低概率边(应该被剪枝)
616        let weight = DenseTensor::from_vec(vec![0.1, 0.2, 0.3], vec![1, 3]);
617        let _edge = graph.add_edge(n1, n2, weight, 0.2).unwrap(); // 初始概率 0.2 < 0.3
618
619        // 离散化
620        let result = graph.discretize();
621        assert!(result.is_ok());
622
623        // 剪枝弱边
624        let pruned = graph.prune_weak_edges();
625        assert!(pruned.is_ok());
626
627        // 验证边被剪枝
628        let pruned_count = pruned.unwrap();
629        // 注意:pruned_count 可能为 0,因为离散化后 exists 可能为 false 但 prob 不一定低于阈值
630
631        println!("✓ Pruning test completed: {} edges pruned", pruned_count);
632    }
633}