1use crate::errors::GraphResult;
14use crate::graph::traits::{GraphBase, GraphQuery};
15use crate::graph::Graph;
16use crate::transformer::optimization::switch::{OperatorType, WeightTensor};
17use std::collections::HashMap;
18
19type ConstraintFn = Box<dyn Fn(&Graph<OperatorType, WeightTensor>) -> GraphResult<bool> + Send + Sync>;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum Severity {
25 Info,
27 Warning,
29 Error,
31 Critical,
33}
34
35#[derive(Debug, Clone)]
37pub enum DefectType {
38 IsolatedNode,
40 DisconnectedComponent,
42 UnexpectedCycle,
44 MissingResidual,
46 UnbalancedAttention,
48 BlockedGradientFlow,
50 Custom(String),
52}
53
54#[derive(Debug, Clone)]
56pub struct TopologyDefect {
57 pub defect_type: DefectType,
59 pub location: usize,
61 pub severity: Severity,
63 pub description: String,
65 pub suggested_fix: Option<String>,
67}
68
69pub enum TopologyConstraint {
71 ResidualConnection {
73 from_layer: String,
75 to_layer: String,
77 },
78 AttentionHeadBalance {
80 layer: String,
82 tolerance: f64,
84 },
85 GradientFlow {
87 from: String,
89 to: String,
91 },
92 Custom(ConstraintFn),
94}
95
96impl Clone for TopologyConstraint {
97 fn clone(&self) -> Self {
98 match self {
99 Self::ResidualConnection { from_layer, to_layer } => {
100 Self::ResidualConnection {
101 from_layer: from_layer.clone(),
102 to_layer: to_layer.clone(),
103 }
104 }
105 Self::AttentionHeadBalance { layer, tolerance } => {
106 Self::AttentionHeadBalance {
107 layer: layer.clone(),
108 tolerance: *tolerance,
109 }
110 }
111 Self::GradientFlow { from, to } => Self::GradientFlow {
112 from: from.clone(),
113 to: to.clone(),
114 },
115 Self::Custom(_) => Self::ResidualConnection {
117 from_layer: String::new(),
118 to_layer: String::new(),
119 },
120 }
121 }
122}
123
124impl std::fmt::Debug for TopologyConstraint {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 match self {
127 Self::ResidualConnection { from_layer, to_layer } => f
128 .debug_struct("ResidualConnection")
129 .field("from_layer", from_layer)
130 .field("to_layer", to_layer)
131 .finish(),
132 Self::AttentionHeadBalance { layer, tolerance } => f
133 .debug_struct("AttentionHeadBalance")
134 .field("layer", layer)
135 .field("tolerance", tolerance)
136 .finish(),
137 Self::GradientFlow { from, to } => f
138 .debug_struct("GradientFlow")
139 .field("from", from)
140 .field("to", to)
141 .finish(),
142 Self::Custom(_) => f.debug_struct("Custom").finish(),
143 }
144 }
145}
146
147#[derive(Debug, Clone)]
149pub struct ConstraintReport {
150 pub all_satisfied: bool,
152 pub satisfied_count: usize,
154 pub violated_count: usize,
156 pub constraint_details: Vec<ConstraintDetail>,
158}
159
160#[derive(Debug, Clone)]
162pub struct ConstraintDetail {
163 pub description: String,
165 pub satisfied: bool,
167 pub violation_details: Option<String>,
169}
170
171pub struct TopologyValidator {
173 constraints: Vec<TopologyConstraint>,
174 validation_cache: HashMap<String, bool>,
175}
176
177impl TopologyValidator {
178 pub fn new() -> Self {
180 Self {
181 constraints: Vec::new(),
182 validation_cache: HashMap::new(),
183 }
184 }
185
186 pub fn with_default_constraints() -> Self {
188 let mut validator = Self::new();
189
190 validator.add_constraint(TopologyConstraint::ResidualConnection {
192 from_layer: "attention".to_string(),
193 to_layer: "attention_output".to_string(),
194 });
195
196 validator.add_constraint(TopologyConstraint::ResidualConnection {
197 from_layer: "mlp".to_string(),
198 to_layer: "mlp_output".to_string(),
199 });
200
201 validator
202 }
203
204 pub fn add_constraint(&mut self, constraint: TopologyConstraint) {
206 self.constraints.push(constraint);
207 self.validation_cache.clear();
208 }
209
210 pub fn clear_constraints(&mut self) {
212 self.constraints.clear();
213 self.validation_cache.clear();
214 }
215
216 pub fn constraint_count(&self) -> usize {
218 self.constraints.len()
219 }
220
221 pub fn validate(&mut self, graph: &Graph<OperatorType, WeightTensor>) -> GraphResult<ConstraintReport> {
231 let mut details = Vec::new();
232 let mut satisfied_count = 0;
233
234 for constraint in &self.constraints {
235 let (satisfied, description, violation) = match constraint {
236 TopologyConstraint::ResidualConnection { from_layer, to_layer } => {
237 self.validate_residual_connection(graph, from_layer, to_layer)?
238 }
239 TopologyConstraint::AttentionHeadBalance { layer, tolerance } => {
240 self.validate_attention_balance(graph, layer, *tolerance)?
241 }
242 TopologyConstraint::GradientFlow { from, to } => {
243 self.validate_gradient_flow(graph, from, to)?
244 }
245 TopologyConstraint::Custom(func) => {
246 let result = func(graph)?;
247 (result, "Custom constraint".to_string(), None)
248 }
249 };
250
251 if satisfied {
252 satisfied_count += 1;
253 }
254
255 details.push(ConstraintDetail {
256 description,
257 satisfied,
258 violation_details: violation,
259 });
260 }
261
262 Ok(ConstraintReport {
263 all_satisfied: satisfied_count == self.constraints.len(),
264 satisfied_count,
265 violated_count: self.constraints.len() - satisfied_count,
266 constraint_details: details,
267 })
268 }
269
270 pub fn detect_defects(
280 &self,
281 graph: &Graph<OperatorType, WeightTensor>,
282 ) -> GraphResult<Vec<TopologyDefect>> {
283 use crate::algorithms::community::connected_components;
284
285 let mut defects = Vec::new();
286
287 for node_ref in graph.nodes() {
289 let node_id = node_ref.index();
290 let neighbor_count = graph.neighbors(node_id).count();
291
292 if neighbor_count == 0 {
293 defects.push(TopologyDefect {
294 defect_type: DefectType::IsolatedNode,
295 location: node_id.index(),
296 severity: Severity::Warning,
297 description: format!("Node {} has no outgoing edges", node_id.index()),
298 suggested_fix: Some("Connect the node to the computation graph or remove it".to_string()),
299 });
300 }
301 }
302
303 let components = connected_components(graph);
305 if components.len() > 1 {
306 for (i, component) in components.iter().enumerate().skip(1) {
307 defects.push(TopologyDefect {
308 defect_type: DefectType::DisconnectedComponent,
309 location: component.first().map(|idx| idx.index()).unwrap_or(0),
310 severity: Severity::Error,
311 description: format!("Found disconnected component {} with {} nodes", i, component.len()),
312 suggested_fix: Some("Add edges to connect this component to the main graph".to_string()),
313 });
314 }
315 }
316
317 Ok(defects)
318 }
319
320 fn validate_residual_connection(
322 &self,
323 graph: &Graph<OperatorType, WeightTensor>,
324 from_layer: &str,
325 to_layer: &str,
326 ) -> GraphResult<(bool, String, Option<String>)> {
327 let found = graph.nodes().any(|n| {
331 matches!(n.data(), OperatorType::Residual)
332 });
333
334 let description = format!("ResidualConnection: {} -> {}", from_layer, to_layer);
335
336 if found {
337 Ok((true, description, None))
338 } else {
339 Ok((
340 false,
341 description,
342 Some(format!("No residual connection found between {} and {}", from_layer, to_layer)),
343 ))
344 }
345 }
346
347 fn validate_attention_balance(
349 &self,
350 _graph: &Graph<OperatorType, WeightTensor>,
351 layer: &str,
352 tolerance: f64,
353 ) -> GraphResult<(bool, String, Option<String>)> {
354 let description = format!("AttentionHeadBalance: {} (tolerance: {})", layer, tolerance);
358
359 Ok((true, description, None))
361 }
362
363 fn validate_gradient_flow(
365 &self,
366 graph: &Graph<OperatorType, WeightTensor>,
367 from: &str,
368 to: &str,
369 ) -> GraphResult<(bool, String, Option<String>)> {
370 use crate::algorithms::traversal::bfs;
371 use crate::node::NodeIndex;
372
373 let mut path_exists = false;
375
376 for start_node in graph.nodes() {
377 let mut visited: std::collections::HashSet<usize> = std::collections::HashSet::new();
378
379 bfs(graph, start_node.index(), |n: NodeIndex, _depth: usize| {
380 visited.insert(n.index());
381 true
382 });
383
384 path_exists = visited.iter().any(|&n| {
386 let node_idx = NodeIndex::new(n, 0);
387 if let Ok(node_data) = graph.get_node(node_idx) {
388 format!("{:?}", node_data).contains(to)
389 } else {
390 false
391 }
392 });
393
394 if path_exists {
395 break;
396 }
397 }
398
399 let description = format!("GradientFlow: {} -> {}", from, to);
400
401 if path_exists {
402 Ok((true, description, None))
403 } else {
404 Ok((
405 false,
406 description,
407 Some(format!("No gradient flow path from {} to {}", from, to)),
408 ))
409 }
410 }
411}
412
413impl Default for TopologyValidator {
414 fn default() -> Self {
415 Self::new()
416 }
417}
418
419#[derive(Debug, Clone)]
421pub struct AssemblyReport {
422 pub is_valid: bool,
424 pub module_count: usize,
426 pub interface_mismatches: usize,
428 pub module_details: Vec<ModuleDetail>,
430}
431
432#[derive(Debug, Clone)]
434pub struct ModuleDetail {
435 pub name: String,
437 pub input_dim: Option<usize>,
439 pub output_dim: Option<usize>,
441 pub interfaces_match: bool,
443}
444
445pub fn validate_assembly(
455 graph: &Graph<OperatorType, WeightTensor>,
456) -> GraphResult<AssemblyReport> {
457 let mut module_details = Vec::new();
458 let interface_mismatches = 0;
459
460 for node_ref in graph.nodes() {
461 let node_data = node_ref.data();
462
463 let (input_dim, output_dim) = match node_data {
465 OperatorType::Linear { in_features, out_features } => {
466 (Some(*in_features), Some(*out_features))
467 }
468 OperatorType::Attention { hidden_dim, .. } => {
469 (Some(*hidden_dim), Some(*hidden_dim))
470 }
471 OperatorType::MLP { hidden_dim, .. } => {
472 (Some(*hidden_dim), Some(*hidden_dim))
473 }
474 _ => (None, None),
475 };
476
477 module_details.push(ModuleDetail {
478 name: format!("{:?}", node_data),
479 input_dim,
480 output_dim,
481 interfaces_match: true, });
483 }
484
485 Ok(AssemblyReport {
486 is_valid: interface_mismatches == 0,
487 module_count: graph.node_count(),
488 interface_mismatches,
489 module_details,
490 })
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496 use crate::graph::traits::GraphOps;
497
498 #[test]
499 fn test_topology_validator() {
500 let mut validator = TopologyValidator::new();
501
502 validator.add_constraint(TopologyConstraint::ResidualConnection {
503 from_layer: "attn".to_string(),
504 to_layer: "output".to_string(),
505 });
506
507 assert_eq!(validator.constraint_count(), 1);
508 }
509
510 #[test]
511 fn test_defect_detection() {
512 let mut graph = Graph::<OperatorType, WeightTensor>::directed();
514
515 graph.add_node(OperatorType::Linear {
517 in_features: 512,
518 out_features: 1024,
519 }).unwrap();
520
521 let validator = TopologyValidator::new();
522 let defects = validator.detect_defects(&graph).unwrap();
523
524 assert!(!defects.is_empty(), "Should detect isolated node as a defect");
526 }
527
528 #[test]
529 fn test_assembly_validation() {
530 let mut graph = Graph::<OperatorType, WeightTensor>::directed();
531
532 let node = graph.add_node(OperatorType::Linear {
533 in_features: 512,
534 out_features: 1024,
535 }).unwrap();
536
537 let report = validate_assembly(&graph).unwrap();
538
539 assert_eq!(report.module_count, 1);
540 assert!(report.is_valid);
541 }
542}