esquel 0.1.2

create mermaid charts from sql scripts
Documentation
use crate::{
    parser::{parse_statements, SqlType, Statement},
    SqlDialect,
};
use anyhow::Result;
use clap::ArgEnum;
use std::{
    collections::{HashMap, HashSet},
    fmt::Display,
};

#[derive(Debug, Clone, Copy, ArgEnum)]
pub enum Flow {
    TB,
    BT,
    RL,
    LR,
}

impl Display for Flow {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str(match self {
            Self::TB => "TB",
            Self::BT => "BT",
            Self::LR => "LR",
            Self::RL => "RL",
        })
    }
}

#[derive(Debug)]
struct Config {
    direction: Flow,
    show_icons: bool,
}

impl Default for Config {
    fn default() -> Self {
        Self {
            direction: Flow::TB,
            show_icons: true,
        }
    }
}

#[derive(Debug, Clone, Copy)]
enum RelationType {
    Table,
    View,
    Unknown,
}

impl RelationType {
    fn fa_icon(&self) -> &str {
        match self {
            Self::Table => "fa:fa-table ",
            Self::View => "fa:fa-eye ",
            _ => "",
        }
    }
}

struct Relation {
    id: String,
    kind: RelationType,
}

pub struct Mermaid {
    statements: Vec<Statement>,
    config: Config,
}

impl Mermaid {
    pub fn try_new(sql: &str, dialect: SqlDialect) -> Result<Self> {
        let statements = parse_statements(sql, dialect)?;
        Ok(Self {
            statements,
            config: Default::default(),
        })
    }

    pub fn flow(&mut self, flow: Flow) -> &mut Self {
        self.config.direction = flow;
        self
    }

    pub fn no_icons(&mut self, no_icons: bool) -> &mut Self {
        self.config.show_icons = !no_icons;
        self
    }

    fn relations(&self) -> Vec<(&str, RelationType)> {
        let mut potential_tables = HashSet::new();
        let mut potential_views = HashSet::new();
        let mut all_relations = HashSet::new();
        for Statement {
            target,
            kind,
            sources,
        } in &self.statements
        {
            all_relations.insert(target);
            all_relations.extend(sources);
            match kind {
                SqlType::CreateView => potential_views.insert(target),
                _ => potential_tables.insert(target),
            };
        }
        all_relations
            .into_iter()
            .map(|relation| {
                let is_potential_table = potential_tables.contains(relation);
                let is_potential_view = potential_views.contains(relation);
                let relation_type = match (is_potential_table, is_potential_view) {
                    (true, false) => RelationType::Table,
                    (false, true) => RelationType::View,
                    _ => RelationType::Unknown,
                };
                (relation.as_str(), relation_type)
            })
            .collect()
    }

    fn get_ids(&self) -> Vec<(&str, Relation)> {
        let mut relations = self.relations();
        relations.sort_unstable_by_key(|(relation, _)| *relation);
        relations
            .into_iter()
            .zip(IdIter::new())
            .map(|((relation, kind), id)| (relation, Relation { kind, id }))
            .collect()
    }

    pub fn to_flowchart(&self) -> String {
        let ids = self.get_ids();
        let declarations: Vec<_> = ids
            .iter()
            .map(|(name, Relation { id, kind })| {
                let icon = if self.config.show_icons {
                    kind.fa_icon()
                } else {
                    ""
                };
                format!("{id}[{icon}{name}]")
            })
            .collect();
        let ids: HashMap<_, _> = ids.into_iter().collect();
        let mut connections: Vec<_> = self
            .statements
            .iter()
            .flat_map(
                |Statement {
                     target, sources, ..
                 }| {
                    sources.iter().map(|source| {
                        let source_id = &ids.get(&source.as_str()).unwrap().id;
                        let target_id = &ids.get(&target.as_str()).unwrap().id;
                        format!("{source_id} --> {target_id}")
                    })
                },
            )
            .collect();
        connections.sort_unstable();
        connections.dedup();
        let declarations = declarations.join("\n    ");
        let connections = connections.join("\n    ");
        let direction = self.config.direction;
        format!("graph {direction}\n    {declarations}\n    {connections}")
    }
}

struct IdIter {
    letters: Vec<char>,
    prefixes: Vec<String>,
    ids: Vec<String>,
    prefix_pos: usize,
    letter_pos: usize,
}

impl IdIter {
    fn new() -> Self {
        let letters: Vec<_> = ('A'..='Z').collect();
        Self {
            letters,
            prefixes: vec!["".to_owned()],
            ids: Vec::new(),
            prefix_pos: 0,
            letter_pos: 0,
        }
    }
}

impl Iterator for IdIter {
    type Item = String;

    fn next(&mut self) -> Option<Self::Item> {
        let new_id = format!(
            "{}{}",
            self.prefixes[self.prefix_pos], self.letters[self.letter_pos]
        );
        self.ids.push(new_id.clone());
        if self.letter_pos == self.letters.len() - 1 {
            if self.prefix_pos == self.prefixes.len() - 1 {
                self.prefixes = self.ids.clone();
                self.ids = Vec::new();
                self.prefix_pos = 0;
                self.letter_pos = 0;
            } else {
                self.prefix_pos += 1;
                self.letter_pos = 0;
            }
        } else {
            self.letter_pos += 1;
        }
        Some(new_id)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn id_generation() {
        let ids: Vec<_> = IdIter::new().take(3).collect();
        assert_eq!(ids, vec!["A".to_owned(), "B".to_owned(), "C".to_owned()]);
        let ids: Vec<_> = IdIter::new().skip(26).take(3).collect();
        assert_eq!(ids, vec!["AA".to_owned(), "AB".to_owned(), "AC".to_owned()]);
    }
}