Skip to main content

Module differentiable

Module differentiable 

Source
Expand description

可微图结构变换模块

本模块实现了图结构变换操作的梯度计算,支持:

  • 可微边编辑(添加/删除/修改边权重)
  • 可微节点编辑(添加/删除节点)
  • Straight-Through Estimator (STE) 用于离散操作
  • Gumbel-Softmax 松弛用于可微采样
  • 图结构优化的梯度传播

§核心概念

§连续松弛表示

传统图结构是离散的:边要么存在 (1) 要么不存在 (0)。 为了支持梯度计算,我们使用连续松弛:

A_soft = σ(A_logits / τ)

其中:
- A_logits: 边的对数几率(可学习参数)
- τ: 温度参数(控制离散程度)
- σ: sigmoid 函数

§Straight-Through Estimator (STE)

对于需要离散输出的场景,使用 STE:

  • 前向传播:硬阈值(0/1)
  • 反向传播:软梯度(通过 sigmoid)
A_hard = (A_soft > 0.5).to_f64()
gradient = A_hard - A_soft.detach() + A_soft

§示例

use god_gragh::graph::Graph;
use god_gragh::tensor::differentiable::{
    DifferentiableGraph, EdgeEditPolicy, GradientConfig
};

// 创建可微图
let mut diff_graph = DifferentiableGraph::new(4);

// 添加可学习边
diff_graph.add_learnable_edge(0, 1, 0.5);
diff_graph.add_learnable_edge(1, 2, 0.8);

// 计算损失对边权重的梯度
let loss = compute_loss(&diff_graph);
let gradients = diff_graph.compute_structure_gradients(loss);

// 基于梯度更新结构
diff_graph.update_structure(&gradients, learning_rate=0.01);

Structs§

DifferentiableEdge
可微边:包含可学习的存在概率
DifferentiableGraph
可微图结构:支持梯度计算的结构变换
DifferentiableNode
可微节点:包含可学习的存在概率和特征
GradientConfig
图结构变换的梯度配置
GradientRecorder
结构梯度记录器:记录所有结构变换的梯度
GraphTransformer
图结构变换器:执行具体的结构编辑操作
GumbelSoftmaxSampler
Gumbel-Softmax 采样器:用于可微离散采样
StructureEdit
结构编辑操作(带梯度信息)
ThresholdEditPolicy
基于阈值的编辑策略

Enums§

EdgeEditOp
边编辑操作类型
EditOperation
编辑操作枚举
NodeEditOp
节点编辑操作类型

Traits§

EdgeEditPolicy
边编辑策略:定义如何基于梯度编辑边