Skip to main content

ascii_petgraph/
graph.rs

1//! Core `RenderedGraph` type that ties together physics, rendering, and styling.
2
3use std::fmt::Display;
4
5use petgraph::graph::{DiGraph, EdgeIndex, NodeIndex};
6use petgraph::visit::EdgeRef;
7use ratatui::buffer::Buffer;
8use ratatui::layout::Rect;
9use ratatui::style::Color;
10use ratatui::widgets::Widget;
11
12use crate::physics::{PhysicsConfig, PhysicsEngine};
13use crate::render::{CharGrid, GraphRenderer, RenderedEdge, RenderedNode};
14use crate::style::{BoxBorder, EdgeStyle, NodeStyle};
15
16/// A rendered graph ready for display in a TUI.
17///
18/// The graph structure is immutable after creation, but colors and styles
19/// can be modified at any time.
20pub struct RenderedGraph<N, E> {
21    /// The underlying petgraph.
22    graph: DiGraph<N, E>,
23    /// Physics engine for layout.
24    physics: PhysicsEngine,
25    /// Renderer configuration.
26    renderer: GraphRenderer,
27    /// Per-node styles (indexed by node index).
28    node_styles: Vec<NodeStyle>,
29    /// Per-edge styles (indexed by edge index).
30    edge_styles: Vec<EdgeStyle>,
31    /// Default node style.
32    default_node_style: NodeStyle,
33    /// Default edge style.
34    default_edge_style: EdgeStyle,
35    /// Cached rendered nodes (updated after layout).
36    rendered_nodes: Vec<RenderedNode<()>>,
37    /// Cached rendered edges.
38    rendered_edges: Vec<RenderedEdge<()>>,
39    /// Whether layout is dirty and needs recalculation.
40    layout_dirty: bool,
41}
42
43impl<N: Display + Clone, E: Display + Clone> RenderedGraph<N, E> {
44    /// Create a new rendered graph from a petgraph DiGraph.
45    pub fn from_graph(graph: DiGraph<N, E>) -> Self {
46        let node_count = graph.node_count();
47        let edge_count = graph.edge_count();
48
49        let physics = PhysicsEngine::new(&graph, PhysicsConfig::default());
50        let node_styles = vec![NodeStyle::default(); node_count];
51        let edge_styles = vec![EdgeStyle::default(); edge_count];
52
53        Self {
54            graph,
55            physics,
56            renderer: GraphRenderer::default(),
57            node_styles,
58            edge_styles,
59            default_node_style: NodeStyle::default(),
60            default_edge_style: EdgeStyle::default(),
61            rendered_nodes: Vec::new(),
62            rendered_edges: Vec::new(),
63            layout_dirty: true,
64        }
65    }
66
67    /// Create a builder for more configuration options.
68    pub fn builder() -> RenderedGraphBuilder<N, E> {
69        RenderedGraphBuilder::new()
70    }
71
72    /// Get a reference to the underlying graph.
73    pub fn graph(&self) -> &DiGraph<N, E> {
74        &self.graph
75    }
76
77    /// Get the physics configuration.
78    pub fn physics_config(&self) -> &PhysicsConfig {
79        &self.physics.config
80    }
81
82    /// Set the physics configuration.
83    pub fn set_physics_config(&mut self, config: PhysicsConfig) {
84        self.physics.config = config;
85    }
86
87    /// Advance the physics simulation by one step.
88    pub fn tick(&mut self) {
89        self.physics.tick(&self.graph);
90        self.layout_dirty = true;
91    }
92
93    /// Check if the simulation has converged.
94    pub fn is_stable(&self) -> bool {
95        self.physics.is_stable()
96    }
97
98    /// Run the physics simulation until stable.
99    pub fn run_simulation(&mut self) {
100        self.physics.run(&self.graph);
101        self.physics.normalize_positions();
102        self.layout_dirty = true;
103    }
104
105    /// Get the number of iterations run so far.
106    pub fn iterations(&self) -> usize {
107        self.physics.iterations()
108    }
109
110    // === Color API ===
111
112    /// Helper to mutate a node's style.
113    fn mutate_node_style<F>(&mut self, node: NodeIndex, f: F)
114    where
115        F: FnOnce(&mut NodeStyle),
116    {
117        if let Some(style) = self.node_styles.get_mut(node.index()) {
118            f(style);
119        }
120    }
121
122    /// Helper to mutate an edge's style.
123    fn mutate_edge_style<F>(&mut self, edge: EdgeIndex, f: F)
124    where
125        F: FnOnce(&mut EdgeStyle),
126    {
127        if let Some(style) = self.edge_styles.get_mut(edge.index()) {
128            f(style);
129        }
130    }
131
132    /// Helper to mutate all node styles.
133    fn mutate_all_node_styles<F>(&mut self, f: F)
134    where
135        F: Fn(&mut NodeStyle),
136    {
137        for style in &mut self.node_styles {
138            f(style);
139        }
140    }
141
142    /// Helper to mutate all edge styles.
143    fn mutate_all_edge_styles<F>(&mut self, f: F)
144    where
145        F: Fn(&mut EdgeStyle),
146    {
147        for style in &mut self.edge_styles {
148            f(style);
149        }
150    }
151
152    /// Set the border color of a node.
153    pub fn set_node_border_color(&mut self, node: NodeIndex, color: Color) {
154        self.mutate_node_style(node, |style| style.border_color = color);
155    }
156
157    /// Set the text color of a node.
158    pub fn set_node_text_color(&mut self, node: NodeIndex, color: Color) {
159        self.mutate_node_style(node, |style| style.text_color = color);
160    }
161
162    /// Set both border and text color of a node.
163    pub fn set_node_colors(&mut self, node: NodeIndex, border: Color, text: Color) {
164        self.mutate_node_style(node, |style| {
165            style.border_color = border;
166            style.text_color = text;
167        });
168    }
169
170    /// Set the line color of an edge.
171    pub fn set_edge_color(&mut self, edge: EdgeIndex, color: Color) {
172        self.mutate_edge_style(edge, |style| style.line_color = color);
173    }
174
175    /// Set the text color of an edge label.
176    pub fn set_edge_text_color(&mut self, edge: EdgeIndex, color: Color) {
177        self.mutate_edge_style(edge, |style| style.text_color = color);
178    }
179
180    /// Set both line and text color of an edge.
181    pub fn set_edge_colors(&mut self, edge: EdgeIndex, line: Color, text: Color) {
182        self.mutate_edge_style(edge, |style| {
183            style.line_color = line;
184            style.text_color = text;
185        });
186    }
187
188    /// Set border color for all nodes.
189    pub fn set_all_node_border_colors(&mut self, color: Color) {
190        self.mutate_all_node_styles(|style| style.border_color = color);
191    }
192
193    /// Set text color for all nodes.
194    pub fn set_all_node_text_colors(&mut self, color: Color) {
195        self.mutate_all_node_styles(|style| style.text_color = color);
196    }
197
198    /// Set line color for all edges.
199    pub fn set_all_edge_colors(&mut self, color: Color) {
200        self.mutate_all_edge_styles(|style| style.line_color = color);
201    }
202
203    /// Reset all node styles to default.
204    pub fn reset_node_styles(&mut self) {
205        let default = self.default_node_style.clone();
206        self.mutate_all_node_styles(|style| *style = default.clone());
207    }
208
209    /// Reset all edge styles to default.
210    pub fn reset_edge_styles(&mut self) {
211        let default = self.default_edge_style.clone();
212        self.mutate_all_edge_styles(|style| *style = default.clone());
213    }
214
215    /// Set the default node style (used for new nodes and reset).
216    pub fn set_default_node_style(&mut self, style: NodeStyle) {
217        self.default_node_style = style;
218    }
219
220    /// Set the default edge style.
221    pub fn set_default_edge_style(&mut self, style: EdgeStyle) {
222        self.default_edge_style = style;
223    }
224
225    /// Set the box border style for all nodes.
226    pub fn set_border_style(&mut self, border: BoxBorder) {
227        self.mutate_all_node_styles(|style| style.border = border);
228    }
229
230    /// Set the scaling mode for handling large graphs.
231    pub fn set_scaling_mode(&mut self, mode: crate::render::ScalingMode) {
232        self.renderer.scaling_mode = mode;
233        self.layout_dirty = true;
234    }
235
236    /// Auto-detect and apply appropriate scaling mode based on terminal width.
237    pub fn auto_scale(&mut self, max_width: usize) {
238        use crate::render::ScalingMode;
239        
240        // First try full labels
241        self.renderer.scaling_mode = ScalingMode::Full;
242        self.layout_dirty = true;
243        self.update_layout();
244        
245        let current_width = self.rendered_nodes.iter()
246            .map(|n| n.x + n.width)
247            .max()
248            .unwrap_or(0) + self.renderer.padding;
249        
250        if current_width <= max_width {
251            return;
252        }
253        
254        // Try truncating labels
255        self.renderer.scaling_mode = ScalingMode::Truncate(8);
256        self.layout_dirty = true;
257        self.update_layout();
258        
259        let current_width = self.rendered_nodes.iter()
260            .map(|n| n.x + n.width)
261            .max()
262            .unwrap_or(0) + self.renderer.padding;
263        
264        if current_width <= max_width {
265            return;
266        }
267        
268        // Fall back to numeric IDs
269        self.renderer.scaling_mode = ScalingMode::NumericIds;
270        self.layout_dirty = true;
271    }
272
273    // === Rendering ===
274
275    /// Update the rendered layout from current physics positions.
276    fn update_layout(&mut self) {
277        if !self.layout_dirty {
278            return;
279        }
280
281        self.rendered_nodes.clear();
282        self.rendered_edges.clear();
283
284        // Calculate node positions and sizes
285        for (idx, node_idx) in self.graph.node_indices().enumerate() {
286            let pos = self.physics.position(node_idx);
287            let label = self.graph[node_idx].to_string();
288            let display_label = self.renderer.display_label(node_idx, &label);
289            let width = self.renderer.node_width(&display_label);
290
291            let x = (pos.x * self.renderer.scale_x) as usize + self.renderer.padding;
292            let y = (pos.y * self.renderer.scale_y) as usize + self.renderer.padding;
293
294            self.rendered_nodes.push(RenderedNode {
295                index: node_idx,
296                label: (),
297                x,
298                y,
299                width,
300                height: self.renderer.node_height,
301                style: self.node_styles.get(idx).cloned().unwrap_or_default(),
302            });
303        }
304
305        // Calculate edge paths and parallel offsets
306        // First, collect all edges and detect parallel pairs
307        use std::collections::HashMap;
308        let mut parallel_groups: HashMap<(usize, usize), Vec<EdgeIndex>> = HashMap::new();
309        
310        for edge in self.graph.edge_references() {
311            let s = edge.source().index();
312            let t = edge.target().index();
313            // Use canonical ordering to group edges between same nodes
314            let key = (s.min(t), s.max(t));
315            parallel_groups.entry(key).or_default().push(edge.id());
316        }
317        
318        // Assign offsets to parallel edges
319        let mut edge_offsets: HashMap<EdgeIndex, i32> = HashMap::new();
320        for edges in parallel_groups.values() {
321            if edges.len() > 1 {
322                // Multiple edges between same nodes - assign symmetric offsets
323                // For 2 edges: -1, +1; for 3 edges: -2, 0, +2; etc.
324                let count = edges.len() as i32;
325                for (i, &edge_id) in edges.iter().enumerate() {
326                    let offset = 2 * i as i32 - (count - 1);
327                    edge_offsets.insert(edge_id, offset);
328                }
329            }
330        }
331        
332        for edge in self.graph.edge_references() {
333            let idx = edge.id().index();
334            let offset = edge_offsets.get(&edge.id()).copied().unwrap_or(0);
335            self.rendered_edges.push(RenderedEdge {
336                index: edge.id(),
337                label: (),
338                source: edge.source(),
339                target: edge.target(),
340                path: Vec::new(), // Path calculated during rendering
341                style: self.edge_styles.get(idx).cloned().unwrap_or_default(),
342                parallel_offset: offset,
343            });
344        }
345
346        self.layout_dirty = false;
347    }
348
349    /// Render the graph to a character grid.
350    pub fn render_to_grid(&mut self) -> CharGrid {
351        self.update_layout();
352
353        // Calculate grid size with extra room for edge labels
354        let max_label_len = self.graph.edge_weights()
355            .map(|w| w.to_string().len())
356            .max()
357            .unwrap_or(0);
358        
359        let max_x = self.rendered_nodes.iter()
360            .map(|n| n.x + n.width)
361            .max()
362            .unwrap_or(0) + self.renderer.padding + max_label_len + 2;
363        let max_y = self.rendered_nodes.iter()
364            .map(|n| n.y + n.height)
365            .max()
366            .unwrap_or(0) + self.renderer.padding + 2;
367
368        let mut grid = CharGrid::new(max_x.max(1), max_y.max(1));
369
370        // Render edges first (so nodes draw on top)
371        for (idx, edge) in self.rendered_edges.iter().enumerate() {
372            let edge_with_label = RenderedEdge {
373                index: edge.index,
374                label: self.graph.edge_weight(edge.index)
375                    .map(|w| w.to_string())
376                    .unwrap_or_default(),
377                source: edge.source,
378                target: edge.target,
379                path: edge.path.clone(),
380                style: self.edge_styles.get(idx).cloned().unwrap_or_default(),
381                parallel_offset: edge.parallel_offset,
382            };
383            self.renderer.render_edge(&mut grid, &edge_with_label, &self.rendered_nodes);
384        }
385
386        // Render nodes
387        for (idx, node) in self.rendered_nodes.iter().enumerate() {
388            let label = self.renderer.display_label(
389                node.index,
390                &self.graph[node.index],
391            );
392            let node_with_label = RenderedNode {
393                index: node.index,
394                label,
395                x: node.x,
396                y: node.y,
397                width: node.width,
398                height: node.height,
399                style: self.node_styles.get(idx).cloned().unwrap_or_default(),
400            };
401            self.renderer.render_node(&mut grid, &node_with_label);
402        }
403
404        grid
405    }
406
407    /// Create a widget for rendering with ratatui.
408    pub fn widget(&mut self) -> GraphWidget<'_, N, E> {
409        GraphWidget { graph: self }
410    }
411}
412
413/// Builder for RenderedGraph with configuration options.
414pub struct RenderedGraphBuilder<N, E> {
415    graph: Option<DiGraph<N, E>>,
416    physics_config: PhysicsConfig,
417    border_style: BoxBorder,
418    default_node_style: NodeStyle,
419    default_edge_style: EdgeStyle,
420}
421
422impl<N, E> RenderedGraphBuilder<N, E> {
423    pub fn new() -> Self {
424        Self {
425            graph: None,
426            physics_config: PhysicsConfig::default(),
427            border_style: BoxBorder::default(),
428            default_node_style: NodeStyle::default(),
429            default_edge_style: EdgeStyle::default(),
430        }
431    }
432
433    pub fn graph(mut self, graph: DiGraph<N, E>) -> Self {
434        self.graph = Some(graph);
435        self
436    }
437
438    pub fn physics_config(mut self, config: PhysicsConfig) -> Self {
439        self.physics_config = config;
440        self
441    }
442
443    pub fn border_style(mut self, border: BoxBorder) -> Self {
444        self.border_style = border;
445        self.default_node_style.border = border;
446        self
447    }
448
449    pub fn default_node_style(mut self, style: NodeStyle) -> Self {
450        self.default_node_style = style;
451        self
452    }
453
454    pub fn default_edge_style(mut self, style: EdgeStyle) -> Self {
455        self.default_edge_style = style;
456        self
457    }
458
459    pub fn gravity(mut self, g: f64) -> Self {
460        self.physics_config.gravity = g;
461        self
462    }
463
464    pub fn spring_constant(mut self, k: f64) -> Self {
465        self.physics_config.spring_constant = k;
466        self
467    }
468
469    pub fn repulsion_constant(mut self, r: f64) -> Self {
470        self.physics_config.repulsion_constant = r;
471        self
472    }
473}
474
475impl<N: Display + Clone, E: Display + Clone> RenderedGraphBuilder<N, E> {
476    pub fn build(self) -> RenderedGraph<N, E> {
477        let graph = self.graph.expect("graph is required");
478        let mut rendered = RenderedGraph::from_graph(graph);
479        rendered.set_physics_config(self.physics_config);
480        rendered.set_default_node_style(self.default_node_style.clone());
481        rendered.set_default_edge_style(self.default_edge_style.clone());
482        rendered.set_border_style(self.border_style);
483        rendered
484    }
485}
486
487impl<N, E> Default for RenderedGraphBuilder<N, E> {
488    fn default() -> Self {
489        Self::new()
490    }
491}
492
493/// Ratatui widget for rendering the graph.
494pub struct GraphWidget<'a, N, E> {
495    graph: &'a mut RenderedGraph<N, E>,
496}
497
498impl<N: Display + Clone, E: Display + Clone> Widget for GraphWidget<'_, N, E> {
499    fn render(self, area: Rect, buf: &mut Buffer) {
500        let grid = self.graph.render_to_grid();
501        let (grid_width, grid_height) = grid.size();
502
503        for (gx, gy, cell) in grid.iter() {
504            // Map grid coordinates to buffer coordinates
505            let bx = area.x + gx as u16;
506            let by = area.y + gy as u16;
507
508            // Check bounds
509            if bx >= area.x + area.width || by >= area.y + area.height {
510                continue;
511            }
512            if gx >= grid_width || gy >= grid_height {
513                continue;
514            }
515
516            let buf_cell = buf.cell_mut((bx, by));
517            if let Some(buf_cell) = buf_cell {
518                buf_cell.set_char(cell.char);
519                buf_cell.set_fg(cell.fg);
520                if cell.bg != Color::Reset {
521                    buf_cell.set_bg(cell.bg);
522                }
523            }
524        }
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531
532    #[test]
533    fn test_rendered_graph_creation() {
534        let mut graph: DiGraph<&str, &str> = DiGraph::new();
535        let a = graph.add_node("A");
536        let b = graph.add_node("B");
537        graph.add_edge(a, b, "edge");
538
539        let rendered = RenderedGraph::from_graph(graph);
540        assert_eq!(rendered.graph().node_count(), 2);
541        assert_eq!(rendered.graph().edge_count(), 1);
542    }
543
544    #[test]
545    fn test_simulation_runs() {
546        let mut graph: DiGraph<&str, &str> = DiGraph::new();
547        let a = graph.add_node("A");
548        let b = graph.add_node("B");
549        let c = graph.add_node("C");
550        graph.add_edge(a, b, "1");
551        graph.add_edge(b, c, "2");
552
553        let mut rendered = RenderedGraph::from_graph(graph);
554        rendered.run_simulation();
555        
556        assert!(rendered.is_stable());
557        assert!(rendered.iterations() >= 10);
558    }
559
560    #[test]
561    fn test_color_api() {
562        let mut graph: DiGraph<&str, &str> = DiGraph::new();
563        let a = graph.add_node("A");
564        let b = graph.add_node("B");
565        let e = graph.add_edge(a, b, "edge");
566
567        let mut rendered = RenderedGraph::from_graph(graph);
568        
569        // Set node colors
570        rendered.set_node_border_color(a, Color::Red);
571        rendered.set_node_text_color(a, Color::Green);
572        
573        // Set edge colors
574        rendered.set_edge_color(e, Color::Blue);
575        rendered.set_edge_text_color(e, Color::Yellow);
576        
577        // Bulk operations
578        rendered.set_all_node_border_colors(Color::White);
579        rendered.set_all_edge_colors(Color::Gray);
580        
581        // Reset
582        rendered.reset_node_styles();
583        rendered.reset_edge_styles();
584    }
585
586    #[test]
587    fn test_render_to_grid() {
588        let mut graph: DiGraph<&str, &str> = DiGraph::new();
589        let a = graph.add_node("Hello");
590        let b = graph.add_node("World");
591        graph.add_edge(a, b, "");
592
593        let mut rendered = RenderedGraph::from_graph(graph);
594        rendered.run_simulation();
595        
596        let grid = rendered.render_to_grid();
597        let (width, height) = grid.size();
598        
599        assert!(width > 0);
600        assert!(height > 0);
601        
602        // Check that some cells have content
603        let mut has_content = false;
604        for (_, _, cell) in grid.iter() {
605            if cell.char != ' ' {
606                has_content = true;
607                break;
608            }
609        }
610        assert!(has_content);
611    }
612
613    #[test]
614    fn test_builder() {
615        let mut graph: DiGraph<&str, &str> = DiGraph::new();
616        graph.add_node("A");
617        graph.add_node("B");
618
619        let rendered = RenderedGraph::builder()
620            .graph(graph)
621            .border_style(BoxBorder::Rounded)
622            .gravity(2.0)
623            .spring_constant(0.2)
624            .build();
625
626        assert_eq!(rendered.physics_config().gravity, 2.0);
627        assert_eq!(rendered.physics_config().spring_constant, 0.2);
628    }
629
630    #[test]
631    fn test_scaling_mode() {
632        let mut graph: DiGraph<&str, &str> = DiGraph::new();
633        let a = graph.add_node("VeryLongNodeLabel");
634        let b = graph.add_node("AnotherLongLabel");
635        graph.add_edge(a, b, "");
636
637        let mut rendered = RenderedGraph::from_graph(graph);
638        rendered.run_simulation();
639        
640        // Test different scaling modes
641        rendered.set_scaling_mode(crate::render::ScalingMode::Full);
642        let grid1 = rendered.render_to_grid();
643        
644        rendered.set_scaling_mode(crate::render::ScalingMode::Truncate(5));
645        rendered.render_to_grid(); // Trigger re-render with truncated mode
646        
647        rendered.set_scaling_mode(crate::render::ScalingMode::NumericIds);
648        let grid3 = rendered.render_to_grid();
649        
650        // NumericIds should produce smallest grid
651        let (w1, _) = grid1.size();
652        let (w3, _) = grid3.size();
653        assert!(w3 <= w1);
654    }
655}