pub struct DifferentiableGraph<T = Vec<f64>> { /* private fields */ }Expand description
可微图结构:支持梯度计算的结构变换
核心思想:将离散的图结构参数化为连续空间, 使得梯度可以反向传播到结构参数。
§Architecture Notes
§与自动微分框架的集成
当前实现使用手动梯度计算。要与真正的自动微分框架(如 dfdx)集成, 需要:
- 将
logits存储为Tensor1D<f64>而非f64 - 构建计算图:logits → probability → adjacency_matrix → loss
- 调用
loss.backward()获取梯度
§与 Graph 的转换
使用 to_graph() 将可微图转换为普通 Graph,
使用 from_graph() 从现有图初始化可微图。
Implementations§
Source§impl<T: Clone + Default> DifferentiableGraph<T>
impl<T: Clone + Default> DifferentiableGraph<T>
Sourcepub fn with_config(num_nodes: usize, config: GradientConfig) -> Self
pub fn with_config(num_nodes: usize, config: GradientConfig) -> Self
创建带配置的可微图
Sourcepub fn init_nodes(&mut self, features: Option<T>)
pub fn init_nodes(&mut self, features: Option<T>)
初始化节点
Sourcepub fn add_learnable_edge(&mut self, src: usize, dst: usize, init_prob: f64)
pub fn add_learnable_edge(&mut self, src: usize, dst: usize, init_prob: f64)
添加可学习边
Sourcepub fn remove_edge(
&mut self,
src: usize,
dst: usize,
) -> Option<DifferentiableEdge>
pub fn remove_edge( &mut self, src: usize, dst: usize, ) -> Option<DifferentiableEdge>
移除边
Sourcepub fn get_probability_matrix(&self) -> Vec<Vec<f64>>
pub fn get_probability_matrix(&self) -> Vec<Vec<f64>>
获取所有边的概率矩阵
Sourcepub fn get_adjacency_matrix(&self) -> Vec<Vec<f64>>
pub fn get_adjacency_matrix(&self) -> Vec<Vec<f64>>
获取离散邻接矩阵(使用 STE)
Sourcepub fn anneal_temperature(&mut self)
pub fn anneal_temperature(&mut self)
温度退火
Sourcepub fn with_temperature_annealing(self, steps: usize) -> Self
pub fn with_temperature_annealing(self, steps: usize) -> Self
设置温度退火
Sourcepub fn discretize(&mut self)
pub fn discretize(&mut self)
离散化所有边(前向传播)
如果启用了 STE 模式,会存储 STE 修正项 (hard - soft), 用于后续梯度计算时修正梯度。
Sourcepub fn compute_structure_gradients(
&mut self,
loss_gradients: &HashMap<(usize, usize), f64>,
) -> HashMap<(usize, usize), f64>
pub fn compute_structure_gradients( &mut self, loss_gradients: &HashMap<(usize, usize), f64>, ) -> HashMap<(usize, usize), f64>
计算结构梯度
§Arguments
loss_gradients- 损失对边存在性的梯度 {(src, dst): ∂L/∂A_ij}
§Returns
HashMap {(src, dst): ∂L/∂logits},可用于更新边的 logits 参数
§Gradient Computation
梯度计算遵循链式法则:
∂L/∂logits = ∂L/∂A * ∂A/∂logits其中 A = σ(logits/τ),所以:
∂A/∂logits = A * (1 - A) / τ§STE 修正
当启用 STE 模式时,梯度会加上 STE 修正项:
gradient = ∂L/∂logits + (hard - soft)这确保了前向传播的离散化与反向传播的连续梯度一致。
§Regularization
§L1 稀疏正则化
L_sparse = λ_sparse * Σ|logits| ∂L_sparse/∂logits = λ_sparse * sign(logits)
梯度下降更新:logits -= lr * gradient
- 正 logits → 正梯度 → logits 减小 → 概率趋向 0 → 稀疏
- 负 logits → 负梯度 → logits 增大 → 概率趋向 0 → 稀疏
§L2 平滑正则化
L_smooth = λ_smooth * Σ_{(i,j),(i,k)∈E} (A_ij - A_ik)² ∂L_smooth/∂A_ij = 2 * λ_smooth * Σ_k (A_ij - A_ik)
平滑正则化鼓励:
- 共享源节点的边有相似概率
- 共享目标节点的边有相似概率
Sourcepub fn optimization_step(
&mut self,
loss_gradients: HashMap<(usize, usize), f64>,
) -> HashMap<(usize, usize), f64>
pub fn optimization_step( &mut self, loss_gradients: HashMap<(usize, usize), f64>, ) -> HashMap<(usize, usize), f64>
一步优化:离散化 -> 计算梯度 -> 更新
Sourcepub fn get_learnable_edges(&self) -> Vec<&DifferentiableEdge>
pub fn get_learnable_edges(&self) -> Vec<&DifferentiableEdge>
获取可微边列表
Sourcepub fn config(&self) -> &GradientConfig
pub fn config(&self) -> &GradientConfig
获取配置
Sourcepub fn set_config(&mut self, config: GradientConfig)
pub fn set_config(&mut self, config: GradientConfig)
设置配置
Sourcepub fn temperature(&self) -> f64
pub fn temperature(&self) -> f64
获取当前温度
Sourcepub fn set_temperature(&mut self, temp: f64)
pub fn set_temperature(&mut self, temp: f64)
设置温度
Sourcepub fn to_graph(&self) -> Graph<usize, f64>
pub fn to_graph(&self) -> Graph<usize, f64>
转换为普通 Graph
使用离散化的边存在性构建 Graph。 边的权重为 1.0(如果存在)或 0.0(如果不存在)。
§Note
此方法创建的图使用节点索引作为节点数据,边权重为 f64。
节点索引通过 NodeIndex::new(index, generation) 创建,
其中 generation 由 Graph 内部管理。
Sourcepub fn to_graph_with_types(
&self,
node_types: &HashMap<usize, OperatorType>,
edge_weights: &HashMap<(usize, usize), WeightTensor>,
) -> Graph<OperatorType, WeightTensor>
pub fn to_graph_with_types( &self, node_types: &HashMap<usize, OperatorType>, edge_weights: &HashMap<(usize, usize), WeightTensor>, ) -> Graph<OperatorType, WeightTensor>
Sourcepub fn from_graph<U, V>(
graph: &Graph<U, V>,
init_probs: Option<HashMap<(usize, usize), f64>>,
) -> DifferentiableGraph<()>
pub fn from_graph<U, V>( graph: &Graph<U, V>, init_probs: Option<HashMap<(usize, usize), f64>>, ) -> DifferentiableGraph<()>
Sourcepub fn from_graph_with_prob<U, V>(
graph: &Graph<U, V>,
init_prob: Option<f64>,
) -> DifferentiableGraph<()>
pub fn from_graph_with_prob<U, V>( graph: &Graph<U, V>, init_prob: Option<f64>, ) -> DifferentiableGraph<()>
Trait Implementations§
Source§impl<T: Clone> Clone for DifferentiableGraph<T>
impl<T: Clone> Clone for DifferentiableGraph<T>
Source§fn clone(&self) -> DifferentiableGraph<T>
fn clone(&self) -> DifferentiableGraph<T>
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreAuto Trait Implementations§
impl<T> Freeze for DifferentiableGraph<T>
impl<T> RefUnwindSafe for DifferentiableGraph<T>where
T: RefUnwindSafe,
impl<T> Send for DifferentiableGraph<T>where
T: Send,
impl<T> Sync for DifferentiableGraph<T>where
T: Sync,
impl<T> Unpin for DifferentiableGraph<T>where
T: Unpin,
impl<T> UnsafeUnpin for DifferentiableGraph<T>
impl<T> UnwindSafe for DifferentiableGraph<T>where
T: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> FmtForward for T
impl<T> FmtForward for T
Source§fn fmt_binary(self) -> FmtBinary<Self>where
Self: Binary,
fn fmt_binary(self) -> FmtBinary<Self>where
Self: Binary,
self to use its Binary implementation when Debug-formatted.Source§fn fmt_display(self) -> FmtDisplay<Self>where
Self: Display,
fn fmt_display(self) -> FmtDisplay<Self>where
Self: Display,
self to use its Display implementation when
Debug-formatted.Source§fn fmt_lower_exp(self) -> FmtLowerExp<Self>where
Self: LowerExp,
fn fmt_lower_exp(self) -> FmtLowerExp<Self>where
Self: LowerExp,
self to use its LowerExp implementation when
Debug-formatted.Source§fn fmt_lower_hex(self) -> FmtLowerHex<Self>where
Self: LowerHex,
fn fmt_lower_hex(self) -> FmtLowerHex<Self>where
Self: LowerHex,
self to use its LowerHex implementation when
Debug-formatted.Source§fn fmt_octal(self) -> FmtOctal<Self>where
Self: Octal,
fn fmt_octal(self) -> FmtOctal<Self>where
Self: Octal,
self to use its Octal implementation when Debug-formatted.Source§fn fmt_pointer(self) -> FmtPointer<Self>where
Self: Pointer,
fn fmt_pointer(self) -> FmtPointer<Self>where
Self: Pointer,
self to use its Pointer implementation when
Debug-formatted.Source§fn fmt_upper_exp(self) -> FmtUpperExp<Self>where
Self: UpperExp,
fn fmt_upper_exp(self) -> FmtUpperExp<Self>where
Self: UpperExp,
self to use its UpperExp implementation when
Debug-formatted.Source§fn fmt_upper_hex(self) -> FmtUpperHex<Self>where
Self: UpperHex,
fn fmt_upper_hex(self) -> FmtUpperHex<Self>where
Self: UpperHex,
self to use its UpperHex implementation when
Debug-formatted.Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pipe for Twhere
T: ?Sized,
impl<T> Pipe for Twhere
T: ?Sized,
Source§fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> Rwhere
Self: Sized,
fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> Rwhere
Self: Sized,
Source§fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> Rwhere
R: 'a,
fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> Rwhere
R: 'a,
self and passes that borrow into the pipe function. Read moreSource§fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> Rwhere
R: 'a,
fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> Rwhere
R: 'a,
self and passes that borrow into the pipe function. Read moreSource§fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
Source§fn pipe_borrow_mut<'a, B, R>(
&'a mut self,
func: impl FnOnce(&'a mut B) -> R,
) -> R
fn pipe_borrow_mut<'a, B, R>( &'a mut self, func: impl FnOnce(&'a mut B) -> R, ) -> R
Source§fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
self, then passes self.as_ref() into the pipe function.Source§fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
self, then passes self.as_mut() into the pipe
function.Source§fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
self, then passes self.deref() into the pipe function.Source§impl<T> Pointable for T
impl<T> Pointable for T
Source§impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
impl<SS, SP> SupersetOf<SS> for SPwhere
SS: SubsetOf<SP>,
Source§fn to_subset(&self) -> Option<SS>
fn to_subset(&self) -> Option<SS>
self from the equivalent element of its
superset. Read moreSource§fn is_in_subset(&self) -> bool
fn is_in_subset(&self) -> bool
self is actually part of its subset T (and can be converted to it).Source§fn to_subset_unchecked(&self) -> SS
fn to_subset_unchecked(&self) -> SS
self.to_subset but without any property checks. Always succeeds.Source§fn from_subset(element: &SS) -> SP
fn from_subset(element: &SS) -> SP
self to the equivalent element of its superset.Source§impl<T> Tap for T
impl<T> Tap for T
Source§fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
Borrow<B> of a value. Read moreSource§fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
BorrowMut<B> of a value. Read moreSource§fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
AsRef<R> view of a value. Read moreSource§fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
AsMut<R> view of a value. Read moreSource§fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
Deref::Target of a value. Read moreSource§fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
Deref::Target of a value. Read moreSource§fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self
fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self
.tap() only in debug builds, and is erased in release builds.Source§fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self
fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self
.tap_mut() only in debug builds, and is erased in release
builds.Source§fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
.tap_borrow() only in debug builds, and is erased in release
builds.Source§fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
.tap_borrow_mut() only in debug builds, and is erased in release
builds.Source§fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
.tap_ref() only in debug builds, and is erased in release
builds.Source§fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
.tap_ref_mut() only in debug builds, and is erased in release
builds.Source§fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
.tap_deref() only in debug builds, and is erased in release
builds.