Skip to main content

flowscope_export/
mermaid.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use flowscope_core::{AnalyzeResult, EdgeType, NodeType};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum MermaidView {
8    All,
9    Script,
10    Table,
11    Column,
12    Hybrid,
13}
14
15pub fn export_mermaid(result: &AnalyzeResult, view: MermaidView) -> String {
16    match view {
17        MermaidView::All => generate_all_views(result),
18        MermaidView::Script => generate_script_view(result),
19        MermaidView::Table => generate_table_view(result),
20        MermaidView::Column => generate_column_view(result),
21        MermaidView::Hybrid => generate_hybrid_view(result),
22    }
23}
24
25fn generate_all_views(result: &AnalyzeResult) -> String {
26    let sections = vec![
27        "# Lineage Diagrams".to_string(),
28        String::new(),
29        "## Script View".to_string(),
30        "```mermaid".to_string(),
31        generate_script_view(result),
32        "```".to_string(),
33        String::new(),
34        "## Hybrid View (Scripts + Tables)".to_string(),
35        "```mermaid".to_string(),
36        generate_hybrid_view(result),
37        "```".to_string(),
38        String::new(),
39        "## Table View".to_string(),
40        "```mermaid".to_string(),
41        generate_table_view(result),
42        "```".to_string(),
43        String::new(),
44        "## Column View".to_string(),
45        "```mermaid".to_string(),
46        generate_column_view(result),
47        "```".to_string(),
48    ];
49
50    sections.join("\n")
51}
52
53fn sanitize_id(id: &str) -> String {
54    id.chars()
55        .map(|c| {
56            if c.is_alphanumeric() || c == '_' {
57                c
58            } else {
59                '_'
60            }
61        })
62        .collect()
63}
64
65fn escape_label(label: &str) -> String {
66    label.replace('"', "\\\"").replace('\n', " ")
67}
68
69#[derive(Debug)]
70struct ScriptInfo {
71    source_name: String,
72    tables_read: HashSet<Arc<str>>,
73    tables_written: HashSet<Arc<str>>,
74}
75
76fn extract_script_info(result: &AnalyzeResult) -> Vec<ScriptInfo> {
77    let mut script_map: HashMap<String, ScriptInfo> = HashMap::new();
78
79    for stmt in &result.statements {
80        let source_name = stmt
81            .source_name
82            .clone()
83            .unwrap_or_else(|| "default".to_string());
84
85        let entry = script_map
86            .entry(source_name.clone())
87            .or_insert_with(|| ScriptInfo {
88                source_name: source_name.clone(),
89                tables_read: HashSet::new(),
90                tables_written: HashSet::new(),
91            });
92
93        let stmt_edges: Vec<_> = result.edges_in_statement(stmt.statement_index).collect();
94
95        for node in result.nodes_in_statement(stmt.statement_index) {
96            if matches!(node.node_type, NodeType::Table | NodeType::View) {
97                let is_written = stmt_edges
98                    .iter()
99                    .any(|edge| edge.to == node.id && edge.edge_type == EdgeType::DataFlow);
100                let is_read = stmt_edges
101                    .iter()
102                    .any(|edge| edge.from == node.id && edge.edge_type == EdgeType::DataFlow);
103
104                let table_name = node
105                    .qualified_name
106                    .clone()
107                    .unwrap_or_else(|| node.label.clone());
108
109                if is_written {
110                    entry.tables_written.insert(table_name.clone());
111                }
112                if is_read || !is_written {
113                    entry.tables_read.insert(table_name);
114                }
115            }
116        }
117    }
118
119    script_map.into_values().collect()
120}
121
122fn generate_script_view(result: &AnalyzeResult) -> String {
123    let scripts = extract_script_info(result);
124    let mut lines = vec!["flowchart LR".to_string()];
125
126    for script in &scripts {
127        let id = sanitize_id(&script.source_name);
128        let label = escape_label(&script.source_name);
129        lines.push(format!("    {id}[\"{label}\"]"));
130    }
131
132    for producer in &scripts {
133        for consumer in &scripts {
134            if producer.source_name == consumer.source_name {
135                continue;
136            }
137
138            let shared_tables: Vec<_> = producer
139                .tables_written
140                .iter()
141                .filter(|table| consumer.tables_read.contains(*table))
142                .collect();
143
144            if !shared_tables.is_empty() {
145                let producer_id = sanitize_id(&producer.source_name);
146                let consumer_id = sanitize_id(&consumer.source_name);
147                let label = if shared_tables.len() > 3 {
148                    let first_three: Vec<_> = shared_tables.iter().take(3).collect();
149                    format!(
150                        "{}...",
151                        first_three
152                            .iter()
153                            .map(|value| value.as_ref())
154                            .collect::<Vec<_>>()
155                            .join(", ")
156                    )
157                } else {
158                    shared_tables
159                        .iter()
160                        .map(|value| value.as_ref())
161                        .collect::<Vec<_>>()
162                        .join(", ")
163                };
164                lines.push(format!(
165                    "    {producer_id} -->|\"{}\"| {consumer_id}",
166                    escape_label(&label)
167                ));
168            }
169        }
170    }
171
172    lines.join("\n")
173}
174
175fn generate_table_view(result: &AnalyzeResult) -> String {
176    let mut lines = vec!["flowchart LR".to_string()];
177    let mut table_ids: HashMap<String, String> = HashMap::new();
178    let mut edges = HashSet::new();
179
180    let table_nodes: Vec<_> = result
181        .nodes
182        .iter()
183        .filter(|node| node.node_type.is_table_like())
184        .collect();
185
186    for node in &table_nodes {
187        let key = node
188            .qualified_name
189            .as_deref()
190            .unwrap_or(&node.label)
191            .to_string();
192        if !table_ids.contains_key(&key) {
193            let id = sanitize_id(&key);
194            table_ids.insert(key.clone(), id.clone());
195            let escaped_label = escape_label(&node.label);
196            let shape = match node.node_type {
197                NodeType::Cte => format!("([\"{escaped_label}\"])"),
198                NodeType::View => format!("[/\"{escaped_label}\"/]"),
199                _ => format!("[\"{escaped_label}\"]"),
200            };
201            lines.push(format!("    {id}{shape}"));
202        }
203    }
204
205    for edge in &result.edges {
206        if edge.edge_type == EdgeType::DataFlow || edge.edge_type == EdgeType::Derivation {
207            let source_node = table_nodes.iter().find(|node| node.id == edge.from);
208            let target_node = table_nodes.iter().find(|node| node.id == edge.to);
209
210            if let (Some(source), Some(target)) = (source_node, target_node) {
211                let source_key = source
212                    .qualified_name
213                    .as_deref()
214                    .unwrap_or(&source.label)
215                    .to_string();
216                let target_key = target
217                    .qualified_name
218                    .as_deref()
219                    .unwrap_or(&target.label)
220                    .to_string();
221                let edge_key = format!("{source_key}->{target_key}");
222
223                if source_key != target_key && edges.insert(edge_key) {
224                    let source_id = table_ids.get(&source_key).cloned().unwrap_or_else(|| {
225                        let id = sanitize_id(&source_key);
226                        table_ids.insert(source_key.clone(), id.clone());
227                        id
228                    });
229                    let target_id = table_ids.get(&target_key).cloned().unwrap_or_else(|| {
230                        let id = sanitize_id(&target_key);
231                        table_ids.insert(target_key.clone(), id.clone());
232                        id
233                    });
234                    lines.push(format!("    {source_id} --> {target_id}"));
235                }
236            }
237        }
238    }
239
240    lines.join("\n")
241}
242
243#[derive(Debug)]
244struct ColumnMapping {
245    source_table: String,
246    source_column: String,
247    target_table: String,
248    target_column: String,
249    edge_type: EdgeType,
250}
251
252fn extract_column_mappings(result: &AnalyzeResult) -> Vec<ColumnMapping> {
253    let mut mappings = Vec::new();
254
255    let table_nodes: Vec<_> = result
256        .nodes
257        .iter()
258        .filter(|node| node.node_type.is_table_like())
259        .collect();
260    let column_nodes: Vec<_> = result
261        .nodes
262        .iter()
263        .filter(|node| node.node_type == NodeType::Column)
264        .collect();
265
266    let mut column_to_table: HashMap<&str, &str> = HashMap::new();
267    for edge in &result.edges {
268        if edge.edge_type == EdgeType::Ownership {
269            if let Some(table_node) = table_nodes.iter().find(|node| node.id == edge.from) {
270                let table_name = table_node
271                    .qualified_name
272                    .as_deref()
273                    .unwrap_or(&table_node.label);
274                column_to_table.insert(edge.to.as_ref(), table_name);
275            }
276        }
277    }
278
279    for edge in &result.edges {
280        if edge.edge_type == EdgeType::Derivation || edge.edge_type == EdgeType::DataFlow {
281            let source_col = column_nodes.iter().find(|col| col.id == edge.from);
282            let target_col = column_nodes.iter().find(|col| col.id == edge.to);
283
284            if let (Some(source), Some(target)) = (source_col, target_col) {
285                let source_table = column_to_table
286                    .get(edge.from.as_ref())
287                    .copied()
288                    .unwrap_or("Output");
289                let target_table = column_to_table
290                    .get(edge.to.as_ref())
291                    .copied()
292                    .unwrap_or("Output");
293
294                mappings.push(ColumnMapping {
295                    source_table: source_table.to_string(),
296                    source_column: source.label.to_string(),
297                    target_table: target_table.to_string(),
298                    target_column: target.label.to_string(),
299                    edge_type: edge.edge_type,
300                });
301            }
302        }
303    }
304
305    mappings
306}
307
308fn generate_column_view(result: &AnalyzeResult) -> String {
309    let mut lines = vec!["flowchart LR".to_string()];
310    let mappings = extract_column_mappings(result);
311    let mut nodes = HashSet::new();
312    let mut edges = HashSet::new();
313
314    for mapping in mappings {
315        let source_id = sanitize_id(&format!(
316            "{}_{}",
317            mapping.source_table, mapping.source_column
318        ));
319        let target_id = sanitize_id(&format!(
320            "{}_{}",
321            mapping.target_table, mapping.target_column
322        ));
323        let source_label = format!("{}.{}", mapping.source_table, mapping.source_column);
324        let target_label = format!("{}.{}", mapping.target_table, mapping.target_column);
325
326        if nodes.insert(source_id.clone()) {
327            lines.push(format!(
328                "    {source_id}[\"{}\"]",
329                escape_label(&source_label)
330            ));
331        }
332        if nodes.insert(target_id.clone()) {
333            lines.push(format!(
334                "    {target_id}[\"{}\"]",
335                escape_label(&target_label)
336            ));
337        }
338
339        let edge_key = format!("{source_id}->{target_id}");
340        if edges.insert(edge_key) {
341            let edge_label = match mapping.edge_type {
342                EdgeType::Derivation => "derived",
343                _ => "flows",
344            };
345            lines.push(format!("    {source_id} -->|{edge_label}| {target_id}"));
346        }
347    }
348
349    lines.join("\n")
350}
351
352fn generate_hybrid_view(result: &AnalyzeResult) -> String {
353    let mut lines = vec!["flowchart LR".to_string()];
354    let scripts = extract_script_info(result);
355
356    let mut script_ids = HashMap::new();
357    for script in &scripts {
358        let id = sanitize_id(&format!("script_{}", script.source_name));
359        script_ids.insert(script.source_name.clone(), id.clone());
360        lines.push(format!(
361            "    {id}{{\"{}\"}}",
362            escape_label(&script.source_name)
363        ));
364    }
365
366    let mut table_ids = HashMap::new();
367    for node in &result.nodes {
368        if node.node_type.is_table_like() {
369            let key = node
370                .qualified_name
371                .as_deref()
372                .unwrap_or(&node.label)
373                .to_string();
374            if !table_ids.contains_key(&key) {
375                let id = sanitize_id(&format!("table_{}", key));
376                table_ids.insert(key.clone(), id.clone());
377                lines.push(format!("    {id}[\"{}\"]", escape_label(&node.label)));
378            }
379        }
380    }
381
382    for script in scripts {
383        let script_id = script_ids
384            .get(&script.source_name)
385            .cloned()
386            .unwrap_or_else(|| {
387                let id = sanitize_id(&format!("script_{}", script.source_name));
388                script_ids.insert(script.source_name.clone(), id.clone());
389                id
390            });
391
392        for table in &script.tables_read {
393            if let Some(table_id) = table_ids.get(table.as_ref()) {
394                lines.push(format!("    {script_id} --> {table_id}"));
395            }
396        }
397        for table in &script.tables_written {
398            if let Some(table_id) = table_ids.get(table.as_ref()) {
399                lines.push(format!("    {table_id} --> {script_id}"));
400            }
401        }
402    }
403
404    lines.join("\n")
405}