use std::collections::HashMap;
use crate::errors::{GraphError, GraphResult};
use crate::graph::Graph;
use crate::graph::traits::{GraphBase, GraphOps, GraphQuery};
use crate::tensor::dense::DenseTensor;
use crate::tensor::differentiable::GradientConfig;
use crate::tensor::traits::TensorBase;
#[derive(Debug, Clone)]
pub struct UnifiedConfig {
pub gradient_config: GradientConfig,
pub structure_learning_rate: f64,
pub param_learning_rate: f64,
pub discretization_threshold: f64,
pub enable_joint_optimization: bool,
}
impl Default for UnifiedConfig {
fn default() -> Self {
Self {
gradient_config: GradientConfig::default(),
structure_learning_rate: 0.01,
param_learning_rate: 0.001,
discretization_threshold: 0.5,
enable_joint_optimization: true,
}
}
}
impl UnifiedConfig {
pub fn new(structure_lr: f64, param_lr: f64) -> Self {
Self {
structure_learning_rate: structure_lr,
param_learning_rate: param_lr,
..Default::default()
}
}
pub fn with_sparsity(mut self, weight: f64) -> Self {
self.gradient_config = self.gradient_config.with_sparsity(weight);
self
}
pub fn with_structure_lr(mut self, lr: f64) -> Self {
self.structure_learning_rate = lr;
self
}
pub fn with_param_lr(mut self, lr: f64) -> Self {
self.param_learning_rate = lr;
self
}
pub fn with_threshold(mut self, threshold: f64) -> Self {
self.discretization_threshold = threshold;
self
}
}
#[derive(Debug, Clone)]
pub struct EdgeData {
pub weight: DenseTensor,
pub structure_logits: f64,
pub existence_prob: f64,
pub exists: bool,
pub structure_gradient: Option<f64>,
pub weight_gradient: Option<DenseTensor>,
}
impl EdgeData {
pub fn new(weight: DenseTensor, init_prob: f64) -> Self {
let logits = Self::prob_to_logits(init_prob);
Self {
weight,
structure_logits: logits,
existence_prob: init_prob,
exists: init_prob > 0.5,
structure_gradient: None,
weight_gradient: None,
}
}
fn prob_to_logits(prob: f64) -> f64 {
let p = prob.clamp(1e-7, 1.0 - 1e-7);
(p / (1.0 - p)).ln()
}
pub fn logits_to_prob(logits: f64, temperature: f64) -> f64 {
1.0 / (1.0 + (-logits / temperature).exp())
}
pub fn update_logits(&mut self, gradient: f64, learning_rate: f64) {
self.structure_logits += learning_rate * gradient;
self.structure_gradient = Some(gradient);
}
pub fn update_weight(&mut self, gradient: &DenseTensor, learning_rate: f64) {
use crate::tensor::traits::TensorOps;
let lr_tensor = DenseTensor::scalar(learning_rate);
let scaled_grad = gradient.mul(&lr_tensor);
self.weight = self.weight.sub(&scaled_grad);
self.weight_gradient = Some(gradient.clone());
}
pub fn discretize(&mut self, temperature: f64, threshold: f64) {
self.existence_prob = Self::logits_to_prob(self.structure_logits, temperature);
self.exists = self.existence_prob > threshold;
}
}
#[derive(Debug, Clone)]
pub struct NodeData {
pub features: DenseTensor,
pub bias: Option<DenseTensor>,
}
impl NodeData {
pub fn new(features: DenseTensor) -> Self {
Self {
features,
bias: None,
}
}
pub fn with_bias(mut self, bias: DenseTensor) -> Self {
self.bias = Some(bias);
self
}
}
pub struct UnifiedGraph {
graph: Graph<NodeData, EdgeData>,
config: UnifiedConfig,
}
impl UnifiedGraph {
pub fn new(config: UnifiedConfig) -> Self {
Self {
graph: Graph::directed(),
config,
}
}
pub fn from_graph(base_graph: Graph<NodeData, EdgeData>, config: UnifiedConfig) -> Self {
Self {
graph: base_graph,
config,
}
}
pub fn add_node(&mut self, features: DenseTensor) -> GraphResult<crate::node::NodeIndex> {
let node_data = NodeData::new(features);
self.graph.add_node(node_data)
}
pub fn add_edge(
&mut self,
src: crate::node::NodeIndex,
dst: crate::node::NodeIndex,
weight: DenseTensor,
init_prob: f64,
) -> GraphResult<usize> {
if self.graph.get_node(src).is_err() {
return Err(GraphError::NotFound(format!("Node {:?} not found", src)));
}
if self.graph.get_node(dst).is_err() {
return Err(GraphError::NotFound(format!("Node {:?} not found", dst)));
}
let edge_data = EdgeData::new(weight, init_prob);
let edge_idx = self.graph.add_edge(src, dst, edge_data)?;
Ok(edge_idx.index())
}
pub fn get_edge_data(&self, edge_idx: usize) -> Result<&EdgeData, GraphError> {
use crate::edge::EdgeIndex;
let idx = EdgeIndex::new(edge_idx, 0);
self.graph.get_edge(idx)
}
pub fn get_edge_data_mut(&mut self, edge_idx: usize) -> Result<&mut EdgeData, GraphError> {
use crate::edge::EdgeIndex;
let idx = EdgeIndex::new(edge_idx, 0);
self.graph.get_edge(idx)?;
Ok(&mut self.graph[idx])
}
pub fn forward(&mut self, input: &DenseTensor) -> GraphResult<DenseTensor> {
use crate::tensor::traits::TensorOps;
use crate::algorithms::traversal::topological_sort;
let sorted = topological_sort(&self.graph)
.map_err(|e| GraphError::InvalidFormat(format!("Topological sort failed: {}", e)))?;
let mut current = input.clone();
for node_idx in sorted {
let incoming: Vec<_> = self.graph.incident_edges(node_idx).collect();
if incoming.is_empty() {
continue;
}
let mut aggregated = DenseTensor::zeros(current.shape().to_vec());
for edge_idx in incoming {
if let Ok(edge_data) = self.graph.get_edge(edge_idx) {
if edge_data.exists {
let weight_t = edge_data.weight.transpose(None);
let contribution = current.matmul(&weight_t);
aggregated = aggregated.add(&contribution);
}
}
}
current = aggregated.relu();
}
Ok(current)
}
pub fn compute_loss(&mut self, target: &DenseTensor, output: &DenseTensor) -> DenseTensor {
use crate::tensor::traits::TensorOps;
let diff = output.sub(target);
diff.mul(&diff)
}
pub fn backward(&mut self, _loss: &DenseTensor) -> GraphResult<()> {
Ok(())
}
pub fn compute_structure_gradients(&mut self, _loss: &DenseTensor) -> GraphResult<HashMap<(usize, usize), f64>> {
let mut gradients = HashMap::new();
let edge_indices: Vec<_> = self.graph.edges().map(|e| e.index).collect();
for edge_idx in edge_indices {
let edge_idx_val = edge_idx.index();
let edge_data_clone = self.get_edge_data(edge_idx_val).cloned().ok();
if let Some(edge_data) = edge_data_clone {
if let Some(grad) = edge_data.weight_gradient {
let grad_norm: f64 = grad.data().iter().map(|&x| x.abs()).sum();
gradients.insert((edge_idx_val, 0), grad_norm);
}
}
}
Ok(gradients)
}
pub fn joint_optimization_step(&mut self, loss: &DenseTensor) -> GraphResult<()> {
self.backward(loss)?;
let structure_grads = self.compute_structure_gradients(loss)?;
let edge_indices: Vec<_> = self.graph.edges().map(|e| e.index).collect();
let temperature = self.config.gradient_config.temperature;
let structure_lr = self.config.structure_learning_rate;
let discretization_threshold = self.config.discretization_threshold;
for edge_idx in edge_indices {
let edge_idx_val = edge_idx.index();
if let Ok(edge_data) = self.get_edge_data_mut(edge_idx_val) {
if let Some(&struct_grad) = structure_grads.get(&(edge_idx_val, 0)) {
edge_data.update_logits(struct_grad, structure_lr);
}
edge_data.discretize(temperature, discretization_threshold);
}
}
self.prune_weak_edges()?;
Ok(())
}
pub fn prune_weak_edges(&mut self) -> GraphResult<usize> {
let mut pruned = 0;
let threshold = self.config.discretization_threshold;
let edges_to_remove: Vec<_> = self.graph.edges()
.filter(|e| !e.data.exists && e.data.existence_prob < threshold)
.map(|e| e.index)
.collect();
for edge_idx in edges_to_remove {
let _ = self.graph.remove_edge(edge_idx);
pruned += 1;
}
Ok(pruned)
}
pub fn discretize(&mut self) -> GraphResult<()> {
let temperature = self.config.gradient_config.temperature;
let threshold = self.config.discretization_threshold;
let edge_indices: Vec<_> = self.graph.edges().map(|e| e.index).collect();
for edge_idx in edge_indices {
let edge_idx_val = edge_idx.index();
if let Ok(edge_data) = self.get_edge_data_mut(edge_idx_val) {
edge_data.discretize(temperature, threshold);
}
}
Ok(())
}
pub fn graph(&self) -> &Graph<NodeData, EdgeData> {
&self.graph
}
pub fn graph_mut(&mut self) -> &mut Graph<NodeData, EdgeData> {
&mut self.graph
}
pub fn config(&self) -> &UnifiedConfig {
&self.config
}
pub fn edge_count(&self) -> usize {
self.graph.edge_count()
}
pub fn node_count(&self) -> usize {
self.graph.node_count()
}
pub fn num_pruned_edges(&self) -> usize {
self.graph.edges().filter(|e| !e.data.exists).count()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(feature = "tensor")]
fn test_unified_graph_basic() {
let config = UnifiedConfig::default()
.with_structure_lr(0.01)
.with_param_lr(0.001);
let mut graph = UnifiedGraph::new(config);
let features1 = DenseTensor::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
let features2 = DenseTensor::from_vec(vec![4.0, 5.0, 6.0], vec![1, 3]);
let n1 = graph.add_node(features1).unwrap();
let n2 = graph.add_node(features2).unwrap();
assert_eq!(graph.node_count(), 2);
let weight = DenseTensor::from_vec(vec![0.1, 0.2, 0.3], vec![1, 3]);
let _edge = graph.add_edge(n1, n2, weight, 0.8).unwrap();
assert_eq!(graph.edge_count(), 1);
}
#[test]
#[cfg(feature = "tensor")]
fn test_edge_data_update() {
let weight = DenseTensor::from_vec(vec![0.1, 0.2, 0.3], vec![1, 3]);
let mut edge_data = EdgeData::new(weight, 0.5);
edge_data.update_logits(0.1, 0.01);
assert!(edge_data.structure_logits > 0.0);
edge_data.discretize(1.0, 0.5);
assert!(edge_data.exists);
}
#[test]
#[cfg(feature = "tensor")]
fn test_unified_graph_joint_optimization() {
let config = UnifiedConfig::default()
.with_structure_lr(0.01)
.with_param_lr(0.001)
.with_sparsity(0.1);
let mut graph = UnifiedGraph::new(config);
let features1 = DenseTensor::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
let features2 = DenseTensor::from_vec(vec![4.0, 5.0, 6.0], vec![1, 3]);
let _n1 = graph.add_node(features1).unwrap();
let _n2 = graph.add_node(features2).unwrap();
let weight = DenseTensor::from_vec(vec![
0.1, 0.2, 0.3,
0.4, 0.5, 0.6,
0.7, 0.8, 0.9,
], vec![3, 3]);
let _edge = graph.add_edge(_n1, _n2, weight, 0.8).unwrap();
let initial_edges = graph.edge_count();
assert_eq!(initial_edges, 1);
let target = DenseTensor::from_vec(vec![0.5, 0.5, 0.5], vec![1, 3]);
let input = DenseTensor::from_vec(vec![1.0, 1.0, 1.0], vec![1, 3]);
let output = graph.forward(&input).unwrap();
let loss = graph.compute_loss(&target, &output);
let result = graph.joint_optimization_step(&loss);
assert!(result.is_ok());
assert!(graph.node_count() > 0);
assert!(graph.edge_count() > 0);
println!("✓ Joint optimization step completed successfully");
}
#[test]
#[cfg(feature = "tensor")]
fn test_unified_graph_pruning() {
let config = UnifiedConfig::default()
.with_structure_lr(0.1)
.with_threshold(0.3);
let mut graph = UnifiedGraph::new(config);
let features1 = DenseTensor::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
let features2 = DenseTensor::from_vec(vec![4.0, 5.0, 6.0], vec![1, 3]);
let n1 = graph.add_node(features1).unwrap();
let n2 = graph.add_node(features2).unwrap();
let weight = DenseTensor::from_vec(vec![0.1, 0.2, 0.3], vec![1, 3]);
let _edge = graph.add_edge(n1, n2, weight, 0.2).unwrap();
let result = graph.discretize();
assert!(result.is_ok());
let pruned = graph.prune_weak_edges();
assert!(pruned.is_ok());
let pruned_count = pruned.unwrap();
println!("✓ Pruning test completed: {} edges pruned", pruned_count);
}
}