Skip to main content

god_graph/transformer/optimization/
constraints.rs

1//! Topology Constraints and Validation
2//!
3//! This module defines topology constraints for LLM computation graphs
4//! and provides validation utilities.
5//!
6//! ## Constraint Types
7//!
8//! - Residual Connection: Ensure residual paths are connected
9//! - Attention Head Balance: Ensure attention heads have balanced weights
10//! - Gradient Flow: Ensure gradient flow paths exist
11//! - Custom: User-defined constraint functions
12
13use crate::errors::GraphResult;
14use crate::graph::traits::{GraphBase, GraphQuery};
15use crate::graph::Graph;
16use crate::transformer::optimization::switch::{OperatorType, WeightTensor};
17use std::collections::HashMap;
18
19/// Type alias for custom constraint functions
20type ConstraintFn = Box<dyn Fn(&Graph<OperatorType, WeightTensor>) -> GraphResult<bool> + Send + Sync>;
21
22/// Severity level for topology defects
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum Severity {
25    /// Informational - no action required
26    Info,
27    /// Warning - should be reviewed
28    Warning,
29    /// Error - should be fixed
30    Error,
31    /// Critical - must be fixed immediately
32    Critical,
33}
34
35/// Type of topology defect
36#[derive(Debug, Clone)]
37pub enum DefectType {
38    /// Isolated node with no connections
39    IsolatedNode,
40    /// Disconnected component in graph
41    DisconnectedComponent,
42    /// Cycle detected in feedforward graph
43    UnexpectedCycle,
44    /// Missing residual connection
45    MissingResidual,
46    /// Unbalanced attention heads
47    UnbalancedAttention,
48    /// Gradient flow blocked
49    BlockedGradientFlow,
50    /// Custom defect type
51    Custom(String),
52}
53
54/// Topology defect report
55#[derive(Debug, Clone)]
56pub struct TopologyDefect {
57    /// Type of defect
58    pub defect_type: DefectType,
59    /// Location of the defect (node index)
60    pub location: usize,
61    /// Severity level
62    pub severity: Severity,
63    /// Description of the issue
64    pub description: String,
65    /// Suggested fix
66    pub suggested_fix: Option<String>,
67}
68
69/// Topology constraint definition
70pub enum TopologyConstraint {
71    /// Residual connection must exist between specific nodes
72    ResidualConnection {
73        /// Source layer name
74        from_layer: String,
75        /// Target layer name
76        to_layer: String,
77    },
78    /// Attention heads must have balanced weight norms
79    AttentionHeadBalance {
80        /// Layer name
81        layer: String,
82        /// Tolerance threshold
83        tolerance: f64,
84    },
85    /// Gradient flow path must exist
86    GradientFlow {
87        /// Source node/layer
88        from: String,
89        /// Target node/layer
90        to: String,
91    },
92    /// Custom constraint function
93    Custom(ConstraintFn),
94}
95
96impl Clone for TopologyConstraint {
97    fn clone(&self) -> Self {
98        match self {
99            Self::ResidualConnection { from_layer, to_layer } => {
100                Self::ResidualConnection {
101                    from_layer: from_layer.clone(),
102                    to_layer: to_layer.clone(),
103                }
104            }
105            Self::AttentionHeadBalance { layer, tolerance } => {
106                Self::AttentionHeadBalance {
107                    layer: layer.clone(),
108                    tolerance: *tolerance,
109                }
110            }
111            Self::GradientFlow { from, to } => Self::GradientFlow {
112                from: from.clone(),
113                to: to.clone(),
114            },
115            // Custom constraints cannot be cloned, return a placeholder
116            Self::Custom(_) => Self::ResidualConnection {
117                from_layer: String::new(),
118                to_layer: String::new(),
119            },
120        }
121    }
122}
123
124impl std::fmt::Debug for TopologyConstraint {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        match self {
127            Self::ResidualConnection { from_layer, to_layer } => f
128                .debug_struct("ResidualConnection")
129                .field("from_layer", from_layer)
130                .field("to_layer", to_layer)
131                .finish(),
132            Self::AttentionHeadBalance { layer, tolerance } => f
133                .debug_struct("AttentionHeadBalance")
134                .field("layer", layer)
135                .field("tolerance", tolerance)
136                .finish(),
137            Self::GradientFlow { from, to } => f
138                .debug_struct("GradientFlow")
139                .field("from", from)
140                .field("to", to)
141                .finish(),
142            Self::Custom(_) => f.debug_struct("Custom").finish(),
143        }
144    }
145}
146
147/// Constraint validation report
148#[derive(Debug, Clone)]
149pub struct ConstraintReport {
150    /// Whether all constraints are satisfied
151    pub all_satisfied: bool,
152    /// Number of satisfied constraints
153    pub satisfied_count: usize,
154    /// Number of violated constraints
155    pub violated_count: usize,
156    /// Details for each constraint
157    pub constraint_details: Vec<ConstraintDetail>,
158}
159
160/// Detail for a single constraint
161#[derive(Debug, Clone)]
162pub struct ConstraintDetail {
163    /// Constraint description
164    pub description: String,
165    /// Whether it's satisfied
166    pub satisfied: bool,
167    /// Violation details if any
168    pub violation_details: Option<String>,
169}
170
171/// Topology validator for LLM computation graphs
172pub struct TopologyValidator {
173    constraints: Vec<TopologyConstraint>,
174    validation_cache: HashMap<String, bool>,
175}
176
177impl TopologyValidator {
178    /// Create a new topology validator
179    pub fn new() -> Self {
180        Self {
181            constraints: Vec::new(),
182            validation_cache: HashMap::new(),
183        }
184    }
185
186    /// Create validator with predefined constraints for common architectures
187    pub fn with_default_constraints() -> Self {
188        let mut validator = Self::new();
189        
190        // Add common constraints for transformer architectures
191        validator.add_constraint(TopologyConstraint::ResidualConnection {
192            from_layer: "attention".to_string(),
193            to_layer: "attention_output".to_string(),
194        });
195        
196        validator.add_constraint(TopologyConstraint::ResidualConnection {
197            from_layer: "mlp".to_string(),
198            to_layer: "mlp_output".to_string(),
199        });
200
201        validator
202    }
203
204    /// Add a constraint
205    pub fn add_constraint(&mut self, constraint: TopologyConstraint) {
206        self.constraints.push(constraint);
207        self.validation_cache.clear();
208    }
209
210    /// Remove all constraints
211    pub fn clear_constraints(&mut self) {
212        self.constraints.clear();
213        self.validation_cache.clear();
214    }
215
216    /// Get the number of constraints
217    pub fn constraint_count(&self) -> usize {
218        self.constraints.len()
219    }
220
221    /// Validate all constraints on a graph
222    ///
223    /// # Arguments
224    ///
225    /// * `graph` - Graph to validate
226    ///
227    /// # Returns
228    ///
229    /// Constraint validation report
230    pub fn validate(&mut self, graph: &Graph<OperatorType, WeightTensor>) -> GraphResult<ConstraintReport> {
231        let mut details = Vec::new();
232        let mut satisfied_count = 0;
233
234        for constraint in &self.constraints {
235            let (satisfied, description, violation) = match constraint {
236                TopologyConstraint::ResidualConnection { from_layer, to_layer } => {
237                    self.validate_residual_connection(graph, from_layer, to_layer)?
238                }
239                TopologyConstraint::AttentionHeadBalance { layer, tolerance } => {
240                    self.validate_attention_balance(graph, layer, *tolerance)?
241                }
242                TopologyConstraint::GradientFlow { from, to } => {
243                    self.validate_gradient_flow(graph, from, to)?
244                }
245                TopologyConstraint::Custom(func) => {
246                    let result = func(graph)?;
247                    (result, "Custom constraint".to_string(), None)
248                }
249            };
250
251            if satisfied {
252                satisfied_count += 1;
253            }
254
255            details.push(ConstraintDetail {
256                description,
257                satisfied,
258                violation_details: violation,
259            });
260        }
261
262        Ok(ConstraintReport {
263            all_satisfied: satisfied_count == self.constraints.len(),
264            satisfied_count,
265            violated_count: self.constraints.len() - satisfied_count,
266            constraint_details: details,
267        })
268    }
269
270    /// Detect topology defects in a graph
271    ///
272    /// # Arguments
273    ///
274    /// * `graph` - Graph to analyze
275    ///
276    /// # Returns
277    ///
278    /// List of detected defects
279    pub fn detect_defects(
280        &self,
281        graph: &Graph<OperatorType, WeightTensor>,
282    ) -> GraphResult<Vec<TopologyDefect>> {
283        use crate::algorithms::community::connected_components;
284
285        let mut defects = Vec::new();
286
287        // Check for isolated nodes
288        for node_ref in graph.nodes() {
289            let node_id = node_ref.index();
290            let neighbor_count = graph.neighbors(node_id).count();
291
292            if neighbor_count == 0 {
293                defects.push(TopologyDefect {
294                    defect_type: DefectType::IsolatedNode,
295                    location: node_id.index(),
296                    severity: Severity::Warning,
297                    description: format!("Node {} has no outgoing edges", node_id.index()),
298                    suggested_fix: Some("Connect the node to the computation graph or remove it".to_string()),
299                });
300            }
301        }
302
303        // Check for disconnected components
304        let components = connected_components(graph);
305        if components.len() > 1 {
306            for (i, component) in components.iter().enumerate().skip(1) {
307                defects.push(TopologyDefect {
308                    defect_type: DefectType::DisconnectedComponent,
309                    location: component.first().map(|idx| idx.index()).unwrap_or(0),
310                    severity: Severity::Error,
311                    description: format!("Found disconnected component {} with {} nodes", i, component.len()),
312                    suggested_fix: Some("Add edges to connect this component to the main graph".to_string()),
313                });
314            }
315        }
316
317        Ok(defects)
318    }
319
320    /// Validate a residual connection constraint
321    fn validate_residual_connection(
322        &self,
323        graph: &Graph<OperatorType, WeightTensor>,
324        from_layer: &str,
325        to_layer: &str,
326    ) -> GraphResult<(bool, String, Option<String>)> {
327        // Simplified implementation
328        // In a full implementation, we would search for actual residual connections
329        
330        let found = graph.nodes().any(|n| {
331            matches!(n.data(), OperatorType::Residual)
332        });
333
334        let description = format!("ResidualConnection: {} -> {}", from_layer, to_layer);
335        
336        if found {
337            Ok((true, description, None))
338        } else {
339            Ok((
340                false,
341                description,
342                Some(format!("No residual connection found between {} and {}", from_layer, to_layer)),
343            ))
344        }
345    }
346
347    /// Validate attention head balance
348    fn validate_attention_balance(
349        &self,
350        _graph: &Graph<OperatorType, WeightTensor>,
351        layer: &str,
352        tolerance: f64,
353    ) -> GraphResult<(bool, String, Option<String>)> {
354        // Simplified implementation
355        // In a full implementation, we would compare attention head weight norms
356
357        let description = format!("AttentionHeadBalance: {} (tolerance: {})", layer, tolerance);
358
359        // Assume balanced for now
360        Ok((true, description, None))
361    }
362
363    /// Validate gradient flow path
364    fn validate_gradient_flow(
365        &self,
366        graph: &Graph<OperatorType, WeightTensor>,
367        from: &str,
368        to: &str,
369    ) -> GraphResult<(bool, String, Option<String>)> {
370        use crate::algorithms::traversal::bfs;
371        use crate::node::NodeIndex;
372
373        // Simplified: check if there's a path from any node matching 'from' to any node matching 'to'
374        let mut path_exists = false;
375
376        for start_node in graph.nodes() {
377            let mut visited: std::collections::HashSet<usize> = std::collections::HashSet::new();
378            
379            bfs(graph, start_node.index(), |n: NodeIndex, _depth: usize| {
380                visited.insert(n.index());
381                true
382            });
383
384            // Check if target is reachable
385            path_exists = visited.iter().any(|&n| {
386                let node_idx = NodeIndex::new(n, 0);
387                if let Ok(node_data) = graph.get_node(node_idx) {
388                    format!("{:?}", node_data).contains(to)
389                } else {
390                    false
391                }
392            });
393
394            if path_exists {
395                break;
396            }
397        }
398
399        let description = format!("GradientFlow: {} -> {}", from, to);
400
401        if path_exists {
402            Ok((true, description, None))
403        } else {
404            Ok((
405                false,
406                description,
407                Some(format!("No gradient flow path from {} to {}", from, to)),
408            ))
409        }
410    }
411}
412
413impl Default for TopologyValidator {
414    fn default() -> Self {
415        Self::new()
416    }
417}
418
419/// Assembly validation report
420#[derive(Debug, Clone)]
421pub struct AssemblyReport {
422    /// Whether the assembly is valid
423    pub is_valid: bool,
424    /// Number of modules checked
425    pub module_count: usize,
426    /// Number of interface mismatches
427    pub interface_mismatches: usize,
428    /// Details about each module
429    pub module_details: Vec<ModuleDetail>,
430}
431
432/// Module detail in assembly report
433#[derive(Debug, Clone)]
434pub struct ModuleDetail {
435    /// Module name
436    pub name: String,
437    /// Input dimension
438    pub input_dim: Option<usize>,
439    /// Output dimension
440    pub output_dim: Option<usize>,
441    /// Whether interfaces match
442    pub interfaces_match: bool,
443}
444
445/// Validate assembly of modules
446///
447/// # Arguments
448///
449/// * `graph` - Graph representing the assembled modules
450///
451/// # Returns
452///
453/// Assembly validation report
454pub fn validate_assembly(
455    graph: &Graph<OperatorType, WeightTensor>,
456) -> GraphResult<AssemblyReport> {
457    let mut module_details = Vec::new();
458    let interface_mismatches = 0;
459
460    for node_ref in graph.nodes() {
461        let node_data = node_ref.data();
462        
463        // Extract input/output dimensions based on operator type
464        let (input_dim, output_dim) = match node_data {
465            OperatorType::Linear { in_features, out_features } => {
466                (Some(*in_features), Some(*out_features))
467            }
468            OperatorType::Attention { hidden_dim, .. } => {
469                (Some(*hidden_dim), Some(*hidden_dim))
470            }
471            OperatorType::MLP { hidden_dim, .. } => {
472                (Some(*hidden_dim), Some(*hidden_dim))
473            }
474            _ => (None, None),
475        };
476
477        module_details.push(ModuleDetail {
478            name: format!("{:?}", node_data),
479            input_dim,
480            output_dim,
481            interfaces_match: true, // Simplified
482        });
483    }
484
485    Ok(AssemblyReport {
486        is_valid: interface_mismatches == 0,
487        module_count: graph.node_count(),
488        interface_mismatches,
489        module_details,
490    })
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use crate::graph::traits::GraphOps;
497
498    #[test]
499    fn test_topology_validator() {
500        let mut validator = TopologyValidator::new();
501        
502        validator.add_constraint(TopologyConstraint::ResidualConnection {
503            from_layer: "attn".to_string(),
504            to_layer: "output".to_string(),
505        });
506
507        assert_eq!(validator.constraint_count(), 1);
508    }
509
510    #[test]
511    fn test_defect_detection() {
512        // Create a graph with an isolated node
513        let mut graph = Graph::<OperatorType, WeightTensor>::directed();
514        
515        // Add an isolated node
516        graph.add_node(OperatorType::Linear {
517            in_features: 512,
518            out_features: 1024,
519        }).unwrap();
520
521        let validator = TopologyValidator::new();
522        let defects = validator.detect_defects(&graph).unwrap();
523
524        // Graph with isolated node should have defects
525        assert!(!defects.is_empty(), "Should detect isolated node as a defect");
526    }
527
528    #[test]
529    fn test_assembly_validation() {
530        let mut graph = Graph::<OperatorType, WeightTensor>::directed();
531        
532        let node = graph.add_node(OperatorType::Linear {
533            in_features: 512,
534            out_features: 1024,
535        }).unwrap();
536        
537        let report = validate_assembly(&graph).unwrap();
538        
539        assert_eq!(report.module_count, 1);
540        assert!(report.is_valid);
541    }
542}