use crate::error::MarkdownError;
use comrak::nodes::{NodeHtmlBlock, NodeValue};
use regex::Regex;
use std::cell::RefCell;
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::LazyLock;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Heading {
pub level: u8,
pub text: String,
pub id: String,
}
pub fn collect_headings<'a>(
root: comrak::nodes::Node<'a>,
prefix: Option<&str>,
) -> Vec<Heading> {
let mut anchorizer = comrak::Anchorizer::new();
let mut out = Vec::new();
for node in root.descendants() {
let level = match node.data.borrow().value {
NodeValue::Heading(h) => h.level,
_ => continue,
};
let text = extract_text(node);
let slug = anchorizer.anchorize(&text);
let id = match prefix {
Some(p) if !p.is_empty() => format!("{p}{slug}"),
_ => slug,
};
out.push(Heading { level, text, id });
}
out
}
pub fn collect_all_text<'a>(root: comrak::nodes::Node<'a>) -> String {
let mut buf = String::new();
for d in root.descendants() {
match &d.data.borrow().value {
NodeValue::Text(t) => buf.push_str(t),
NodeValue::Code(c) => buf.push_str(&c.literal),
NodeValue::CodeBlock(cb) => {
if !buf.is_empty() && !buf.ends_with(' ') {
buf.push(' ');
}
buf.push_str(&cb.literal);
}
NodeValue::SoftBreak | NodeValue::LineBreak
if !buf.is_empty() && !buf.ends_with(' ') =>
{
buf.push(' ');
}
NodeValue::Paragraph
| NodeValue::Heading(_)
| NodeValue::Item(_)
| NodeValue::BlockQuote
| NodeValue::Table(_)
| NodeValue::TableRow(_)
| NodeValue::TableCell
if !buf.is_empty() && !buf.ends_with(' ') =>
{
buf.push(' ');
}
_ => {}
}
}
buf.trim().to_string()
}
fn extract_text<'a>(node: comrak::nodes::Node<'a>) -> String {
let mut buf = String::new();
for d in node.descendants() {
match &d.data.borrow().value {
NodeValue::Text(t) => buf.push_str(t),
NodeValue::Code(c) => buf.push_str(&c.literal),
NodeValue::Image(img) => buf.push_str(&img.title),
_ => {}
}
}
buf
}
static TABLE_CELL_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"<td([^>]*)>").unwrap());
static CUSTOM_BLOCK_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(
r#"(?si)<div\s+class=["']?(note|warning|tip|info|important|caution)["']?>(.*?)</div>"#,
)
.unwrap()
});
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ColumnAlignment {
Left,
Center,
Right,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CustomBlockType {
Note,
Warning,
Tip,
Info,
Important,
Caution,
}
impl CustomBlockType {
pub fn default_alert_class(&self) -> &'static str {
match self {
Self::Note => "alert-info",
Self::Warning => "alert-warning",
Self::Tip => "alert-success",
Self::Info => "alert-primary",
Self::Important => "alert-danger",
Self::Caution => "alert-secondary",
}
}
pub fn default_title(&self) -> &'static str {
match self {
Self::Note => "Note",
Self::Warning => "Warning",
Self::Tip => "Tip",
Self::Info => "Info",
Self::Important => "Important",
Self::Caution => "Caution",
}
}
pub fn get_alert_class(&self) -> &'static str {
self.default_alert_class()
}
pub fn get_title(&self) -> &'static str {
self.default_title()
}
pub fn alert_class_with<'a>(
&self,
config: &'a CustomBlockConfig,
) -> &'a str {
config
.class_overrides
.get(self)
.map(|s| s.as_str())
.unwrap_or_else(move || self.default_alert_class())
}
pub fn title_with<'a>(
&self,
config: &'a CustomBlockConfig,
) -> &'a str {
config
.title_overrides
.get(self)
.map(|s| s.as_str())
.unwrap_or_else(move || self.default_title())
}
}
impl FromStr for CustomBlockType {
type Err = MarkdownError;
fn from_str(block_type: &str) -> Result<Self, Self::Err> {
match block_type.to_lowercase().as_str() {
"note" => Ok(Self::Note),
"warning" => Ok(Self::Warning),
"tip" => Ok(Self::Tip),
"info" => Ok(Self::Info),
"important" => Ok(Self::Important),
"caution" => Ok(Self::Caution),
_ => Err(MarkdownError::CustomBlockError(format!(
"Unknown block type: {block_type}"
))),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CustomBlockConfig {
pub class_overrides: HashMap<CustomBlockType, String>,
pub title_overrides: HashMap<CustomBlockType, String>,
}
impl CustomBlockConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_class(
mut self,
block_type: CustomBlockType,
class: impl Into<String>,
) -> Self {
self.class_overrides.insert(block_type, class.into());
self
}
pub fn with_title(
mut self,
block_type: CustomBlockType,
title: impl Into<String>,
) -> Self {
self.title_overrides.insert(block_type, title.into());
self
}
}
pub fn process_custom_block_nodes<'a>(
root: comrak::nodes::Node<'a>,
config: &CustomBlockConfig,
) {
for node in root.descendants() {
let mut ast = node.data.borrow_mut();
if let NodeValue::HtmlBlock(ref mut block) = ast.value {
block.literal =
transform_custom_blocks(&block.literal, config);
}
}
}
fn transform_custom_blocks(
html: &str,
config: &CustomBlockConfig,
) -> String {
CUSTOM_BLOCK_RE
.replace_all(html, |caps: ®ex::Captures| {
let block_type = CustomBlockType::from_str(
caps.get(1).unwrap().as_str(),
)
.expect("regex only matches known block types");
generate_custom_block_html(block_type, &caps[2], config)
})
.to_string()
}
fn generate_custom_block_html(
block_type: CustomBlockType,
content: &str,
config: &CustomBlockConfig,
) -> String {
format!(
r#"<div class="alert {}" role="alert"><strong>{}:</strong> {}</div>"#,
block_type.alert_class_with(config),
block_type.title_with(config),
content
)
}
pub fn enhance_table_nodes<'a>(
root: comrak::nodes::Node<'a>,
arena: &'a comrak::Arena<'a>,
options: &comrak::Options,
) {
let table_nodes: Vec<comrak::nodes::Node<'a>> = root
.descendants()
.filter(|node| {
matches!(node.data.borrow().value, NodeValue::Table(_))
})
.collect();
for table_node in table_nodes {
let mut table_html = String::new();
if comrak::format_html(table_node, options, &mut table_html)
.is_err()
{
continue;
}
let enhanced = process_tables(&table_html);
let start = comrak::nodes::LineColumn { line: 0, column: 0 };
let replacement = arena.alloc(comrak::nodes::AstNode::new(
RefCell::new(comrak::nodes::Ast::new(
NodeValue::HtmlBlock(NodeHtmlBlock {
block_type: 6, literal: enhanced,
}),
start,
)),
));
table_node.insert_before(replacement);
table_node.detach();
}
}
pub fn process_custom_blocks(content: &str) -> String {
transform_custom_blocks(content, &CustomBlockConfig::default())
}
pub fn process_tables(table_html: &str) -> String {
let table_html = table_html.replace(
"<table>",
r#"<div class="table-responsive"><table class="table">"#,
);
let table_html = table_html.replace("</table>", "</table></div>");
TABLE_CELL_RE
.replace_all(&table_html, |caps: ®ex::Captures| {
let attrs = &caps[1];
if attrs.contains("align=\"center\"") {
format!(r#"<td{attrs} class="text-center">"#)
} else if attrs.contains("align=\"right\"") {
format!(r#"<td{attrs} class="text-right">"#)
} else {
format!(r#"<td{attrs} class="text-left">"#)
}
})
.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_process_custom_blocks_default_config() {
let input = r#"
<div class="note">This is a note.</div>
<div class="WARNING">This is a warning.</div>
<div class="Tip">This is a tip.</div>
"#;
let processed = process_custom_blocks(input);
assert!(processed.contains(r#"alert alert-info"#));
assert!(processed.contains(r#"alert alert-warning"#));
assert!(processed.contains(r#"alert alert-success"#));
}
#[test]
fn test_custom_block_config_overrides() {
let config = CustomBlockConfig::new()
.with_class(CustomBlockType::Note, "callout-info")
.with_title(CustomBlockType::Note, "Did you know?");
let html = generate_custom_block_html(
CustomBlockType::Note,
"test content",
&config,
);
assert!(html.contains("callout-info"));
assert!(html.contains("Did you know?:"));
}
#[test]
fn test_unknown_block_passthrough() {
let input =
r#"<div class="unknown">Should pass through.</div>"#;
let processed = process_custom_blocks(input);
assert_eq!(processed, input);
}
#[test]
fn test_process_tables() {
let input = r#"<table><tr><td align="center">Center</td><td align="right">Right</td><td>Left</td></tr></table>"#;
let processed = process_tables(input);
assert!(processed.contains(r#"table-responsive"#));
assert!(processed.contains(r#"text-center"#));
assert!(processed.contains(r#"text-right"#));
assert!(processed.contains(r#"text-left"#));
}
#[test]
fn test_process_multiple_tables() {
let input = "<table><tr><td>A</td></tr></table>\n<table><tr><td>B</td></tr></table>";
let processed = process_tables(input);
assert_eq!(processed.matches("table-responsive").count(), 2);
}
#[test]
fn test_unknown_block_type_from_str() {
let result = CustomBlockType::from_str("unknown");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("Unknown block type: unknown"),
"Error message should contain the unknown type"
);
}
#[test]
fn test_unknown_block_type_from_str_various() {
for name in ["foobar", "alert", "danger", "success", ""] {
let result = CustomBlockType::from_str(name);
assert!(
result.is_err(),
"'{name}' should not parse as a valid block type"
);
}
}
}