1use crate::computation_graph::{ComputationGraph, TensorOp};
35use crate::proof_storage::ProofFragment;
36use std::fmt::Write as FmtWrite;
37
38pub struct GraphVisualizer;
40
41impl GraphVisualizer {
42 pub fn to_dot(graph: &ComputationGraph) -> String {
50 let mut dot = String::new();
51 writeln!(dot, "digraph ComputationGraph {{").unwrap();
52 writeln!(dot, " rankdir=TB;").unwrap();
53 writeln!(dot, " node [shape=box, style=filled];").unwrap();
54 writeln!(dot).unwrap();
55
56 for (node_id, node) in &graph.nodes {
58 let color = Self::node_color(&node.op);
59 let shape = if graph.inputs.contains(node_id) {
60 "ellipse"
61 } else if graph.outputs.contains(node_id) {
62 "doubleoctagon"
63 } else {
64 "box"
65 };
66
67 let label = Self::format_operation(&node.op);
68 writeln!(
69 dot,
70 " \"{}\" [label=\"{}\\n{}\", fillcolor=\"{}\", shape={}];",
71 Self::escape(node_id),
72 Self::escape(node_id),
73 label,
74 color,
75 shape
76 )
77 .unwrap();
78 }
79
80 writeln!(dot).unwrap();
81
82 for (node_id, node) in &graph.nodes {
84 for input in &node.inputs {
85 writeln!(
86 dot,
87 " \"{}\" -> \"{}\";",
88 Self::escape(input),
89 Self::escape(node_id)
90 )
91 .unwrap();
92 }
93 }
94
95 writeln!(dot).unwrap();
97 writeln!(dot, " subgraph cluster_legend {{").unwrap();
98 writeln!(dot, " label=\"Legend\";").unwrap();
99 writeln!(dot, " style=filled;").unwrap();
100 writeln!(dot, " fillcolor=lightgrey;").unwrap();
101 writeln!(
102 dot,
103 " legend_input [label=\"Input\", shape=ellipse, fillcolor=lightblue];"
104 )
105 .unwrap();
106 writeln!(
107 dot,
108 " legend_output [label=\"Output\", shape=doubleoctagon, fillcolor=lightgreen];"
109 )
110 .unwrap();
111 writeln!(
112 dot,
113 " legend_compute [label=\"Compute\", shape=box, fillcolor=lightyellow];"
114 )
115 .unwrap();
116 writeln!(dot, " }}").unwrap();
117
118 writeln!(dot, "}}").unwrap();
119 dot
120 }
121
122 fn node_color(op: &TensorOp) -> &'static str {
124 match op {
125 TensorOp::Input { .. } | TensorOp::Constant { .. } => "lightblue",
126 TensorOp::MatMul | TensorOp::Einsum { .. } => "orange",
127 TensorOp::Add | TensorOp::Mul | TensorOp::Sub | TensorOp::Div => "yellow",
128 TensorOp::ReLU
129 | TensorOp::Tanh
130 | TensorOp::Sigmoid
131 | TensorOp::GELU
132 | TensorOp::Softmax { .. } => "lightgreen",
133 TensorOp::LayerNorm { .. } | TensorOp::BatchNorm { .. } => "lightcoral",
134 TensorOp::Dropout { .. } => "plum",
135 TensorOp::Reshape { .. } | TensorOp::Transpose { .. } | TensorOp::Slice { .. } => {
136 "lightyellow"
137 }
138 _ => "white",
139 }
140 }
141
142 fn format_operation(op: &TensorOp) -> String {
144 match op {
145 TensorOp::Input { name } => format!("Input({})", name),
146 TensorOp::Constant { value_cid } => format!("Const(cid:{})", &value_cid[..8]),
147 TensorOp::MatMul => "MatMul".to_string(),
148 TensorOp::Einsum { subscripts } => format!("Einsum({})", subscripts),
149 TensorOp::Add => "Add".to_string(),
150 TensorOp::Mul => "Multiply".to_string(),
151 TensorOp::Sub => "Subtract".to_string(),
152 TensorOp::Div => "Divide".to_string(),
153 TensorOp::ReLU => "ReLU".to_string(),
154 TensorOp::Tanh => "Tanh".to_string(),
155 TensorOp::Sigmoid => "Sigmoid".to_string(),
156 TensorOp::GELU => "GELU".to_string(),
157 TensorOp::Softmax { axis } => format!("Softmax(axis={})", axis),
158 TensorOp::LayerNorm {
159 normalized_shape: _,
160 eps,
161 } => format!("LayerNorm(ε={:.1e})", eps),
162 TensorOp::BatchNorm { eps, momentum } => {
163 format!("BatchNorm(ε={:.1e}, μ={:.2})", eps, momentum)
164 }
165 TensorOp::Dropout { p } => format!("Dropout({:.2})", p),
166 TensorOp::Reshape { shape } => format!("Reshape({:?})", shape),
167 TensorOp::Transpose { axes } => format!("Transpose({:?})", axes),
168 TensorOp::ReduceSum { axes, keepdims: _ } => format!("ReduceSum({:?})", axes),
169 TensorOp::ReduceMean { axes, keepdims: _ } => format!("ReduceMean({:?})", axes),
170 TensorOp::Concat { axis } => format!("Concat(axis={})", axis),
171 TensorOp::Split { axis, sections } => {
172 format!("Split(axis={}, n={})", axis, sections.len())
173 }
174 TensorOp::Gather { axis } => format!("Gather(axis={})", axis),
175 TensorOp::Scatter { axis } => format!("Scatter(axis={})", axis),
176 TensorOp::Slice {
177 start,
178 end,
179 strides,
180 } => format!("Slice({:?}:{:?}:{:?})", start, end, strides),
181 TensorOp::Pad { padding, mode: _ } => format!("Pad({:?})", padding),
182 TensorOp::Exp => "Exp".to_string(),
183 TensorOp::Log => "Log".to_string(),
184 TensorOp::Pow { exponent } => format!("Pow({})", exponent),
185 TensorOp::Sqrt => "Sqrt".to_string(),
186 TensorOp::FusedLinear => "FusedLinear".to_string(),
187 TensorOp::FusedAddReLU => "FusedAdd+ReLU".to_string(),
188 TensorOp::FusedBatchNormReLU { eps, momentum } => {
189 format!("FusedBN+ReLU(ε={:.1e}, μ={:.2})", eps, momentum)
190 }
191 TensorOp::FusedLayerNormDropout {
192 normalized_shape: _,
193 eps,
194 dropout_p,
195 } => format!("FusedLN+Dropout(ε={:.1e}, p={:.2})", eps, dropout_p),
196 }
197 }
198
199 fn escape(s: &str) -> String {
201 s.replace('\"', "\\\"")
202 .replace('\n', "\\n")
203 .replace('\t', "\\t")
204 }
205
206 pub fn graph_stats(graph: &ComputationGraph) -> String {
208 let mut stats = String::new();
209 writeln!(stats, "Graph Statistics:").unwrap();
210 writeln!(stats, " Total nodes: {}", graph.nodes.len()).unwrap();
211 writeln!(stats, " Input nodes: {}", graph.inputs.len()).unwrap();
212 writeln!(stats, " Output nodes: {}", graph.outputs.len()).unwrap();
213
214 let mut op_counts = std::collections::HashMap::new();
216 for node in graph.nodes.values() {
217 let op_name = Self::operation_name(&node.op);
218 *op_counts.entry(op_name).or_insert(0) += 1;
219 }
220
221 writeln!(stats, " Operation counts:").unwrap();
222 let mut ops: Vec<_> = op_counts.into_iter().collect();
223 ops.sort_by(|a, b| b.1.cmp(&a.1));
224 for (op, count) in ops {
225 writeln!(stats, " {}: {}", op, count).unwrap();
226 }
227
228 stats
229 }
230
231 fn operation_name(op: &TensorOp) -> &'static str {
232 match op {
233 TensorOp::Input { .. } => "Input",
234 TensorOp::Constant { .. } => "Constant",
235 TensorOp::MatMul => "MatMul",
236 TensorOp::Einsum { .. } => "Einsum",
237 TensorOp::Add => "Add",
238 TensorOp::Mul => "Mul",
239 TensorOp::Sub => "Sub",
240 TensorOp::Div => "Div",
241 TensorOp::ReLU => "ReLU",
242 TensorOp::Tanh => "Tanh",
243 TensorOp::Sigmoid => "Sigmoid",
244 TensorOp::GELU => "GELU",
245 TensorOp::Softmax { .. } => "Softmax",
246 TensorOp::LayerNorm { .. } => "LayerNorm",
247 TensorOp::BatchNorm { .. } => "BatchNorm",
248 TensorOp::Dropout { .. } => "Dropout",
249 TensorOp::Reshape { .. } => "Reshape",
250 TensorOp::Transpose { .. } => "Transpose",
251 TensorOp::ReduceSum { .. } => "ReduceSum",
252 TensorOp::ReduceMean { .. } => "ReduceMean",
253 TensorOp::Concat { .. } => "Concat",
254 TensorOp::Split { .. } => "Split",
255 TensorOp::Gather { .. } => "Gather",
256 TensorOp::Scatter { .. } => "Scatter",
257 TensorOp::Slice { .. } => "Slice",
258 TensorOp::Pad { .. } => "Pad",
259 TensorOp::Exp => "Exp",
260 TensorOp::Log => "Log",
261 TensorOp::Pow { .. } => "Pow",
262 TensorOp::Sqrt => "Sqrt",
263 TensorOp::FusedLinear => "FusedLinear",
264 TensorOp::FusedAddReLU => "FusedAddReLU",
265 TensorOp::FusedBatchNormReLU { .. } => "FusedBatchNormReLU",
266 TensorOp::FusedLayerNormDropout { .. } => "FusedLayerNormDropout",
267 }
268 }
269}
270
271pub struct ProofVisualizer;
273
274impl ProofVisualizer {
275 pub fn to_dot(proof: &ProofFragment, id: usize) -> String {
280 let mut dot = String::new();
281 writeln!(dot, "digraph ProofTree {{").unwrap();
282 writeln!(dot, " rankdir=TB;").unwrap();
283 writeln!(dot, " node [shape=box, style=\"filled,rounded\"];").unwrap();
284 writeln!(dot).unwrap();
285
286 let mut node_counter = 0;
287 Self::write_proof_node(&mut dot, proof, id, &mut node_counter);
288
289 writeln!(dot, "}}").unwrap();
290 dot
291 }
292
293 fn write_proof_node(
294 dot: &mut String,
295 proof: &ProofFragment,
296 node_id: usize,
297 counter: &mut usize,
298 ) {
299 let color = if proof.premise_refs.is_empty() {
300 "lightblue" } else {
302 "lightyellow" };
304
305 let conclusion_str = format!("{:?}", proof.conclusion);
306 writeln!(
307 dot,
308 " node_{} [label=\"{}\", fillcolor=\"{}\"];",
309 node_id,
310 GraphVisualizer::escape(&conclusion_str),
311 color
312 )
313 .unwrap();
314
315 for premise_ref in &proof.premise_refs {
317 *counter += 1;
318 let premise_id = *counter;
319 let premise_str = if let Some(ref hint) = premise_ref.conclusion_hint {
320 hint.clone()
321 } else {
322 format!("CID: {}", premise_ref.cid)
323 };
324 writeln!(
325 dot,
326 " node_{} [label=\"{}\", fillcolor=\"lightgray\"];",
327 premise_id,
328 GraphVisualizer::escape(&premise_str)
329 )
330 .unwrap();
331 writeln!(dot, " node_{} -> node_{};", node_id, premise_id).unwrap();
332 }
333
334 if let Some(ref rule_ref) = proof.rule_applied {
336 writeln!(
337 dot,
338 " node_{}_rule [label=\"Rule: {}\", shape=note, fillcolor=\"lightyellow\"];",
339 node_id,
340 GraphVisualizer::escape(&rule_ref.rule_id)
341 )
342 .unwrap();
343 writeln!(
344 dot,
345 " node_{}_rule -> node_{} [style=dashed];",
346 node_id, node_id
347 )
348 .unwrap();
349 }
350 }
351
352 pub fn explain(proof: &ProofFragment, depth: usize) -> String {
354 let mut explanation = String::new();
355 let indent = " ".repeat(depth);
356
357 writeln!(explanation, "{}Prove: {:?}", indent, proof.conclusion).unwrap();
358
359 if proof.premise_refs.is_empty() {
360 writeln!(explanation, "{} ✓ This is a known fact", indent).unwrap();
361 } else {
362 if let Some(ref rule_ref) = proof.rule_applied {
363 writeln!(explanation, "{} Using rule: {}", indent, rule_ref.rule_id).unwrap();
364 }
365 writeln!(
366 explanation,
367 "{} Requires proving {} premise(s):",
368 indent,
369 proof.premise_refs.len()
370 )
371 .unwrap();
372 for (i, premise_ref) in proof.premise_refs.iter().enumerate() {
373 let hint = premise_ref
374 .conclusion_hint
375 .as_deref()
376 .unwrap_or("(premise)");
377 writeln!(explanation, "{} {}. {}", indent, i + 1, hint).unwrap();
378 }
379 }
380
381 if let Some(complexity) = proof.metadata.complexity {
382 writeln!(explanation, "{} Complexity: {} steps", indent, complexity).unwrap();
383 }
384 writeln!(explanation, "{} Depth: {}", indent, proof.metadata.depth).unwrap();
385
386 explanation
387 }
388
389 pub fn proof_stats(proof: &ProofFragment) -> String {
391 let mut stats = String::new();
392 writeln!(stats, "Proof Statistics:").unwrap();
393 writeln!(stats, " ID: {}", proof.id).unwrap();
394 writeln!(stats, " Direct premises: {}", proof.premise_refs.len()).unwrap();
395
396 writeln!(
397 stats,
398 " Complexity: {} steps",
399 proof.metadata.complexity.unwrap_or(0)
400 )
401 .unwrap();
402 writeln!(stats, " Depth: {}", proof.metadata.depth).unwrap();
403 if let Some(ref created_by) = proof.metadata.created_by {
404 writeln!(stats, " Created by: {}", created_by).unwrap();
405 }
406
407 if proof.premise_refs.is_empty() {
408 writeln!(stats, " Type: Fact (axiom)").unwrap();
409 } else {
410 writeln!(stats, " Type: Rule application").unwrap();
411 if let Some(ref rule_ref) = proof.rule_applied {
412 writeln!(stats, " Rule: {}", rule_ref.rule_id).unwrap();
413 }
414 }
415
416 if !proof.substitution.is_empty() {
417 writeln!(stats, " Substitutions: {}", proof.substitution.len()).unwrap();
418 }
419
420 stats
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427 use crate::{ComputationGraph, GraphNode, Predicate, TensorOp, Term};
428
429 #[test]
430 fn test_graph_to_dot() {
431 let mut graph = ComputationGraph::new();
432
433 let input = GraphNode::new(
434 "input".to_string(),
435 TensorOp::Input {
436 name: "x".to_string(),
437 },
438 );
439 graph.add_node(input).unwrap();
440 graph.mark_input("input".to_string());
441
442 let relu =
443 GraphNode::new("relu".to_string(), TensorOp::ReLU).add_input("input".to_string());
444 graph.add_node(relu).unwrap();
445 graph.mark_output("relu".to_string());
446
447 let dot = GraphVisualizer::to_dot(&graph);
448
449 assert!(dot.contains("digraph ComputationGraph"));
450 assert!(dot.contains("\"input\""));
451 assert!(dot.contains("\"relu\""));
452 assert!(dot.contains("\"input\" -> \"relu\""));
453 }
454
455 #[test]
456 fn test_graph_stats() {
457 let mut graph = ComputationGraph::new();
458
459 let input = GraphNode::new(
460 "input".to_string(),
461 TensorOp::Input {
462 name: "x".to_string(),
463 },
464 );
465 graph.add_node(input).unwrap();
466
467 let relu =
468 GraphNode::new("relu".to_string(), TensorOp::ReLU).add_input("input".to_string());
469 graph.add_node(relu).unwrap();
470
471 let stats = GraphVisualizer::graph_stats(&graph);
472
473 assert!(stats.contains("Total nodes: 2"));
474 assert!(stats.contains("Input: 1"));
475 assert!(stats.contains("ReLU: 1"));
476 }
477
478 #[test]
479 fn test_proof_to_dot() {
480 use crate::proof_storage::{ProofFragmentRef, ProofMetadata, RuleRef};
481
482 let conclusion = Predicate::new(
483 "ancestor".to_string(),
484 vec![
485 Term::Const(crate::Constant::String("Alice".to_string())),
486 Term::Const(crate::Constant::String("Bob".to_string())),
487 ],
488 );
489
490 let proof = ProofFragment {
491 id: "proof_1".to_string(),
492 conclusion,
493 rule_applied: Some(RuleRef {
494 rule_id: "ancestor_rule".to_string(),
495 rule_cid: None,
496 rule: None,
497 }),
498 premise_refs: vec![ProofFragmentRef {
499 cid: ipfrs_core::Cid::default(),
500 conclusion_hint: Some("parent(Alice, Bob)".to_string()),
501 }],
502 substitution: vec![],
503 metadata: ProofMetadata {
504 created_at: None,
505 created_by: None,
506 complexity: Some(2),
507 depth: 1,
508 custom: std::collections::HashMap::new(),
509 },
510 };
511
512 let dot = ProofVisualizer::to_dot(&proof, 0);
513
514 assert!(dot.contains("digraph ProofTree"));
515 assert!(dot.contains("ancestor"));
516 assert!(dot.contains("parent"));
517 }
518
519 #[test]
520 fn test_proof_explain() {
521 use crate::proof_storage::ProofMetadata;
522
523 let conclusion = Predicate::new(
524 "test".to_string(),
525 vec![Term::Const(crate::Constant::String("A".to_string()))],
526 );
527
528 let proof = ProofFragment {
529 id: "proof_2".to_string(),
530 conclusion,
531 rule_applied: None,
532 premise_refs: vec![],
533 substitution: vec![],
534 metadata: ProofMetadata {
535 created_at: None,
536 created_by: None,
537 complexity: None,
538 depth: 0,
539 custom: std::collections::HashMap::new(),
540 },
541 };
542
543 let explanation = ProofVisualizer::explain(&proof, 0);
544
545 assert!(explanation.contains("Prove"));
546 assert!(explanation.contains("known fact"));
547 }
548
549 #[test]
550 fn test_proof_stats() {
551 use crate::proof_storage::ProofMetadata;
552
553 let conclusion = Predicate::new(
554 "test".to_string(),
555 vec![Term::Const(crate::Constant::String("A".to_string()))],
556 );
557
558 let proof = ProofFragment {
559 id: "proof_3".to_string(),
560 conclusion,
561 rule_applied: None,
562 premise_refs: vec![],
563 substitution: vec![],
564 metadata: ProofMetadata {
565 created_at: None,
566 created_by: None,
567 complexity: None,
568 depth: 0,
569 custom: std::collections::HashMap::new(),
570 },
571 };
572
573 let stats = ProofVisualizer::proof_stats(&proof);
574
575 assert!(stats.contains("Proof Statistics"));
576 assert!(stats.contains("Type: Fact"));
577 }
578}