1use std::collections::HashMap;
51
52use crate::errors::{GraphError, GraphResult};
53use crate::graph::Graph;
54use crate::graph::traits::{GraphBase, GraphOps, GraphQuery};
55use crate::tensor::dense::DenseTensor;
56use crate::tensor::differentiable::GradientConfig;
57use crate::tensor::traits::TensorBase;
58
59#[derive(Debug, Clone)]
61pub struct UnifiedConfig {
62 pub gradient_config: GradientConfig,
64 pub structure_learning_rate: f64,
66 pub param_learning_rate: f64,
68 pub discretization_threshold: f64,
70 pub enable_joint_optimization: bool,
72}
73
74impl Default for UnifiedConfig {
75 fn default() -> Self {
76 Self {
77 gradient_config: GradientConfig::default(),
78 structure_learning_rate: 0.01,
79 param_learning_rate: 0.001,
80 discretization_threshold: 0.5,
81 enable_joint_optimization: true,
82 }
83 }
84}
85
86impl UnifiedConfig {
87 pub fn new(structure_lr: f64, param_lr: f64) -> Self {
89 Self {
90 structure_learning_rate: structure_lr,
91 param_learning_rate: param_lr,
92 ..Default::default()
93 }
94 }
95
96 pub fn with_sparsity(mut self, weight: f64) -> Self {
98 self.gradient_config = self.gradient_config.with_sparsity(weight);
99 self
100 }
101
102 pub fn with_structure_lr(mut self, lr: f64) -> Self {
104 self.structure_learning_rate = lr;
105 self
106 }
107
108 pub fn with_param_lr(mut self, lr: f64) -> Self {
110 self.param_learning_rate = lr;
111 self
112 }
113
114 pub fn with_threshold(mut self, threshold: f64) -> Self {
116 self.discretization_threshold = threshold;
117 self
118 }
119}
120
121#[derive(Debug, Clone)]
123pub struct EdgeData {
124 pub weight: DenseTensor,
126 pub structure_logits: f64,
128 pub existence_prob: f64,
130 pub exists: bool,
132 pub structure_gradient: Option<f64>,
134 pub weight_gradient: Option<DenseTensor>,
136}
137
138impl EdgeData {
139 pub fn new(weight: DenseTensor, init_prob: f64) -> Self {
141 let logits = Self::prob_to_logits(init_prob);
142 Self {
143 weight,
144 structure_logits: logits,
145 existence_prob: init_prob,
146 exists: init_prob > 0.5,
147 structure_gradient: None,
148 weight_gradient: None,
149 }
150 }
151
152 fn prob_to_logits(prob: f64) -> f64 {
154 let p = prob.clamp(1e-7, 1.0 - 1e-7);
155 (p / (1.0 - p)).ln()
156 }
157
158 pub fn logits_to_prob(logits: f64, temperature: f64) -> f64 {
160 1.0 / (1.0 + (-logits / temperature).exp())
161 }
162
163 pub fn update_logits(&mut self, gradient: f64, learning_rate: f64) {
165 self.structure_logits += learning_rate * gradient;
166 self.structure_gradient = Some(gradient);
167 }
168
169 pub fn update_weight(&mut self, gradient: &DenseTensor, learning_rate: f64) {
171 use crate::tensor::traits::TensorOps;
172
173 let lr_tensor = DenseTensor::scalar(learning_rate);
175 let scaled_grad = gradient.mul(&lr_tensor);
176 self.weight = self.weight.sub(&scaled_grad);
177 self.weight_gradient = Some(gradient.clone());
178 }
179
180 pub fn discretize(&mut self, temperature: f64, threshold: f64) {
182 self.existence_prob = Self::logits_to_prob(self.structure_logits, temperature);
183 self.exists = self.existence_prob > threshold;
184 }
185}
186
187#[derive(Debug, Clone)]
189pub struct NodeData {
190 pub features: DenseTensor,
192 pub bias: Option<DenseTensor>,
194}
195
196impl NodeData {
197 pub fn new(features: DenseTensor) -> Self {
199 Self {
200 features,
201 bias: None,
202 }
203 }
204
205 pub fn with_bias(mut self, bias: DenseTensor) -> Self {
207 self.bias = Some(bias);
208 self
209 }
210}
211
212pub struct UnifiedGraph {
221 graph: Graph<NodeData, EdgeData>,
223 config: UnifiedConfig,
225}
226
227impl UnifiedGraph {
228 pub fn new(config: UnifiedConfig) -> Self {
230 Self {
231 graph: Graph::directed(),
232 config,
233 }
234 }
235
236 pub fn from_graph(base_graph: Graph<NodeData, EdgeData>, config: UnifiedConfig) -> Self {
238 Self {
239 graph: base_graph,
240 config,
241 }
242 }
243
244 pub fn add_node(&mut self, features: DenseTensor) -> GraphResult<crate::node::NodeIndex> {
246 let node_data = NodeData::new(features);
247 self.graph.add_node(node_data)
248 }
249
250 pub fn add_edge(
252 &mut self,
253 src: crate::node::NodeIndex,
254 dst: crate::node::NodeIndex,
255 weight: DenseTensor,
256 init_prob: f64,
257 ) -> GraphResult<usize> {
258 if self.graph.get_node(src).is_err() {
260 return Err(GraphError::NotFound(format!("Node {:?} not found", src)));
261 }
262 if self.graph.get_node(dst).is_err() {
263 return Err(GraphError::NotFound(format!("Node {:?} not found", dst)));
264 }
265
266 let edge_data = EdgeData::new(weight, init_prob);
267 let edge_idx = self.graph.add_edge(src, dst, edge_data)?;
268 Ok(edge_idx.index())
269 }
270
271 pub fn get_edge_data(&self, edge_idx: usize) -> Result<&EdgeData, GraphError> {
281 use crate::edge::EdgeIndex;
282
283 let idx = EdgeIndex::new(edge_idx, 0);
284 self.graph.get_edge(idx)
285 }
286
287 pub fn get_edge_data_mut(&mut self, edge_idx: usize) -> Result<&mut EdgeData, GraphError> {
297 use crate::edge::EdgeIndex;
298
299 let idx = EdgeIndex::new(edge_idx, 0);
300
301 self.graph.get_edge(idx)?;
303
304 Ok(&mut self.graph[idx])
306 }
307
308 pub fn forward(&mut self, input: &DenseTensor) -> GraphResult<DenseTensor> {
312 use crate::tensor::traits::TensorOps;
313 use crate::algorithms::traversal::topological_sort;
314
315 let sorted = topological_sort(&self.graph)
317 .map_err(|e| GraphError::InvalidFormat(format!("Topological sort failed: {}", e)))?;
318
319 let mut current = input.clone();
320
321 for node_idx in sorted {
322 let incoming: Vec<_> = self.graph.incident_edges(node_idx).collect();
324
325 if incoming.is_empty() {
326 continue;
328 }
329
330 let mut aggregated = DenseTensor::zeros(current.shape().to_vec());
332 for edge_idx in incoming {
333 if let Ok(edge_data) = self.graph.get_edge(edge_idx) {
334 if edge_data.exists {
335 let weight_t = edge_data.weight.transpose(None);
337 let contribution = current.matmul(&weight_t);
338 aggregated = aggregated.add(&contribution);
339 }
340 }
341 }
342
343 current = aggregated.relu();
345 }
346
347 Ok(current)
348 }
349
350 pub fn compute_loss(&mut self, target: &DenseTensor, output: &DenseTensor) -> DenseTensor {
352 use crate::tensor::traits::TensorOps;
353
354 let diff = output.sub(target);
356 diff.mul(&diff)
357 }
358
359 pub fn backward(&mut self, _loss: &DenseTensor) -> GraphResult<()> {
361 Ok(())
364 }
365
366 pub fn compute_structure_gradients(&mut self, _loss: &DenseTensor) -> GraphResult<HashMap<(usize, usize), f64>> {
368 let mut gradients = HashMap::new();
369
370 let edge_indices: Vec<_> = self.graph.edges().map(|e| e.index).collect();
372
373 for edge_idx in edge_indices {
374 let edge_idx_val = edge_idx.index();
375 let edge_data_clone = self.get_edge_data(edge_idx_val).cloned().ok();
377
378 if let Some(edge_data) = edge_data_clone {
379 if let Some(grad) = edge_data.weight_gradient {
381 let grad_norm: f64 = grad.data().iter().map(|&x| x.abs()).sum();
383
384 gradients.insert((edge_idx_val, 0), grad_norm);
386 }
387 }
388 }
389
390 Ok(gradients)
391 }
392
393 pub fn joint_optimization_step(&mut self, loss: &DenseTensor) -> GraphResult<()> {
403 self.backward(loss)?;
405
406 let structure_grads = self.compute_structure_gradients(loss)?;
408
409 let edge_indices: Vec<_> = self.graph.edges().map(|e| e.index).collect();
411 let temperature = self.config.gradient_config.temperature;
412 let structure_lr = self.config.structure_learning_rate;
413 let discretization_threshold = self.config.discretization_threshold;
414
415 for edge_idx in edge_indices {
416 let edge_idx_val = edge_idx.index();
417 if let Ok(edge_data) = self.get_edge_data_mut(edge_idx_val) {
418 if let Some(&struct_grad) = structure_grads.get(&(edge_idx_val, 0)) {
420 edge_data.update_logits(struct_grad, structure_lr);
421 }
422
423 edge_data.discretize(temperature, discretization_threshold);
428 }
429 }
430
431 self.prune_weak_edges()?;
433
434 Ok(())
435 }
436
437 pub fn prune_weak_edges(&mut self) -> GraphResult<usize> {
441 let mut pruned = 0;
442 let threshold = self.config.discretization_threshold;
443
444 let edges_to_remove: Vec<_> = self.graph.edges()
446 .filter(|e| !e.data.exists && e.data.existence_prob < threshold)
447 .map(|e| e.index)
448 .collect();
449
450 for edge_idx in edges_to_remove {
452 let _ = self.graph.remove_edge(edge_idx);
453 pruned += 1;
454 }
455
456 Ok(pruned)
457 }
458
459 pub fn discretize(&mut self) -> GraphResult<()> {
461 let temperature = self.config.gradient_config.temperature;
462 let threshold = self.config.discretization_threshold;
463
464 let edge_indices: Vec<_> = self.graph.edges().map(|e| e.index).collect();
465
466 for edge_idx in edge_indices {
467 let edge_idx_val = edge_idx.index();
468 if let Ok(edge_data) = self.get_edge_data_mut(edge_idx_val) {
469 edge_data.discretize(temperature, threshold);
470 }
471 }
472
473 Ok(())
474 }
475
476 pub fn graph(&self) -> &Graph<NodeData, EdgeData> {
478 &self.graph
479 }
480
481 pub fn graph_mut(&mut self) -> &mut Graph<NodeData, EdgeData> {
483 &mut self.graph
484 }
485
486 pub fn config(&self) -> &UnifiedConfig {
488 &self.config
489 }
490
491 pub fn edge_count(&self) -> usize {
493 self.graph.edge_count()
494 }
495
496 pub fn node_count(&self) -> usize {
498 self.graph.node_count()
499 }
500
501 pub fn num_pruned_edges(&self) -> usize {
503 self.graph.edges().filter(|e| !e.data.exists).count()
504 }
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510
511 #[test]
512 #[cfg(feature = "tensor")]
513 fn test_unified_graph_basic() {
514 let config = UnifiedConfig::default()
516 .with_structure_lr(0.01)
517 .with_param_lr(0.001);
518 let mut graph = UnifiedGraph::new(config);
519
520 let features1 = DenseTensor::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
522 let features2 = DenseTensor::from_vec(vec![4.0, 5.0, 6.0], vec![1, 3]);
523 let n1 = graph.add_node(features1).unwrap();
524 let n2 = graph.add_node(features2).unwrap();
525
526 assert_eq!(graph.node_count(), 2);
527
528 let weight = DenseTensor::from_vec(vec![0.1, 0.2, 0.3], vec![1, 3]);
530 let _edge = graph.add_edge(n1, n2, weight, 0.8).unwrap();
531
532 assert_eq!(graph.edge_count(), 1);
533 }
534
535 #[test]
536 #[cfg(feature = "tensor")]
537 fn test_edge_data_update() {
538 let weight = DenseTensor::from_vec(vec![0.1, 0.2, 0.3], vec![1, 3]);
539 let mut edge_data = EdgeData::new(weight, 0.5);
540
541 edge_data.update_logits(0.1, 0.01);
543 assert!(edge_data.structure_logits > 0.0);
544
545 edge_data.discretize(1.0, 0.5);
547 assert!(edge_data.exists);
549 }
550
551 #[test]
552 #[cfg(feature = "tensor")]
553 fn test_unified_graph_joint_optimization() {
554 let config = UnifiedConfig::default()
556 .with_structure_lr(0.01)
557 .with_param_lr(0.001)
558 .with_sparsity(0.1);
559 let mut graph = UnifiedGraph::new(config);
560
561 let features1 = DenseTensor::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
563 let features2 = DenseTensor::from_vec(vec![4.0, 5.0, 6.0], vec![1, 3]);
564 let _n1 = graph.add_node(features1).unwrap();
565 let _n2 = graph.add_node(features2).unwrap();
566
567 let weight = DenseTensor::from_vec(vec![
570 0.1, 0.2, 0.3,
571 0.4, 0.5, 0.6,
572 0.7, 0.8, 0.9,
573 ], vec![3, 3]);
574 let _edge = graph.add_edge(_n1, _n2, weight, 0.8).unwrap();
575
576 let initial_edges = graph.edge_count();
577 assert_eq!(initial_edges, 1);
578
579 let target = DenseTensor::from_vec(vec![0.5, 0.5, 0.5], vec![1, 3]);
581
582 let input = DenseTensor::from_vec(vec![1.0, 1.0, 1.0], vec![1, 3]);
584 let output = graph.forward(&input).unwrap();
585
586 let loss = graph.compute_loss(&target, &output);
588
589 let result = graph.joint_optimization_step(&loss);
591 assert!(result.is_ok());
592
593 assert!(graph.node_count() > 0);
595 assert!(graph.edge_count() > 0);
596
597 println!("✓ Joint optimization step completed successfully");
598 }
599
600 #[test]
601 #[cfg(feature = "tensor")]
602 fn test_unified_graph_pruning() {
603 let config = UnifiedConfig::default()
605 .with_structure_lr(0.1)
606 .with_threshold(0.3);
607 let mut graph = UnifiedGraph::new(config);
608
609 let features1 = DenseTensor::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
611 let features2 = DenseTensor::from_vec(vec![4.0, 5.0, 6.0], vec![1, 3]);
612 let n1 = graph.add_node(features1).unwrap();
613 let n2 = graph.add_node(features2).unwrap();
614
615 let weight = DenseTensor::from_vec(vec![0.1, 0.2, 0.3], vec![1, 3]);
617 let _edge = graph.add_edge(n1, n2, weight, 0.2).unwrap(); let result = graph.discretize();
621 assert!(result.is_ok());
622
623 let pruned = graph.prune_weak_edges();
625 assert!(pruned.is_ok());
626
627 let pruned_count = pruned.unwrap();
629 println!("✓ Pruning test completed: {} edges pruned", pruned_count);
632 }
633}