mdbook_diagrams/
lib.rs

1use anyhow::{Result, bail};
2use mdbook::BookItem;
3use mdbook::book::{Book, Chapter};
4use mdbook::preprocess::{Preprocessor, PreprocessorContext};
5use regex::Regex;
6use sha2::{Sha256, Digest};
7use std::collections::HashMap;
8use std::num::NonZero;
9use std::ops::Range;
10use std::path::Path;
11use std::path::PathBuf;
12use std::sync::Arc;
13use futures::future::BoxFuture;
14use futures::FutureExt;
15use tokio::io::AsyncWriteExt;
16use toml::value::Table;
17
18/// Rendering mode for diagrams
19#[derive(Debug, Clone, Copy, PartialEq)]
20enum RenderMode {
21    /// Pre-render diagrams at build time (default)
22    PreRender,
23    /// Embed mermaid code and render at runtime in browser
24    Runtime,
25}
26
27/// Represents an edit to be applied to a chapter's content.
28struct ChapterEdit {
29    chapter_path: PathBuf,
30    range: Range<usize>, // The byte range of the original mermaid block
31    new_string: String,
32    cached_filename: String, // The cache filename that was used or created
33}
34
35pub struct DiagramsPreprocessor {
36    render_mode: RenderMode,
37    mmdc_cmd: String,
38    output_format: String,
39    enable_cache: bool,
40}
41
42impl DiagramsPreprocessor {
43    pub fn new(config: Option<Table>) -> DiagramsPreprocessor {
44        let render_mode = config
45            .as_ref()
46            .and_then(|table| table.get("render-mode"))
47            .and_then(|val| val.as_str())
48            .map(|s| match s {
49                "runtime" => RenderMode::Runtime,
50                "pre-render" => RenderMode::PreRender,
51                _ => {
52                    eprintln!("[mdbook-diagrams] Invalid render-mode: {}, falling back to pre-render", s);
53                    eprintln!("[mdbook-diagrams] Available modes: runtime, pre-render");
54                    RenderMode::PreRender
55                },
56            })
57            .unwrap_or(RenderMode::PreRender);
58
59        let mmdc_cmd = config
60            .as_ref()
61            .and_then(|table| table.get("mmdc-cmd"))
62            .and_then(|val| val.as_str())
63            .map(|s| s.to_string())
64            .unwrap_or_else(|| "mmdc".to_string());
65
66        let output_format = config
67            .as_ref()
68            .and_then(|table| table.get("output-format"))
69            .and_then(|val| val.as_str())
70            .map(|s| match s {
71                "svg" => "svg".to_string(),
72                "png" => "png".to_string(),
73                _ => {
74                    eprintln!("[mdbook-diagrams] Invalid output-format: {}, falling back to svg", s);
75                    eprintln!("[mdbook-diagrams] Available formats: svg, png");
76                    "svg".to_string()
77                },
78            })
79            .unwrap_or_else(|| "svg".to_string());
80
81        let enable_cache = config
82            .as_ref()
83            .and_then(|table| table.get("enable-cache"))
84            .and_then(|val| val.as_bool())
85            .unwrap_or(true);
86
87        DiagramsPreprocessor {
88            render_mode,
89            mmdc_cmd,
90            output_format,
91            enable_cache,
92        }
93    }
94
95    /// Compute a cache hash from diagram content and rendering configuration
96    fn compute_cache_hash(content: &str, output_format: &str, mmdc_cmd: &str) -> String {
97        let mut hasher = Sha256::new();
98        hasher.update(content.as_bytes());
99        hasher.update(output_format.as_bytes());
100        hasher.update(mmdc_cmd.as_bytes());
101        let result = hasher.finalize();
102        format!("{:x}", result)
103    }
104
105    fn prepare_mermaid_files(&self, ctx: &PreprocessorContext) -> Result<()> {
106        let theme_dir = ctx.root.join("theme");
107        std::fs::create_dir_all(&theme_dir)?;
108
109        let mermaid_js_path = theme_dir.join("mermaid.min.js");
110        let mermaid_init_path = theme_dir.join("mermaid-init.js");
111
112        let mut js_updated = false;
113
114        // Download mermaid.min.js if it doesn't exist
115        if !mermaid_js_path.exists() {
116            eprintln!("Downloading mermaid.min.js...");
117            let url = "https://cdn.jsdelivr.net/npm/mermaid@11/dist/mermaid.min.js";
118            let response = reqwest::blocking::get(url)?;
119            let content = response.bytes()?;
120            std::fs::write(&mermaid_js_path, content)?;
121            js_updated = true;
122            eprintln!("Downloaded mermaid.min.js to theme/mermaid.min.js");
123        }
124
125        // Create mermaid-init.js if it doesn't exist
126        if !mermaid_init_path.exists() {
127            let init_script = r#"import mermaid from './mermaid.min.js';
128mermaid.initialize({ startOnLoad: true });
129"#;
130            std::fs::write(&mermaid_init_path, init_script)?;
131            js_updated = true;
132            eprintln!("Created mermaid-init.js at theme/mermaid-init.js");
133        }
134
135        if js_updated {
136            eprintln!("[mdbook-diagrams] mermaid.min.js and mermaid-init.js is created in theme/ directory.");
137            eprintln!("[mdbook-diagrams] To enable runtime rendering, please add the following to your book.toml:\n");
138            eprintln!("[output.html]");
139            eprintln!("additional-js = [\"theme/mermaid.min.js\", \"theme/mermaid-init.js\"]");
140        }
141
142        Ok(())
143    }
144
145    fn process_book_for_runtime_mode(&self, mut book: Book) -> Result<Book> {
146        let mermaid_re = Regex::new(r#"```mermaid\r?\n([\s\S]*?)\r?\n```"#)?;
147
148        for item in &mut book.sections {
149            Self::process_book_item_for_runtime_mode(&mermaid_re, item);
150        }
151
152        Ok(book)
153    }
154
155    /// Recursively process book items to convert mermaid blocks to HTML `pre` tags
156    fn process_book_item_for_runtime_mode(mermaid_re: &Regex, book_item: &mut BookItem) {
157        if let BookItem::Chapter(chapter) = book_item {
158            chapter.content = mermaid_re.replace_all(&chapter.content, |caps: &regex::Captures| {
159                let diagram_code = &caps[1];
160                format!("<pre class=\"mermaid\">\n{}\n</pre>", diagram_code)
161            }).to_string();
162
163            for sub_item in &mut chapter.sub_items {
164                Self::process_book_item_for_runtime_mode(mermaid_re, sub_item);
165            }
166        }
167    }
168
169    async fn async_process_book(&self, ctx: &PreprocessorContext, book: &mut Book) -> Result<()> {
170        let mermaid_re = Regex::new(r#"```mermaid\r?\n([\s\S]*?)\r?\n```"#)?;
171
172        let output_dir = ctx.root.join(&ctx.config.book.src).join("generated").join("diagrams");
173        tokio::fs::create_dir_all(&output_dir).await?;
174
175        let num_cpus = std::thread::available_parallelism()
176            .unwrap_or(NonZero::new(1).unwrap())
177            .get();
178        let semaphore = Arc::new(tokio::sync::Semaphore::new(num_cpus));
179
180        // Collect all futures from book items
181        let mut all_futures = Vec::new();
182        for item in &mut book.sections {
183            all_futures.extend(self.collect_edits_from_book_item_recursively(
184                &mermaid_re,
185                item,
186                &output_dir,
187                &semaphore,
188            ));
189        }
190
191        let edits: Vec<ChapterEdit> = futures::future::join_all(all_futures).await.into_iter()
192            .filter_map(|e| match e {
193                Ok(e) => Some(e),
194                Err(e) => {
195                    eprintln!("[mdbook-diagrams] Failed to generate diagram: {}", e);
196                    None
197                }
198            }
199        ).collect();
200
201        // Extract referenced filenames for cleanup
202        let referenced_files: std::collections::HashSet<String> = edits
203            .iter()
204            .map(|edit| edit.cached_filename.clone())
205            .filter(|name| !name.is_empty())
206            .collect();
207
208        // Clean up unreferenced cache files if caching is enabled
209        if self.enable_cache {
210            if let Err(e) = Self::cleanup_unreferenced_files(&output_dir, &referenced_files).await {
211                eprintln!("[mdbook-diagrams] Warning: Failed to clean up cache files: {}", e);
212            }
213        }
214
215        // Group edits by chapter path for easier processing
216        let mut edits_by_chapter: HashMap<PathBuf, Vec<ChapterEdit>> = HashMap::new();
217        for edit in edits {
218            edits_by_chapter.entry(edit.chapter_path.clone()).or_insert_with(Vec::new).push(edit);
219        }
220
221        // Iterate through the book mutably and apply edits recursively
222        for item in &mut book.sections {
223            DiagramsPreprocessor::apply_edits_to_book_item_recursively(item, &mut edits_by_chapter);
224        }
225
226        Ok(())
227    }
228
229    fn apply_edits_to_book_item_recursively(book_item: &mut BookItem, edits_by_chapter: &mut HashMap<PathBuf, Vec<ChapterEdit>>) {
230        if let BookItem::Chapter(chapter) = book_item {
231            let chapter_path = chapter.path.clone().unwrap_or_default();
232            if let Some(chapter_edits) = edits_by_chapter.remove(&chapter_path) {
233                // Sort edits by range start in descending order to avoid offset issues
234                let mut sorted_edits = chapter_edits;
235                sorted_edits.sort_by_key(|e| e.range.start);
236                sorted_edits.reverse();
237
238                for edit in sorted_edits {
239                    // Replace the content using the byte range
240                    chapter.content.replace_range(edit.range, &edit.new_string);
241                }
242            }
243
244            // Recursively apply to sub items
245            for sub_item in &mut chapter.sub_items {
246                DiagramsPreprocessor::apply_edits_to_book_item_recursively(sub_item, edits_by_chapter);
247            }
248        }
249    }
250
251    fn collect_edits_from_book_item_recursively(
252        &'_ self,
253        mermaid_re: & Regex,
254        book_item: & BookItem,
255        output_dir: & PathBuf,
256        semaphore: & Arc<tokio::sync::Semaphore>,
257    ) -> Vec<BoxFuture<'_, Result<ChapterEdit>>> {
258        let mut futures = Vec::new();
259        if let BookItem::Chapter(chapter) = book_item {
260            // Collect edits from a chapter
261            futures.extend(
262                self.collect_edits_from_chapter(mermaid_re, chapter, output_dir, semaphore),
263            );
264
265            // Proceed recursively for sub items
266            for sub_item in &chapter.sub_items {
267                futures.extend(self.collect_edits_from_book_item_recursively(
268                    &mermaid_re, sub_item, &output_dir, &semaphore,
269                ));
270            }
271        }
272        futures
273    }
274
275    /// Generate diagrams for all mermaid blocks in a chapter and return a list of edits to apply.
276    fn collect_edits_from_chapter(
277        &'_ self,
278        mermaid_re: & Regex,
279        chapter: & Chapter,
280        output_dir: & PathBuf,
281        semaphore: & Arc<tokio::sync::Semaphore>,
282    ) -> Vec<BoxFuture<'_, Result<ChapterEdit>>> {
283        let mut futures = Vec::new();
284
285        for cap in mermaid_re.captures_iter(&chapter.content) {
286            let full_match_range = cap.get(0).unwrap().range();
287            let mermaid_code = cap[1].to_string();
288            let original_block = cap.get(0).unwrap().as_str().to_string();
289
290            let cache_hash = Self::compute_cache_hash(&mermaid_code, &self.output_format, &self.mmdc_cmd);
291            let output_filename = format!("{}.{}", cache_hash, self.output_format);
292            let output_filepath = output_dir.join(&output_filename);
293
294            let chapter_path = chapter.path.clone().unwrap_or_default();
295
296            let relative_output_path = {
297                let chapter_dir_relative_to_src = chapter
298                    .path
299                    .as_ref()
300                    .and_then(|p| p.parent())
301                    .unwrap_or_else(|| Path::new(""));
302                let num_parent_dirs = chapter_dir_relative_to_src.components().count();
303
304                let mut path = PathBuf::new();
305                for _ in 0..num_parent_dirs {
306                    path.push("..");
307                }
308                path.push("generated");
309                path.push("diagrams");
310                path.push(&output_filename);
311                path
312            };
313
314            let semaphore_clone = semaphore.clone();
315            let mmdc_cmd = self.mmdc_cmd.clone();
316            let enable_cache = self.enable_cache;
317
318            futures.push(async move {
319                if enable_cache && output_filepath.exists() {
320                    // Cache hit - skip mmdc execution
321                    let img_tag = format!(
322                        "![diagram](./{})",
323                        relative_output_path.to_string_lossy().replace("\\", "/")
324                    );
325                    return Ok(ChapterEdit {
326                        chapter_path,
327                        range: full_match_range,
328                        new_string: img_tag,
329                        cached_filename: output_filename.clone(),
330                    });
331                }
332
333                // Cache miss or caching disabled - generate diagram
334                let result = async {
335                    let _permit = semaphore_clone.acquire().await?;
336                    let mut command = if cfg!(windows) {
337                        let mut cmd = tokio::process::Command::new("powershell");
338                        cmd.arg("-NoProfile")
339                            .arg("-Command")
340                            .arg(&mmdc_cmd)
341                            .arg("-i")
342                            .arg("-")
343                            .arg("-o")
344                            .arg(&output_filepath)
345                            .stdin(std::process::Stdio::piped())
346                            .stdout(std::process::Stdio::piped())
347                            .stderr(std::process::Stdio::piped());
348                        cmd
349                    } else {
350                        let mut cmd = tokio::process::Command::new(&mmdc_cmd);
351                        cmd.arg("-i")
352                            .arg("-")
353                            .arg("-o")
354                            .arg(&output_filepath)
355                            .stdin(std::process::Stdio::piped())
356                            .stdout(std::process::Stdio::piped())
357                            .stderr(std::process::Stdio::piped());
358                        cmd
359                    };
360
361                    let mut child = command.spawn()?;
362
363                    if let Some(mut stdin) = child.stdin.take() {
364                        AsyncWriteExt::write_all(&mut stdin, mermaid_code.as_bytes()).await?;
365                    }
366
367                    let output = child.wait_with_output().await?;
368
369                    if !output.status.success() {
370                        bail!(
371                            "mmdc failed: {}\nStderr: {}",
372                            output.status,
373                            String::from_utf8_lossy(&output.stderr)
374                        );
375                    }
376
377                    Ok::<String, anyhow::Error>(format!(
378                        "![diagram](./{})",
379                        relative_output_path.to_string_lossy().replace("\\", "/")
380                    ))
381                }.await;
382
383                // Handle result - on error, keep the original mermaid block with an error message
384                match result {
385                    Ok(img_tag) => Ok(ChapterEdit {
386                        chapter_path,
387                        range: full_match_range,
388                        new_string: img_tag,
389                        cached_filename: output_filename.clone(),
390                    }),
391                    Err(e) => {
392                        let error_msg = format!("{:#}", e);
393                        eprintln!("[mdbook-diagrams] {}", error_msg);
394
395                        // Keep original mermaid block with error comment
396                        let error_comment = format!(
397                            "<!-- Error generating diagram: {} -->\n{}",
398                            error_msg.lines().next().unwrap_or("Unknown error"),
399                            original_block
400                        );
401                        Ok(ChapterEdit {
402                            chapter_path,
403                            range: full_match_range,
404                            new_string: error_comment,
405                            cached_filename: String::new(),
406                        })
407                    }
408                }
409            }.boxed())
410        }
411        futures
412    }
413
414    /// Remove cache files that are not referenced in the current build
415    async fn cleanup_unreferenced_files(
416        output_dir: &PathBuf,
417        referenced_files: &std::collections::HashSet<String>,
418    ) -> Result<()> {
419        let mut entries = tokio::fs::read_dir(output_dir).await?;
420
421        while let Some(entry) = entries.next_entry().await? {
422            if let Ok(filename) = entry.file_name().into_string() {
423                if !referenced_files.contains(&filename) && !filename.is_empty() {
424                    if let Err(e) = tokio::fs::remove_file(entry.path()).await {
425                        eprintln!(
426                            "[mdbook-diagrams] Warning: Failed to remove unreferenced cache file {}: {}",
427                            filename, e
428                        );
429                    }
430                }
431            }
432        }
433
434        Ok(())
435    }
436}
437
438impl Preprocessor for DiagramsPreprocessor {
439    fn name(&self) -> &str {
440        "mdbook-diagrams"
441    }
442
443    fn run(&self, ctx: &PreprocessorContext, mut book: Book) -> Result<Book> {
444        match self.render_mode {
445            RenderMode::Runtime => {
446                self.prepare_mermaid_files(ctx)?;
447                book = self.process_book_for_runtime_mode(book)?;
448            }
449            RenderMode::PreRender => {
450                let runtime = tokio::runtime::Builder::new_multi_thread()
451                    .enable_all()
452                    .build()?;
453
454                runtime.block_on(self.async_process_book(ctx, &mut book))?;
455            }
456        }
457
458        Ok(book)
459    }
460}