ipfrs_tensorlogic/
visualization.rs

1//! Visualization utilities for computation graphs and proofs.
2//!
3//! This module provides tools for exporting graphs and proofs to DOT format
4//! (Graphviz) for visualization and debugging.
5//!
6//! # Examples
7//!
8//! ## Visualizing a Computation Graph
9//!
10//! ```
11//! use ipfrs_tensorlogic::{ComputationGraph, GraphNode, TensorOp, GraphVisualizer};
12//!
13//! let mut graph = ComputationGraph::new();
14//!
15//! // Create nodes
16//! let input = GraphNode::new("input".to_string(), TensorOp::Input {
17//!     name: "x".to_string(),
18//! });
19//! graph.add_node(input).unwrap();
20//! graph.mark_input("input".to_string());
21//!
22//! let relu = GraphNode::new("relu".to_string(), TensorOp::ReLU)
23//!     .add_input("input".to_string());
24//! graph.add_node(relu).unwrap();
25//! graph.mark_output("relu".to_string());
26//!
27//! // Export to DOT format
28//! let dot = GraphVisualizer::to_dot(&graph);
29//! println!("{}", dot);
30//! // Save to file: std::fs::write("graph.dot", dot).unwrap();
31//! // Render: dot -Tpng graph.dot -o graph.png
32//! ```
33
34use crate::computation_graph::{ComputationGraph, TensorOp};
35use crate::proof_storage::ProofFragment;
36use std::fmt::Write as FmtWrite;
37
38/// Visualizer for computation graphs
39pub struct GraphVisualizer;
40
41impl GraphVisualizer {
42    /// Export a computation graph to DOT format
43    ///
44    /// The output can be rendered using Graphviz:
45    /// ```bash
46    /// dot -Tpng graph.dot -o graph.png
47    /// dot -Tsvg graph.dot -o graph.svg
48    /// ```
49    pub fn to_dot(graph: &ComputationGraph) -> String {
50        let mut dot = String::new();
51        writeln!(dot, "digraph ComputationGraph {{").unwrap();
52        writeln!(dot, "  rankdir=TB;").unwrap();
53        writeln!(dot, "  node [shape=box, style=filled];").unwrap();
54        writeln!(dot).unwrap();
55
56        // Write nodes
57        for (node_id, node) in &graph.nodes {
58            let color = Self::node_color(&node.op);
59            let shape = if graph.inputs.contains(node_id) {
60                "ellipse"
61            } else if graph.outputs.contains(node_id) {
62                "doubleoctagon"
63            } else {
64                "box"
65            };
66
67            let label = Self::format_operation(&node.op);
68            writeln!(
69                dot,
70                "  \"{}\" [label=\"{}\\n{}\", fillcolor=\"{}\", shape={}];",
71                Self::escape(node_id),
72                Self::escape(node_id),
73                label,
74                color,
75                shape
76            )
77            .unwrap();
78        }
79
80        writeln!(dot).unwrap();
81
82        // Write edges
83        for (node_id, node) in &graph.nodes {
84            for input in &node.inputs {
85                writeln!(
86                    dot,
87                    "  \"{}\" -> \"{}\";",
88                    Self::escape(input),
89                    Self::escape(node_id)
90                )
91                .unwrap();
92            }
93        }
94
95        // Add legend
96        writeln!(dot).unwrap();
97        writeln!(dot, "  subgraph cluster_legend {{").unwrap();
98        writeln!(dot, "    label=\"Legend\";").unwrap();
99        writeln!(dot, "    style=filled;").unwrap();
100        writeln!(dot, "    fillcolor=lightgrey;").unwrap();
101        writeln!(
102            dot,
103            "    legend_input [label=\"Input\", shape=ellipse, fillcolor=lightblue];"
104        )
105        .unwrap();
106        writeln!(
107            dot,
108            "    legend_output [label=\"Output\", shape=doubleoctagon, fillcolor=lightgreen];"
109        )
110        .unwrap();
111        writeln!(
112            dot,
113            "    legend_compute [label=\"Compute\", shape=box, fillcolor=lightyellow];"
114        )
115        .unwrap();
116        writeln!(dot, "  }}").unwrap();
117
118        writeln!(dot, "}}").unwrap();
119        dot
120    }
121
122    /// Get color for a node based on operation type
123    fn node_color(op: &TensorOp) -> &'static str {
124        match op {
125            TensorOp::Input { .. } | TensorOp::Constant { .. } => "lightblue",
126            TensorOp::MatMul | TensorOp::Einsum { .. } => "orange",
127            TensorOp::Add | TensorOp::Mul | TensorOp::Sub | TensorOp::Div => "yellow",
128            TensorOp::ReLU
129            | TensorOp::Tanh
130            | TensorOp::Sigmoid
131            | TensorOp::GELU
132            | TensorOp::Softmax { .. } => "lightgreen",
133            TensorOp::LayerNorm { .. } | TensorOp::BatchNorm { .. } => "lightcoral",
134            TensorOp::Dropout { .. } => "plum",
135            TensorOp::Reshape { .. } | TensorOp::Transpose { .. } | TensorOp::Slice { .. } => {
136                "lightyellow"
137            }
138            _ => "white",
139        }
140    }
141
142    /// Format operation for display
143    fn format_operation(op: &TensorOp) -> String {
144        match op {
145            TensorOp::Input { name } => format!("Input({})", name),
146            TensorOp::Constant { value_cid } => format!("Const(cid:{})", &value_cid[..8]),
147            TensorOp::MatMul => "MatMul".to_string(),
148            TensorOp::Einsum { subscripts } => format!("Einsum({})", subscripts),
149            TensorOp::Add => "Add".to_string(),
150            TensorOp::Mul => "Multiply".to_string(),
151            TensorOp::Sub => "Subtract".to_string(),
152            TensorOp::Div => "Divide".to_string(),
153            TensorOp::ReLU => "ReLU".to_string(),
154            TensorOp::Tanh => "Tanh".to_string(),
155            TensorOp::Sigmoid => "Sigmoid".to_string(),
156            TensorOp::GELU => "GELU".to_string(),
157            TensorOp::Softmax { axis } => format!("Softmax(axis={})", axis),
158            TensorOp::LayerNorm {
159                normalized_shape: _,
160                eps,
161            } => format!("LayerNorm(ε={:.1e})", eps),
162            TensorOp::BatchNorm { eps, momentum } => {
163                format!("BatchNorm(ε={:.1e}, μ={:.2})", eps, momentum)
164            }
165            TensorOp::Dropout { p } => format!("Dropout({:.2})", p),
166            TensorOp::Reshape { shape } => format!("Reshape({:?})", shape),
167            TensorOp::Transpose { axes } => format!("Transpose({:?})", axes),
168            TensorOp::ReduceSum { axes, keepdims: _ } => format!("ReduceSum({:?})", axes),
169            TensorOp::ReduceMean { axes, keepdims: _ } => format!("ReduceMean({:?})", axes),
170            TensorOp::Concat { axis } => format!("Concat(axis={})", axis),
171            TensorOp::Split { axis, sections } => {
172                format!("Split(axis={}, n={})", axis, sections.len())
173            }
174            TensorOp::Gather { axis } => format!("Gather(axis={})", axis),
175            TensorOp::Scatter { axis } => format!("Scatter(axis={})", axis),
176            TensorOp::Slice {
177                start,
178                end,
179                strides,
180            } => format!("Slice({:?}:{:?}:{:?})", start, end, strides),
181            TensorOp::Pad { padding, mode: _ } => format!("Pad({:?})", padding),
182            TensorOp::Exp => "Exp".to_string(),
183            TensorOp::Log => "Log".to_string(),
184            TensorOp::Pow { exponent } => format!("Pow({})", exponent),
185            TensorOp::Sqrt => "Sqrt".to_string(),
186            TensorOp::FusedLinear => "FusedLinear".to_string(),
187            TensorOp::FusedAddReLU => "FusedAdd+ReLU".to_string(),
188            TensorOp::FusedBatchNormReLU { eps, momentum } => {
189                format!("FusedBN+ReLU(ε={:.1e}, μ={:.2})", eps, momentum)
190            }
191            TensorOp::FusedLayerNormDropout {
192                normalized_shape: _,
193                eps,
194                dropout_p,
195            } => format!("FusedLN+Dropout(ε={:.1e}, p={:.2})", eps, dropout_p),
196        }
197    }
198
199    /// Escape special characters for DOT format
200    fn escape(s: &str) -> String {
201        s.replace('\"', "\\\"")
202            .replace('\n', "\\n")
203            .replace('\t', "\\t")
204    }
205
206    /// Export graph statistics
207    pub fn graph_stats(graph: &ComputationGraph) -> String {
208        let mut stats = String::new();
209        writeln!(stats, "Graph Statistics:").unwrap();
210        writeln!(stats, "  Total nodes: {}", graph.nodes.len()).unwrap();
211        writeln!(stats, "  Input nodes: {}", graph.inputs.len()).unwrap();
212        writeln!(stats, "  Output nodes: {}", graph.outputs.len()).unwrap();
213
214        // Count operation types
215        let mut op_counts = std::collections::HashMap::new();
216        for node in graph.nodes.values() {
217            let op_name = Self::operation_name(&node.op);
218            *op_counts.entry(op_name).or_insert(0) += 1;
219        }
220
221        writeln!(stats, "  Operation counts:").unwrap();
222        let mut ops: Vec<_> = op_counts.into_iter().collect();
223        ops.sort_by(|a, b| b.1.cmp(&a.1));
224        for (op, count) in ops {
225            writeln!(stats, "    {}: {}", op, count).unwrap();
226        }
227
228        stats
229    }
230
231    fn operation_name(op: &TensorOp) -> &'static str {
232        match op {
233            TensorOp::Input { .. } => "Input",
234            TensorOp::Constant { .. } => "Constant",
235            TensorOp::MatMul => "MatMul",
236            TensorOp::Einsum { .. } => "Einsum",
237            TensorOp::Add => "Add",
238            TensorOp::Mul => "Mul",
239            TensorOp::Sub => "Sub",
240            TensorOp::Div => "Div",
241            TensorOp::ReLU => "ReLU",
242            TensorOp::Tanh => "Tanh",
243            TensorOp::Sigmoid => "Sigmoid",
244            TensorOp::GELU => "GELU",
245            TensorOp::Softmax { .. } => "Softmax",
246            TensorOp::LayerNorm { .. } => "LayerNorm",
247            TensorOp::BatchNorm { .. } => "BatchNorm",
248            TensorOp::Dropout { .. } => "Dropout",
249            TensorOp::Reshape { .. } => "Reshape",
250            TensorOp::Transpose { .. } => "Transpose",
251            TensorOp::ReduceSum { .. } => "ReduceSum",
252            TensorOp::ReduceMean { .. } => "ReduceMean",
253            TensorOp::Concat { .. } => "Concat",
254            TensorOp::Split { .. } => "Split",
255            TensorOp::Gather { .. } => "Gather",
256            TensorOp::Scatter { .. } => "Scatter",
257            TensorOp::Slice { .. } => "Slice",
258            TensorOp::Pad { .. } => "Pad",
259            TensorOp::Exp => "Exp",
260            TensorOp::Log => "Log",
261            TensorOp::Pow { .. } => "Pow",
262            TensorOp::Sqrt => "Sqrt",
263            TensorOp::FusedLinear => "FusedLinear",
264            TensorOp::FusedAddReLU => "FusedAddReLU",
265            TensorOp::FusedBatchNormReLU { .. } => "FusedBatchNormReLU",
266            TensorOp::FusedLayerNormDropout { .. } => "FusedLayerNormDropout",
267        }
268    }
269}
270
271/// Visualizer for proof trees
272pub struct ProofVisualizer;
273
274impl ProofVisualizer {
275    /// Export a proof tree to DOT format
276    ///
277    /// The proof is rendered as a tree with the conclusion at the top
278    /// and premises as child nodes.
279    pub fn to_dot(proof: &ProofFragment, id: usize) -> String {
280        let mut dot = String::new();
281        writeln!(dot, "digraph ProofTree {{").unwrap();
282        writeln!(dot, "  rankdir=TB;").unwrap();
283        writeln!(dot, "  node [shape=box, style=\"filled,rounded\"];").unwrap();
284        writeln!(dot).unwrap();
285
286        let mut node_counter = 0;
287        Self::write_proof_node(&mut dot, proof, id, &mut node_counter);
288
289        writeln!(dot, "}}").unwrap();
290        dot
291    }
292
293    fn write_proof_node(
294        dot: &mut String,
295        proof: &ProofFragment,
296        node_id: usize,
297        counter: &mut usize,
298    ) {
299        let color = if proof.premise_refs.is_empty() {
300            "lightblue" // Fact (no premises)
301        } else {
302            "lightyellow" // Rule application
303        };
304
305        let conclusion_str = format!("{:?}", proof.conclusion);
306        writeln!(
307            dot,
308            "  node_{} [label=\"{}\", fillcolor=\"{}\"];",
309            node_id,
310            GraphVisualizer::escape(&conclusion_str),
311            color
312        )
313        .unwrap();
314
315        // Write premise references as child nodes
316        for premise_ref in &proof.premise_refs {
317            *counter += 1;
318            let premise_id = *counter;
319            let premise_str = if let Some(ref hint) = premise_ref.conclusion_hint {
320                hint.clone()
321            } else {
322                format!("CID: {}", premise_ref.cid)
323            };
324            writeln!(
325                dot,
326                "  node_{} [label=\"{}\", fillcolor=\"lightgray\"];",
327                premise_id,
328                GraphVisualizer::escape(&premise_str)
329            )
330            .unwrap();
331            writeln!(dot, "  node_{} -> node_{};", node_id, premise_id).unwrap();
332        }
333
334        // Add rule information
335        if let Some(ref rule_ref) = proof.rule_applied {
336            writeln!(
337                dot,
338                "  node_{}_rule [label=\"Rule: {}\", shape=note, fillcolor=\"lightyellow\"];",
339                node_id,
340                GraphVisualizer::escape(&rule_ref.rule_id)
341            )
342            .unwrap();
343            writeln!(
344                dot,
345                "  node_{}_rule -> node_{} [style=dashed];",
346                node_id, node_id
347            )
348            .unwrap();
349        }
350    }
351
352    /// Generate a textual explanation of a proof
353    pub fn explain(proof: &ProofFragment, depth: usize) -> String {
354        let mut explanation = String::new();
355        let indent = "  ".repeat(depth);
356
357        writeln!(explanation, "{}Prove: {:?}", indent, proof.conclusion).unwrap();
358
359        if proof.premise_refs.is_empty() {
360            writeln!(explanation, "{}  ✓ This is a known fact", indent).unwrap();
361        } else {
362            if let Some(ref rule_ref) = proof.rule_applied {
363                writeln!(explanation, "{}  Using rule: {}", indent, rule_ref.rule_id).unwrap();
364            }
365            writeln!(
366                explanation,
367                "{}  Requires proving {} premise(s):",
368                indent,
369                proof.premise_refs.len()
370            )
371            .unwrap();
372            for (i, premise_ref) in proof.premise_refs.iter().enumerate() {
373                let hint = premise_ref
374                    .conclusion_hint
375                    .as_deref()
376                    .unwrap_or("(premise)");
377                writeln!(explanation, "{}    {}. {}", indent, i + 1, hint).unwrap();
378            }
379        }
380
381        if let Some(complexity) = proof.metadata.complexity {
382            writeln!(explanation, "{}  Complexity: {} steps", indent, complexity).unwrap();
383        }
384        writeln!(explanation, "{}  Depth: {}", indent, proof.metadata.depth).unwrap();
385
386        explanation
387    }
388
389    /// Generate a summary of proof statistics
390    pub fn proof_stats(proof: &ProofFragment) -> String {
391        let mut stats = String::new();
392        writeln!(stats, "Proof Statistics:").unwrap();
393        writeln!(stats, "  ID: {}", proof.id).unwrap();
394        writeln!(stats, "  Direct premises: {}", proof.premise_refs.len()).unwrap();
395
396        writeln!(
397            stats,
398            "  Complexity: {} steps",
399            proof.metadata.complexity.unwrap_or(0)
400        )
401        .unwrap();
402        writeln!(stats, "  Depth: {}", proof.metadata.depth).unwrap();
403        if let Some(ref created_by) = proof.metadata.created_by {
404            writeln!(stats, "  Created by: {}", created_by).unwrap();
405        }
406
407        if proof.premise_refs.is_empty() {
408            writeln!(stats, "  Type: Fact (axiom)").unwrap();
409        } else {
410            writeln!(stats, "  Type: Rule application").unwrap();
411            if let Some(ref rule_ref) = proof.rule_applied {
412                writeln!(stats, "  Rule: {}", rule_ref.rule_id).unwrap();
413            }
414        }
415
416        if !proof.substitution.is_empty() {
417            writeln!(stats, "  Substitutions: {}", proof.substitution.len()).unwrap();
418        }
419
420        stats
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427    use crate::{ComputationGraph, GraphNode, Predicate, TensorOp, Term};
428
429    #[test]
430    fn test_graph_to_dot() {
431        let mut graph = ComputationGraph::new();
432
433        let input = GraphNode::new(
434            "input".to_string(),
435            TensorOp::Input {
436                name: "x".to_string(),
437            },
438        );
439        graph.add_node(input).unwrap();
440        graph.mark_input("input".to_string());
441
442        let relu =
443            GraphNode::new("relu".to_string(), TensorOp::ReLU).add_input("input".to_string());
444        graph.add_node(relu).unwrap();
445        graph.mark_output("relu".to_string());
446
447        let dot = GraphVisualizer::to_dot(&graph);
448
449        assert!(dot.contains("digraph ComputationGraph"));
450        assert!(dot.contains("\"input\""));
451        assert!(dot.contains("\"relu\""));
452        assert!(dot.contains("\"input\" -> \"relu\""));
453    }
454
455    #[test]
456    fn test_graph_stats() {
457        let mut graph = ComputationGraph::new();
458
459        let input = GraphNode::new(
460            "input".to_string(),
461            TensorOp::Input {
462                name: "x".to_string(),
463            },
464        );
465        graph.add_node(input).unwrap();
466
467        let relu =
468            GraphNode::new("relu".to_string(), TensorOp::ReLU).add_input("input".to_string());
469        graph.add_node(relu).unwrap();
470
471        let stats = GraphVisualizer::graph_stats(&graph);
472
473        assert!(stats.contains("Total nodes: 2"));
474        assert!(stats.contains("Input: 1"));
475        assert!(stats.contains("ReLU: 1"));
476    }
477
478    #[test]
479    fn test_proof_to_dot() {
480        use crate::proof_storage::{ProofFragmentRef, ProofMetadata, RuleRef};
481
482        let conclusion = Predicate::new(
483            "ancestor".to_string(),
484            vec![
485                Term::Const(crate::Constant::String("Alice".to_string())),
486                Term::Const(crate::Constant::String("Bob".to_string())),
487            ],
488        );
489
490        let proof = ProofFragment {
491            id: "proof_1".to_string(),
492            conclusion,
493            rule_applied: Some(RuleRef {
494                rule_id: "ancestor_rule".to_string(),
495                rule_cid: None,
496                rule: None,
497            }),
498            premise_refs: vec![ProofFragmentRef {
499                cid: ipfrs_core::Cid::default(),
500                conclusion_hint: Some("parent(Alice, Bob)".to_string()),
501            }],
502            substitution: vec![],
503            metadata: ProofMetadata {
504                created_at: None,
505                created_by: None,
506                complexity: Some(2),
507                depth: 1,
508                custom: std::collections::HashMap::new(),
509            },
510        };
511
512        let dot = ProofVisualizer::to_dot(&proof, 0);
513
514        assert!(dot.contains("digraph ProofTree"));
515        assert!(dot.contains("ancestor"));
516        assert!(dot.contains("parent"));
517    }
518
519    #[test]
520    fn test_proof_explain() {
521        use crate::proof_storage::ProofMetadata;
522
523        let conclusion = Predicate::new(
524            "test".to_string(),
525            vec![Term::Const(crate::Constant::String("A".to_string()))],
526        );
527
528        let proof = ProofFragment {
529            id: "proof_2".to_string(),
530            conclusion,
531            rule_applied: None,
532            premise_refs: vec![],
533            substitution: vec![],
534            metadata: ProofMetadata {
535                created_at: None,
536                created_by: None,
537                complexity: None,
538                depth: 0,
539                custom: std::collections::HashMap::new(),
540            },
541        };
542
543        let explanation = ProofVisualizer::explain(&proof, 0);
544
545        assert!(explanation.contains("Prove"));
546        assert!(explanation.contains("known fact"));
547    }
548
549    #[test]
550    fn test_proof_stats() {
551        use crate::proof_storage::ProofMetadata;
552
553        let conclusion = Predicate::new(
554            "test".to_string(),
555            vec![Term::Const(crate::Constant::String("A".to_string()))],
556        );
557
558        let proof = ProofFragment {
559            id: "proof_3".to_string(),
560            conclusion,
561            rule_applied: None,
562            premise_refs: vec![],
563            substitution: vec![],
564            metadata: ProofMetadata {
565                created_at: None,
566                created_by: None,
567                complexity: None,
568                depth: 0,
569                custom: std::collections::HashMap::new(),
570            },
571        };
572
573        let stats = ProofVisualizer::proof_stats(&proof);
574
575        assert!(stats.contains("Proof Statistics"));
576        assert!(stats.contains("Type: Fact"));
577    }
578}