kitmd 0.2.1

A terminal-based markdown and mermaid renderer/viewer using the Kitty graphics protocol
use std::collections::{BTreeMap, HashMap, VecDeque};

use crate::mermaid_engine::config::LayoutConfig;
use crate::mermaid_engine::ir::Graph;
use crate::mermaid_engine::theme::Theme;

use super::text::measure_label;
use super::{
    DiagramData, EdgeLayout, Layout, NodeLayout, SankeyLayout, SankeyLinkLayout, SankeyNodeLayout,
    resolve_node_style,
};

pub(super) fn compute_sankey_layout(graph: &Graph, theme: &Theme, config: &LayoutConfig) -> Layout {
    const SANKEY_MIN_WIDTH: f32 = 560.0;
    const SANKEY_MAX_WIDTH: f32 = 640.0;
    const SANKEY_HEIGHT: f32 = 360.0;
    const SANKEY_NODE_WIDTH: f32 = 10.0;
    const SANKEY_PALETTE: [&str; 10] = [
        "#4e79a7", "#f28e2c", "#e15759", "#76b7b2", "#59a14f", "#edc949", "#af7aa1", "#ff9da7",
        "#9c755f", "#bab0ab",
    ];

    let mut node_ids: Vec<String> = graph.nodes.keys().cloned().collect();
    node_ids.sort_by(|a, b| {
        let order_a = graph.node_order.get(a).copied().unwrap_or(usize::MAX);
        let order_b = graph.node_order.get(b).copied().unwrap_or(usize::MAX);
        order_a.cmp(&order_b).then_with(|| a.cmp(b))
    });

    let node_count = node_ids.len();
    let mut id_to_idx: HashMap<String, usize> = HashMap::new();
    for (idx, id) in node_ids.iter().enumerate() {
        id_to_idx.insert(id.clone(), idx);
    }

    let node_order_idx: Vec<usize> = node_ids
        .iter()
        .map(|id| graph.node_order.get(id).copied().unwrap_or(usize::MAX))
        .collect();

    #[derive(Debug, Clone)]
    struct SankeyEdgeData {
        from_idx: usize,
        to_idx: usize,
        value: f32,
    }

    let mut edges_data: Vec<SankeyEdgeData> = Vec::new();
    let mut incoming: Vec<Vec<usize>> = vec![Vec::new(); node_count];
    let mut outgoing: Vec<Vec<usize>> = vec![Vec::new(); node_count];
    let mut indegree: Vec<usize> = vec![0; node_count];
    let mut in_total: Vec<f32> = vec![0.0; node_count];
    let mut out_total: Vec<f32> = vec![0.0; node_count];

    for edge in &graph.edges {
        let Some(&from_idx) = id_to_idx.get(&edge.from) else {
            continue;
        };
        let Some(&to_idx) = id_to_idx.get(&edge.to) else {
            continue;
        };
        let raw_value = edge
            .label
            .as_deref()
            .and_then(|text| text.parse::<f32>().ok())
            .unwrap_or(1.0);
        let value = if raw_value.is_finite() {
            raw_value.max(0.0)
        } else {
            0.0
        };
        let edge_idx = edges_data.len();
        edges_data.push(SankeyEdgeData {
            from_idx,
            to_idx,
            value,
        });
        outgoing[from_idx].push(edge_idx);
        incoming[to_idx].push(edge_idx);
        indegree[to_idx] += 1;
        out_total[from_idx] += value;
        in_total[to_idx] += value;
    }

    let mut ranks = vec![0usize; node_count];
    let mut indegree_work = indegree.clone();
    let mut queue: VecDeque<usize> = indegree_work
        .iter()
        .enumerate()
        .filter_map(|(idx, deg)| (*deg == 0).then_some(idx))
        .collect();
    let mut topo = Vec::with_capacity(node_count);
    while let Some(node_idx) = queue.pop_front() {
        topo.push(node_idx);
        for &edge_idx in &outgoing[node_idx] {
            let to_idx = edges_data[edge_idx].to_idx;
            if indegree_work[to_idx] > 0 {
                indegree_work[to_idx] -= 1;
                if indegree_work[to_idx] == 0 {
                    queue.push_back(to_idx);
                }
            }
        }
    }
    if topo.len() == node_count {
        for &node_idx in &topo {
            for &edge_idx in &outgoing[node_idx] {
                let to_idx = edges_data[edge_idx].to_idx;
                ranks[to_idx] = ranks[to_idx].max(ranks[node_idx] + 1);
            }
        }
    }

    let max_rank = ranks.iter().copied().max().unwrap_or(0);
    let num_ranks = max_rank + 1;
    let sankey_width = (SANKEY_MIN_WIDTH + num_ranks.saturating_sub(2) as f32 * 25.0)
        .clamp(SANKEY_MIN_WIDTH, SANKEY_MAX_WIDTH);
    let gap_x = if num_ranks > 1 {
        ((sankey_width - SANKEY_NODE_WIDTH * num_ranks as f32) / (num_ranks - 1) as f32).max(0.0)
    } else {
        0.0
    };

    let mut totals = vec![0.0f32; node_count];
    for idx in 0..node_count {
        let total = in_total[idx].max(out_total[idx]);
        totals[idx] = if total > 0.0 { total } else { 1.0 };
    }
    let max_total = totals.iter().copied().fold(0.0, f32::max).max(1.0);
    let scale = SANKEY_HEIGHT / max_total;

    let mut node_x = vec![0.0f32; node_count];
    let mut node_y = vec![0.0f32; node_count];
    let mut node_h = vec![0.0f32; node_count];
    for idx in 0..node_count {
        let rank = ranks[idx];
        node_x[idx] = rank as f32 * (SANKEY_NODE_WIDTH + gap_x);
        node_h[idx] = totals[idx] * scale;
    }

    let mut rank_nodes: Vec<Vec<usize>> = vec![Vec::new(); num_ranks];
    for idx in 0..node_count {
        rank_nodes[ranks[idx]].push(idx);
    }
    for nodes_in_rank in &mut rank_nodes {
        nodes_in_rank.sort_by(|a, b| {
            node_order_idx[*a]
                .cmp(&node_order_idx[*b])
                .then_with(|| node_ids[*a].cmp(&node_ids[*b]))
        });
    }

    let mut outbound_order = outgoing.clone();
    for edges in &mut outbound_order {
        edges.sort_by(|a, b| {
            let target_a = edges_data[*a].to_idx;
            let target_b = edges_data[*b].to_idx;
            ranks[target_b]
                .cmp(&ranks[target_a])
                .then_with(|| node_order_idx[target_a].cmp(&node_order_idx[target_b]))
                .then_with(|| node_ids[target_a].cmp(&node_ids[target_b]))
        });
    }

    let edge_thickness: Vec<f32> = edges_data.iter().map(|edge| edge.value * scale).collect();
    let mut link_top = vec![0.0f32; edges_data.len()];
    let mut outbound_offset = vec![0.0f32; edges_data.len()];
    let mut acc = vec![0.0f32; node_count];

    fn compute_link_tops(
        node_positions: &[f32],
        outbound_order: &[Vec<usize>],
        edge_thickness: &[f32],
        link_top: &mut [f32],
        outbound_offset: &mut [f32],
        acc: &mut [f32],
    ) {
        link_top.fill(0.0);
        outbound_offset.fill(0.0);
        acc.fill(0.0);
        for source_idx in 0..outbound_order.len() {
            for &edge_idx in &outbound_order[source_idx] {
                let offset = acc[source_idx];
                outbound_offset[edge_idx] = offset;
                link_top[edge_idx] = node_positions[source_idx] + offset;
                acc[source_idx] += edge_thickness[edge_idx];
            }
        }
    }

    for rank in 1..=max_rank {
        compute_link_tops(
            &node_y,
            &outbound_order,
            &edge_thickness,
            &mut link_top,
            &mut outbound_offset,
            &mut acc,
        );
        for &node_idx in &rank_nodes[rank] {
            let mut min_top = f32::INFINITY;
            for &edge_idx in &incoming[node_idx] {
                let from_idx = edges_data[edge_idx].from_idx;
                if ranks[from_idx] >= rank {
                    continue;
                }
                min_top = min_top.min(link_top[edge_idx]);
            }
            if !min_top.is_finite() {
                continue;
            }
            let max_y = (SANKEY_HEIGHT - node_h[node_idx]).max(0.0);
            node_y[node_idx] = min_top.clamp(0.0, max_y);
        }
    }
    compute_link_tops(
        &node_y,
        &outbound_order,
        &edge_thickness,
        &mut link_top,
        &mut outbound_offset,
        &mut acc,
    );

    let mut node_colors = Vec::with_capacity(node_count);
    for idx in 0..node_count {
        let default_color = SANKEY_PALETTE[idx % SANKEY_PALETTE.len()].to_string();
        let mut style = resolve_node_style(node_ids[idx].as_str(), graph);
        let color = style.fill.clone().unwrap_or(default_color);
        if style.fill.is_none() {
            style.fill = Some(color.clone());
        }
        if style.stroke.is_none() {
            style.stroke = Some("none".to_string());
        }
        if style.stroke_width.is_none() {
            style.stroke_width = Some(0.0);
        }
        node_colors.push((color, style));
    }

    let mut nodes = BTreeMap::new();
    let mut sankey_nodes = Vec::with_capacity(node_count);
    for idx in 0..node_count {
        let id = node_ids[idx].clone();
        let label = graph
            .nodes
            .get(&id)
            .map(|node| node.label.clone())
            .unwrap_or_else(|| id.clone());
        let (color, style) = &node_colors[idx];
        let label_block = measure_label(&label, theme, config);
        nodes.insert(
            id.clone(),
            NodeLayout {
                id: id.clone(),
                x: node_x[idx],
                y: node_y[idx],
                width: SANKEY_NODE_WIDTH,
                height: node_h[idx],
                label: label_block,
                shape: crate::mermaid_engine::ir::NodeShape::Rectangle,
                style: style.clone(),
                link: graph.node_links.get(&id).cloned(),
                anchor_subgraph: None,
                hidden: false,
                icon: None,
            },
        );
        sankey_nodes.push(SankeyNodeLayout {
            id: id.clone(),
            label,
            total: totals[idx],
            rank: ranks[idx],
            x: node_x[idx],
            y: node_y[idx],
            width: SANKEY_NODE_WIDTH,
            height: node_h[idx],
            color: color.clone(),
        });
    }

    let mut edges = Vec::with_capacity(edges_data.len());
    let mut sankey_links = Vec::with_capacity(edges_data.len());
    for (edge_idx, edge) in edges_data.iter().enumerate() {
        let from_id = node_ids[edge.from_idx].clone();
        let to_id = node_ids[edge.to_idx].clone();
        let thickness = edge_thickness[edge_idx];
        if thickness <= 0.0 {
            continue;
        }
        let start_x = node_x[edge.from_idx] + SANKEY_NODE_WIDTH;
        let end_x = node_x[edge.to_idx];
        let start_y = node_y[edge.from_idx] + outbound_offset[edge_idx] + thickness / 2.0;
        let inbound_offset = (link_top[edge_idx] - node_y[edge.to_idx]).max(0.0);
        let end_y = node_y[edge.to_idx] + inbound_offset + thickness / 2.0;
        let (color_start, _) = &node_colors[edge.from_idx];
        let (color_end, _) = &node_colors[edge.to_idx];
        let gradient_id = format!("sankey-grad-{edge_idx}");

        edges.push(EdgeLayout {
            from: from_id.clone(),
            to: to_id.clone(),
            label: None,
            start_label: None,
            end_label: None,
            label_anchor: None,
            start_label_anchor: None,
            end_label_anchor: None,
            points: vec![(start_x, start_y), (end_x, end_y)],
            directed: false,
            arrow_start: false,
            arrow_end: false,
            arrow_start_kind: None,
            arrow_end_kind: None,
            start_decoration: None,
            end_decoration: None,
            style: crate::mermaid_engine::ir::EdgeStyle::Solid,
            override_style: crate::mermaid_engine::ir::EdgeStyleOverride {
                stroke: Some(color_start.clone()),
                stroke_width: Some(thickness),
                dasharray: None,
                label_color: None,
            },
        });
        sankey_links.push(SankeyLinkLayout {
            source: from_id,
            target: to_id,
            value: edge.value,
            thickness,
            start: (start_x, start_y),
            end: (end_x, end_y),
            color_start: color_start.clone(),
            color_end: color_end.clone(),
            gradient_id,
        });
    }

    Layout {
        kind: graph.kind,
        nodes,
        edges,
        subgraphs: Vec::new(),
        width: sankey_width,
        height: SANKEY_HEIGHT,
        diagram: DiagramData::Sankey(SankeyLayout {
            width: sankey_width,
            height: SANKEY_HEIGHT,
            node_width: SANKEY_NODE_WIDTH,
            nodes: sankey_nodes,
            links: sankey_links,
        }),
    }
}