use crate::{
parser::{parse_statements, SqlType, Statement},
SqlDialect,
};
use anyhow::Result;
use clap::ArgEnum;
use std::{
collections::{HashMap, HashSet},
fmt::Display,
};
#[derive(Debug, Clone, Copy, ArgEnum)]
pub enum Flow {
TB,
BT,
RL,
LR,
}
impl Display for Flow {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
Self::TB => "TB",
Self::BT => "BT",
Self::LR => "LR",
Self::RL => "RL",
})
}
}
#[derive(Debug)]
struct Config {
direction: Flow,
show_icons: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
direction: Flow::TB,
show_icons: true,
}
}
}
#[derive(Debug, Clone, Copy)]
enum RelationType {
Table,
View,
Unknown,
}
impl RelationType {
fn fa_icon(&self) -> &str {
match self {
Self::Table => "fa:fa-table ",
Self::View => "fa:fa-eye ",
_ => "",
}
}
}
struct Relation {
id: String,
kind: RelationType,
}
pub struct Mermaid {
statements: Vec<Statement>,
config: Config,
}
impl Mermaid {
pub fn try_new(sql: &str, dialect: SqlDialect) -> Result<Self> {
let statements = parse_statements(sql, dialect)?;
Ok(Self {
statements,
config: Default::default(),
})
}
pub fn flow(&mut self, flow: Flow) -> &mut Self {
self.config.direction = flow;
self
}
pub fn no_icons(&mut self, no_icons: bool) -> &mut Self {
self.config.show_icons = !no_icons;
self
}
fn relations(&self) -> Vec<(&str, RelationType)> {
let mut potential_tables = HashSet::new();
let mut potential_views = HashSet::new();
let mut all_relations = HashSet::new();
for Statement {
target,
kind,
sources,
} in &self.statements
{
all_relations.insert(target);
all_relations.extend(sources);
match kind {
SqlType::CreateView => potential_views.insert(target),
_ => potential_tables.insert(target),
};
}
all_relations
.into_iter()
.map(|relation| {
let is_potential_table = potential_tables.contains(relation);
let is_potential_view = potential_views.contains(relation);
let relation_type = match (is_potential_table, is_potential_view) {
(true, false) => RelationType::Table,
(false, true) => RelationType::View,
_ => RelationType::Unknown,
};
(relation.as_str(), relation_type)
})
.collect()
}
fn get_ids(&self) -> Vec<(&str, Relation)> {
let mut relations = self.relations();
relations.sort_unstable_by_key(|(relation, _)| *relation);
relations
.into_iter()
.zip(IdIter::new())
.map(|((relation, kind), id)| (relation, Relation { kind, id }))
.collect()
}
pub fn to_flowchart(&self) -> String {
let ids = self.get_ids();
let declarations: Vec<_> = ids
.iter()
.map(|(name, Relation { id, kind })| {
let icon = if self.config.show_icons {
kind.fa_icon()
} else {
""
};
format!("{id}[{icon}{name}]")
})
.collect();
let ids: HashMap<_, _> = ids.into_iter().collect();
let mut connections: Vec<_> = self
.statements
.iter()
.flat_map(
|Statement {
target, sources, ..
}| {
sources.iter().map(|source| {
let source_id = &ids.get(&source.as_str()).unwrap().id;
let target_id = &ids.get(&target.as_str()).unwrap().id;
format!("{source_id} --> {target_id}")
})
},
)
.collect();
connections.sort_unstable();
connections.dedup();
let declarations = declarations.join("\n ");
let connections = connections.join("\n ");
let direction = self.config.direction;
format!("graph {direction}\n {declarations}\n {connections}")
}
}
struct IdIter {
letters: Vec<char>,
prefixes: Vec<String>,
ids: Vec<String>,
prefix_pos: usize,
letter_pos: usize,
}
impl IdIter {
fn new() -> Self {
let letters: Vec<_> = ('A'..='Z').collect();
Self {
letters,
prefixes: vec!["".to_owned()],
ids: Vec::new(),
prefix_pos: 0,
letter_pos: 0,
}
}
}
impl Iterator for IdIter {
type Item = String;
fn next(&mut self) -> Option<Self::Item> {
let new_id = format!(
"{}{}",
self.prefixes[self.prefix_pos], self.letters[self.letter_pos]
);
self.ids.push(new_id.clone());
if self.letter_pos == self.letters.len() - 1 {
if self.prefix_pos == self.prefixes.len() - 1 {
self.prefixes = self.ids.clone();
self.ids = Vec::new();
self.prefix_pos = 0;
self.letter_pos = 0;
} else {
self.prefix_pos += 1;
self.letter_pos = 0;
}
} else {
self.letter_pos += 1;
}
Some(new_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn id_generation() {
let ids: Vec<_> = IdIter::new().take(3).collect();
assert_eq!(ids, vec!["A".to_owned(), "B".to_owned(), "C".to_owned()]);
let ids: Vec<_> = IdIter::new().skip(26).take(3).collect();
assert_eq!(ids, vec!["AA".to_owned(), "AB".to_owned(), "AC".to_owned()]);
}
}