Skip to main content

codetwin/layouts/
dependency_graph.rs

1use super::trait_def::Layout;
2use crate::core::ir::{Blueprint, Element};
3use crate::drivers;
4use anyhow::Result;
5use petgraph::algo::is_cyclic_directed;
6use petgraph::graph::DiGraph;
7use std::collections::HashMap;
8
9pub struct DependencyGraphLayout;
10
11impl Layout for DependencyGraphLayout {
12    fn format(&self, blueprints: &[Blueprint]) -> Result<Vec<(String, String)>> {
13        let mut graph = DiGraph::new();
14        let mut node_indices = HashMap::new();
15
16        // Build graph: one node per module
17        for blueprint in blueprints {
18            let module_name = extract_module_name(&blueprint.source_path);
19                node_indices
20                    .entry(module_name.clone())
21                    .or_insert_with(|| graph.add_node(module_name.clone()));
22        }
23
24        // Add edges for dependencies
25        for blueprint in blueprints {
26            let module_name = extract_module_name(&blueprint.source_path);
27            if let Some(&source_idx) = node_indices.get(&module_name) {
28                for dep in &blueprint.dependencies {
29                    if let Some(&target_idx) = node_indices.get(dep) {
30                        graph.add_edge(source_idx, target_idx, ());
31                    }
32                }
33            }
34        }
35
36        // Check for cycles
37        let has_cycles = is_cyclic_directed(&graph);
38
39        // Generate Mermaid diagram
40        let mermaid_diagram = generate_mermaid_diagram(&graph, &node_indices, has_cycles);
41
42        // Generate module list with descriptions
43        let module_list = generate_module_list(blueprints);
44
45        // Generate markdown output
46        let content = format!(
47            "{}\n\n{}\n\n{}",
48            mermaid_diagram,
49            module_list,
50            generate_footer(has_cycles)
51        );
52
53        Ok(vec![("architecture.md".to_string(), content)])
54    }
55}
56
57/// Extract module name from file path (e.g., "src/engine.rs" -> "engine")
58fn extract_module_name(path: &std::path::Path) -> String {
59    path.file_stem()
60        .and_then(|stem| stem.to_str())
61        .unwrap_or("unknown")
62        .to_string()
63}
64
65/// Generate Mermaid directed graph diagram
66fn generate_mermaid_diagram(
67    graph: &DiGraph<String, ()>,
68    node_indices: &HashMap<String, petgraph::graph::NodeIndex>,
69    _has_cycles: bool,
70) -> String {
71    let mut diagram = String::from("## Dependency Graph\n\n```mermaid\ngraph TD\n");
72
73    // Add nodes with styling
74    for name in node_indices.keys() {
75        diagram.push_str(&format!("    {}[{}]\n", sanitize_id(name), name));
76    }
77
78    // Add edges
79    for edge in graph.raw_edges() {
80        let from_name = &graph[edge.source()];
81        let to_name = &graph[edge.target()];
82        diagram.push_str(&format!(
83            "    {} --> {}\n",
84            sanitize_id(from_name),
85            sanitize_id(to_name)
86        ));
87    }
88
89    diagram.push_str("```\n");
90    diagram
91}
92
93/// Sanitize module names for Mermaid (replace special chars)
94fn sanitize_id(name: &str) -> String {
95    name.replace("-", "_")
96        .replace(".", "_")
97        .chars()
98        .filter(|c| c.is_alphanumeric() || *c == '_')
99        .collect()
100}
101
102/// Generate markdown list of modules with their descriptions
103fn generate_module_list(blueprints: &[Blueprint]) -> String {
104    let mut list = String::from("## Modules\n\n");
105
106    for blueprint in blueprints {
107        let module_name = extract_module_name(&blueprint.source_path);
108        list.push_str(&format!("### `{}`\n\n", module_name));
109
110        // Add file path
111        list.push_str(&format!(
112            "**File**: {}\n\n",
113            blueprint.source_path.display()
114        ));
115
116        // Add elements count
117        let class_count = blueprint
118            .elements
119            .iter()
120            .filter(|e| matches!(e, Element::Class(_)))
121            .count();
122        let function_count = blueprint
123            .elements
124            .iter()
125            .filter(|e| matches!(e, Element::Function(_)))
126            .count();
127
128        let terminology = drivers::terminology_for_language(&blueprint.language);
129        list.push_str(&format!(
130            "**Contents**: {} {}, {} {}\n\n",
131            class_count,
132            terminology.element_type_plural,
133            function_count,
134            terminology.function_label_plural
135        ));
136
137        // Add elements summary
138        if !blueprint.elements.is_empty() {
139            list.push_str(&format!(
140                "**Key {} and {}**:\n\n",
141                terminology.element_type_plural, terminology.function_label_plural
142            ));
143            for element in &blueprint.elements {
144                match element {
145                    Element::Class(class) => {
146                        list.push_str(&format!(
147                            "- `{}` ({})\n",
148                            class.name, terminology.element_type_singular
149                        ));
150                    }
151                    Element::Function(func) => {
152                        list.push_str(&format!(
153                            "- `{}()` ({})\n",
154                            func.name, terminology.function_label
155                        ));
156                    }
157                    Element::Module(_) => {}
158                }
159            }
160            list.push('\n');
161        }
162
163        // Add dependencies
164        if !blueprint.dependencies.is_empty() {
165            list.push_str("**Dependencies**: ");
166            list.push_str(&blueprint.dependencies.join(", "));
167            list.push_str("\n\n");
168        }
169    }
170
171    list
172}
173
174/// Generate footer with cycle detection warning
175fn generate_footer(has_cycles: bool) -> String {
176    if has_cycles {
177        String::from(
178            "⚠️ **Circular Dependencies Detected**\n\n\
179            This architecture contains circular dependencies. Consider refactoring to break these cycles.",
180        )
181    } else {
182        String::from("✅ No circular dependencies detected.")
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use std::path::PathBuf;
190
191    #[test]
192    fn test_extract_module_name() {
193        assert_eq!(
194            extract_module_name(&PathBuf::from("src/engine.rs")),
195            "engine"
196        );
197        assert_eq!(extract_module_name(&PathBuf::from("main.rs")), "main");
198        assert_eq!(
199            extract_module_name(&PathBuf::from("src/drivers/rust.rs")),
200            "rust"
201        );
202    }
203
204    #[test]
205    fn test_sanitize_id() {
206        assert_eq!(sanitize_id("engine"), "engine");
207        assert_eq!(sanitize_id("std-lib"), "std_lib");
208        assert_eq!(sanitize_id("my.module"), "my_module");
209        assert_eq!(sanitize_id("some-module!"), "some_module");
210    }
211
212    #[test]
213    fn test_dependency_graph_format() {
214        let mut blueprints = vec![];
215
216        // Create test blueprints
217        let bp1 = Blueprint {
218            source_path: PathBuf::from("src/main.rs"),
219            language: "rust".to_string(),
220            elements: vec![],
221            dependencies: vec!["engine".to_string()],
222        };
223
224        let bp2 = Blueprint {
225            source_path: PathBuf::from("src/engine.rs"),
226            language: "rust".to_string(),
227            elements: vec![],
228            dependencies: vec!["config".to_string()],
229        };
230
231        blueprints.push(bp1);
232        blueprints.push(bp2);
233
234        let layout = DependencyGraphLayout;
235        let result = layout.format(&blueprints).unwrap();
236
237        assert_eq!(result.len(), 1);
238        assert_eq!(result[0].0, "architecture.md");
239
240        let content = &result[0].1;
241        assert!(content.contains("graph TD"));
242        assert!(content.contains("main"));
243        assert!(content.contains("engine"));
244        assert!(content.contains("config"));
245    }
246
247    #[test]
248    fn test_no_cycles() {
249        let blueprints = vec![
250            Blueprint {
251                source_path: PathBuf::from("src/a.rs"),
252                language: "rust".to_string(),
253                elements: vec![],
254                dependencies: vec!["b".to_string()],
255            },
256            Blueprint {
257                source_path: PathBuf::from("src/b.rs"),
258                language: "rust".to_string(),
259                elements: vec![],
260                dependencies: vec![],
261            },
262        ];
263
264        let layout = DependencyGraphLayout;
265        let result = layout.format(&blueprints).unwrap();
266        let content = &result[0].1;
267
268        assert!(content.contains("No circular dependencies detected"));
269    }
270}