Skip to main content

flowscope_export/
mermaid.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use flowscope_core::{AnalyzeResult, EdgeType, NodeType};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum MermaidView {
8    All,
9    Script,
10    Table,
11    Column,
12    Hybrid,
13}
14
15pub fn export_mermaid(result: &AnalyzeResult, view: MermaidView) -> String {
16    match view {
17        MermaidView::All => generate_all_views(result),
18        MermaidView::Script => generate_script_view(result),
19        MermaidView::Table => generate_table_view(result),
20        MermaidView::Column => generate_column_view(result),
21        MermaidView::Hybrid => generate_hybrid_view(result),
22    }
23}
24
25fn generate_all_views(result: &AnalyzeResult) -> String {
26    let sections = vec![
27        "# Lineage Diagrams".to_string(),
28        String::new(),
29        "## Script View".to_string(),
30        "```mermaid".to_string(),
31        generate_script_view(result),
32        "```".to_string(),
33        String::new(),
34        "## Hybrid View (Scripts + Tables)".to_string(),
35        "```mermaid".to_string(),
36        generate_hybrid_view(result),
37        "```".to_string(),
38        String::new(),
39        "## Table View".to_string(),
40        "```mermaid".to_string(),
41        generate_table_view(result),
42        "```".to_string(),
43        String::new(),
44        "## Column View".to_string(),
45        "```mermaid".to_string(),
46        generate_column_view(result),
47        "```".to_string(),
48    ];
49
50    sections.join("\n")
51}
52
53fn sanitize_id(id: &str) -> String {
54    id.chars()
55        .map(|c| {
56            if c.is_alphanumeric() || c == '_' {
57                c
58            } else {
59                '_'
60            }
61        })
62        .collect()
63}
64
65fn escape_label(label: &str) -> String {
66    label.replace('"', "\\\"").replace('\n', " ")
67}
68
69#[derive(Debug)]
70struct ScriptInfo {
71    source_name: String,
72    tables_read: HashSet<Arc<str>>,
73    tables_written: HashSet<Arc<str>>,
74}
75
76fn extract_script_info(result: &AnalyzeResult) -> Vec<ScriptInfo> {
77    let mut script_map: HashMap<String, ScriptInfo> = HashMap::new();
78
79    for stmt in &result.statements {
80        let source_name = stmt
81            .source_name
82            .clone()
83            .unwrap_or_else(|| "default".to_string());
84
85        let entry = script_map
86            .entry(source_name.clone())
87            .or_insert_with(|| ScriptInfo {
88                source_name: source_name.clone(),
89                tables_read: HashSet::new(),
90                tables_written: HashSet::new(),
91            });
92
93        for node in &stmt.nodes {
94            if matches!(node.node_type, NodeType::Table | NodeType::View) {
95                let is_written = stmt
96                    .edges
97                    .iter()
98                    .any(|edge| edge.to == node.id && edge.edge_type == EdgeType::DataFlow);
99                let is_read = stmt
100                    .edges
101                    .iter()
102                    .any(|edge| edge.from == node.id && edge.edge_type == EdgeType::DataFlow);
103
104                let table_name = node
105                    .qualified_name
106                    .clone()
107                    .unwrap_or_else(|| node.label.clone());
108
109                if is_written {
110                    entry.tables_written.insert(table_name.clone());
111                }
112                if is_read || !is_written {
113                    entry.tables_read.insert(table_name);
114                }
115            }
116        }
117    }
118
119    script_map.into_values().collect()
120}
121
122fn generate_script_view(result: &AnalyzeResult) -> String {
123    let scripts = extract_script_info(result);
124    let mut lines = vec!["flowchart LR".to_string()];
125
126    for script in &scripts {
127        let id = sanitize_id(&script.source_name);
128        let label = escape_label(&script.source_name);
129        lines.push(format!("    {id}[\"{label}\"]"));
130    }
131
132    for producer in &scripts {
133        for consumer in &scripts {
134            if producer.source_name == consumer.source_name {
135                continue;
136            }
137
138            let shared_tables: Vec<_> = producer
139                .tables_written
140                .iter()
141                .filter(|table| consumer.tables_read.contains(*table))
142                .collect();
143
144            if !shared_tables.is_empty() {
145                let producer_id = sanitize_id(&producer.source_name);
146                let consumer_id = sanitize_id(&consumer.source_name);
147                let label = if shared_tables.len() > 3 {
148                    let first_three: Vec<_> = shared_tables.iter().take(3).collect();
149                    format!(
150                        "{}...",
151                        first_three
152                            .iter()
153                            .map(|value| value.as_ref())
154                            .collect::<Vec<_>>()
155                            .join(", ")
156                    )
157                } else {
158                    shared_tables
159                        .iter()
160                        .map(|value| value.as_ref())
161                        .collect::<Vec<_>>()
162                        .join(", ")
163                };
164                lines.push(format!(
165                    "    {producer_id} -->|\"{}\"| {consumer_id}",
166                    escape_label(&label)
167                ));
168            }
169        }
170    }
171
172    lines.join("\n")
173}
174
175fn generate_table_view(result: &AnalyzeResult) -> String {
176    let mut lines = vec!["flowchart LR".to_string()];
177    let mut table_ids: HashMap<String, String> = HashMap::new();
178    let mut edges = HashSet::new();
179
180    for stmt in &result.statements {
181        let table_nodes: Vec<_> = stmt
182            .nodes
183            .iter()
184            .filter(|node| node.node_type.is_table_like())
185            .collect();
186
187        for node in &table_nodes {
188            let key = node
189                .qualified_name
190                .as_deref()
191                .unwrap_or(&node.label)
192                .to_string();
193            if !table_ids.contains_key(&key) {
194                let id = sanitize_id(&key);
195                table_ids.insert(key.clone(), id.clone());
196                let escaped_label = escape_label(&node.label);
197                let shape = match node.node_type {
198                    NodeType::Cte => format!("([\"{escaped_label}\"])"),
199                    NodeType::View => format!("[/\"{escaped_label}\"/]"),
200                    _ => format!("[\"{escaped_label}\"]"),
201                };
202                lines.push(format!("    {id}{shape}"));
203            }
204        }
205
206        for edge in &stmt.edges {
207            if edge.edge_type == EdgeType::DataFlow || edge.edge_type == EdgeType::Derivation {
208                let source_node = table_nodes.iter().find(|node| node.id == edge.from);
209                let target_node = table_nodes.iter().find(|node| node.id == edge.to);
210
211                if let (Some(source), Some(target)) = (source_node, target_node) {
212                    let source_key = source
213                        .qualified_name
214                        .as_deref()
215                        .unwrap_or(&source.label)
216                        .to_string();
217                    let target_key = target
218                        .qualified_name
219                        .as_deref()
220                        .unwrap_or(&target.label)
221                        .to_string();
222                    let edge_key = format!("{source_key}->{target_key}");
223
224                    if source_key != target_key && edges.insert(edge_key) {
225                        let source_id = table_ids.get(&source_key).cloned().unwrap_or_else(|| {
226                            let id = sanitize_id(&source_key);
227                            table_ids.insert(source_key.clone(), id.clone());
228                            id
229                        });
230                        let target_id = table_ids.get(&target_key).cloned().unwrap_or_else(|| {
231                            let id = sanitize_id(&target_key);
232                            table_ids.insert(target_key.clone(), id.clone());
233                            id
234                        });
235                        lines.push(format!("    {source_id} --> {target_id}"));
236                    }
237                }
238            }
239        }
240    }
241
242    lines.join("\n")
243}
244
245#[derive(Debug)]
246struct ColumnMapping {
247    source_table: String,
248    source_column: String,
249    target_table: String,
250    target_column: String,
251    edge_type: EdgeType,
252}
253
254fn extract_column_mappings(result: &AnalyzeResult) -> Vec<ColumnMapping> {
255    let mut mappings = Vec::new();
256
257    for stmt in &result.statements {
258        let table_nodes: Vec<_> = stmt
259            .nodes
260            .iter()
261            .filter(|node| node.node_type.is_table_like())
262            .collect();
263        let column_nodes: Vec<_> = stmt
264            .nodes
265            .iter()
266            .filter(|node| node.node_type == NodeType::Column)
267            .collect();
268
269        let mut column_to_table: HashMap<&str, &str> = HashMap::new();
270        for edge in &stmt.edges {
271            if edge.edge_type == EdgeType::Ownership {
272                if let Some(table_node) = table_nodes.iter().find(|node| node.id == edge.from) {
273                    let table_name = table_node
274                        .qualified_name
275                        .as_deref()
276                        .unwrap_or(&table_node.label);
277                    column_to_table.insert(edge.to.as_ref(), table_name);
278                }
279            }
280        }
281
282        for edge in &stmt.edges {
283            if edge.edge_type == EdgeType::Derivation || edge.edge_type == EdgeType::DataFlow {
284                let source_col = column_nodes.iter().find(|col| col.id == edge.from);
285                let target_col = column_nodes.iter().find(|col| col.id == edge.to);
286
287                if let (Some(source), Some(target)) = (source_col, target_col) {
288                    let source_table = column_to_table
289                        .get(edge.from.as_ref())
290                        .copied()
291                        .unwrap_or("Output");
292                    let target_table = column_to_table
293                        .get(edge.to.as_ref())
294                        .copied()
295                        .unwrap_or("Output");
296
297                    mappings.push(ColumnMapping {
298                        source_table: source_table.to_string(),
299                        source_column: source.label.to_string(),
300                        target_table: target_table.to_string(),
301                        target_column: target.label.to_string(),
302                        edge_type: edge.edge_type,
303                    });
304                }
305            }
306        }
307    }
308
309    mappings
310}
311
312fn generate_column_view(result: &AnalyzeResult) -> String {
313    let mut lines = vec!["flowchart LR".to_string()];
314    let mappings = extract_column_mappings(result);
315    let mut nodes = HashSet::new();
316    let mut edges = HashSet::new();
317
318    for mapping in mappings {
319        let source_id = sanitize_id(&format!(
320            "{}_{}",
321            mapping.source_table, mapping.source_column
322        ));
323        let target_id = sanitize_id(&format!(
324            "{}_{}",
325            mapping.target_table, mapping.target_column
326        ));
327        let source_label = format!("{}.{}", mapping.source_table, mapping.source_column);
328        let target_label = format!("{}.{}", mapping.target_table, mapping.target_column);
329
330        if nodes.insert(source_id.clone()) {
331            lines.push(format!(
332                "    {source_id}[\"{}\"]",
333                escape_label(&source_label)
334            ));
335        }
336        if nodes.insert(target_id.clone()) {
337            lines.push(format!(
338                "    {target_id}[\"{}\"]",
339                escape_label(&target_label)
340            ));
341        }
342
343        let edge_key = format!("{source_id}->{target_id}");
344        if edges.insert(edge_key) {
345            let edge_label = match mapping.edge_type {
346                EdgeType::Derivation => "derived",
347                _ => "flows",
348            };
349            lines.push(format!("    {source_id} -->|{edge_label}| {target_id}"));
350        }
351    }
352
353    lines.join("\n")
354}
355
356fn generate_hybrid_view(result: &AnalyzeResult) -> String {
357    let mut lines = vec!["flowchart LR".to_string()];
358    let scripts = extract_script_info(result);
359
360    let mut script_ids = HashMap::new();
361    for script in &scripts {
362        let id = sanitize_id(&format!("script_{}", script.source_name));
363        script_ids.insert(script.source_name.clone(), id.clone());
364        lines.push(format!(
365            "    {id}{{\"{}\"}}",
366            escape_label(&script.source_name)
367        ));
368    }
369
370    let mut table_ids = HashMap::new();
371    for stmt in &result.statements {
372        for node in &stmt.nodes {
373            if node.node_type.is_table_like() {
374                let key = node
375                    .qualified_name
376                    .as_deref()
377                    .unwrap_or(&node.label)
378                    .to_string();
379                if !table_ids.contains_key(&key) {
380                    let id = sanitize_id(&format!("table_{}", key));
381                    table_ids.insert(key.clone(), id.clone());
382                    lines.push(format!("    {id}[\"{}\"]", escape_label(&node.label)));
383                }
384            }
385        }
386    }
387
388    for script in scripts {
389        let script_id = script_ids
390            .get(&script.source_name)
391            .cloned()
392            .unwrap_or_else(|| {
393                let id = sanitize_id(&format!("script_{}", script.source_name));
394                script_ids.insert(script.source_name.clone(), id.clone());
395                id
396            });
397
398        for table in &script.tables_read {
399            if let Some(table_id) = table_ids.get(table.as_ref()) {
400                lines.push(format!("    {script_id} --> {table_id}"));
401            }
402        }
403        for table in &script.tables_written {
404            if let Some(table_id) = table_ids.get(table.as_ref()) {
405                lines.push(format!("    {table_id} --> {script_id}"));
406            }
407        }
408    }
409
410    lines.join("\n")
411}