use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use image::{DynamicImage, RgbaImage};
use ratatui_image::{picker::Picker, protocol::StatefulProtocol};
use resvg::usvg;
use crate::markdown::MermaidBlockId;
fn font_db() -> &'static Arc<usvg::fontdb::Database> {
static DB: OnceLock<Arc<usvg::fontdb::Database>> = OnceLock::new();
DB.get_or_init(|| {
let mut db = usvg::fontdb::Database::new();
db.load_system_fonts();
Arc::new(db)
})
}
pub enum MermaidEntry {
Pending,
Ready(Box<StatefulProtocol>),
Failed(String),
SourceOnly(String),
}
pub struct MermaidCache {
entries: HashMap<MermaidBlockId, MermaidEntry>,
}
impl MermaidCache {
pub fn new() -> Self {
Self {
entries: HashMap::new(),
}
}
pub fn get_mut(&mut self, id: &MermaidBlockId) -> Option<&mut MermaidEntry> {
self.entries.get_mut(id)
}
pub fn insert(&mut self, id: MermaidBlockId, entry: MermaidEntry) {
self.entries.insert(id, entry);
}
pub fn ensure_queued(
&mut self,
id: MermaidBlockId,
source: &str,
picker: Option<&Picker>,
action_tx: &tokio::sync::mpsc::UnboundedSender<crate::action::Action>,
in_tmux: bool,
) -> bool {
if self.entries.contains_key(&id) {
return false;
}
let Some(picker) = picker else {
let reason = if in_tmux {
TMUX_DISABLED_REASON.to_string()
} else {
"graphics unavailable".to_string()
};
self.entries.insert(id, MermaidEntry::SourceOnly(reason));
return false;
};
self.entries.insert(id, MermaidEntry::Pending);
let source = source.to_string();
let picker = picker.clone();
let tx = action_tx.clone();
tokio::task::spawn_blocking(move || {
let result = render_blocking(source, &picker);
let entry = match result {
Ok(protocol) => MermaidEntry::Ready(Box::new(protocol)),
Err(e) => MermaidEntry::Failed(e),
};
let _ = tx.send(crate::action::Action::MermaidReady(id, Box::new(entry)));
});
true
}
}
fn render_blocking(source: String, picker: &Picker) -> Result<StatefulProtocol, String> {
let svg = mermaid_rs_renderer::render(&source).map_err(|e| format!("render error: {e}"))?;
let img = svg_to_image(&svg).map_err(|e| format!("svg rasterize: {e}"))?;
Ok(picker.new_resize_protocol(img))
}
const SVG_RENDER_SCALE: f32 = 3.0;
fn svg_to_image(svg: &str) -> Result<DynamicImage, String> {
let opts = usvg::Options {
fontdb: Arc::clone(font_db()),
..usvg::Options::default()
};
let tree = usvg::Tree::from_str(svg, &opts).map_err(|e| format!("usvg parse: {e}"))?;
let size = tree.size();
let width = (size.width() * SVG_RENDER_SCALE).ceil() as u32;
let height = (size.height() * SVG_RENDER_SCALE).ceil() as u32;
if width == 0 || height == 0 {
return Err("empty SVG dimensions".to_string());
}
let mut pixmap =
resvg::tiny_skia::Pixmap::new(width, height).ok_or("failed to allocate pixmap")?;
resvg::render(
&tree,
resvg::tiny_skia::Transform::from_scale(SVG_RENDER_SCALE, SVG_RENDER_SCALE),
&mut pixmap.as_mut(),
);
let raw = pixmap.take();
let rgba = demultiply_alpha(raw, width, height)?;
Ok(DynamicImage::ImageRgba8(rgba))
}
fn demultiply_alpha(data: Vec<u8>, width: u32, height: u32) -> Result<RgbaImage, String> {
let mut out = Vec::with_capacity(data.len());
for pixel in data.chunks_exact(4) {
let (r, g, b, a) = (pixel[0], pixel[1], pixel[2], pixel[3]);
if a == 0 {
out.extend_from_slice(&[0, 0, 0, 0]);
} else {
let factor = 255.0 / a as f32;
out.push((r as f32 * factor).min(255.0) as u8);
out.push((g as f32 * factor).min(255.0) as u8);
out.push((b as f32 * factor).min(255.0) as u8);
out.push(a);
}
}
RgbaImage::from_raw(width, height, out).ok_or("image buffer size mismatch".to_string())
}
pub fn create_picker() -> Option<Picker> {
if std::env::var("TMUX").is_ok() {
return None;
}
match Picker::from_query_stdio() {
Ok(picker) => Some(picker),
Err(_) => Some(Picker::halfblocks()),
}
}
pub const TMUX_DISABLED_REASON: &str = "disable tmux for graphics";
#[cfg(test)]
mod tests {
use super::*;
const SEQUENCE_DIAGRAM: &str = r#"sequenceDiagram
participant W as Worker
participant CP as CheckpointStore
participant ES as EventReader
W->>CP: Read checkpoint (last sequence)
CP-->>W: sequence_number
W->>ES: Poll events (after sequence, limit 500)
ES-->>W: batch of StoredEvents"#;
const GRAPH_LR_1: &str = r#"graph LR
subgraph Supervisor
direction TB
F[Factory] -->|creates| W[Worker]
W -->|panics/exits| F
end
W -->|beat every cycle| HB[Heartbeat]
HB -->|checked every 10s| WD[Watchdog]
WD -->|stall > 120s| CT[Cancel Token]
CT -->|stops| W
style WD fill:#c82,stroke:#fff,color:#fff"#;
const STATE_DIAGRAM: &str = r#"stateDiagram-v2
[*] --> CLOSED
CLOSED --> OPEN : 5 consecutive failures
OPEN --> HALF_OPEN : probe interval elapsed
HALF_OPEN --> CLOSED : probe succeeds
HALF_OPEN --> OPEN : probe fails (increased backoff)"#;
const GRAPH_LR_2: &str = r#"graph LR
subgraph projections-pg [projections-pg :9092]
PG_W[event_log, account_registry]
end
PG_W --> PG[(PostgreSQL)]
style PG fill:#336,stroke:#fff,color:#fff"#;
#[test]
fn render_four_target_diagrams() {
let diagrams = [
("sequenceDiagram", SEQUENCE_DIAGRAM),
("graph LR (resilience)", GRAPH_LR_1),
("stateDiagram-v2", STATE_DIAGRAM),
("graph LR (deployments)", GRAPH_LR_2),
];
let mut ready_count = 0;
let mut failed: Vec<(&str, String)> = Vec::new();
for (name, src) in &diagrams {
match mermaid_rs_renderer::render(src) {
Ok(svg) => match svg_to_image(&svg) {
Ok(_) => {
ready_count += 1;
}
Err(e) => failed.push((name, format!("rasterize: {e}"))),
},
Err(e) => failed.push((name, format!("mermaid: {e}"))),
}
}
assert!(
ready_count >= 2,
"only {ready_count}/4 diagrams rendered successfully; failures: {failed:?}"
);
}
}