use crate::mathbox::MathBox;
use crate::unicode_maps::{get_greek, get_symbol, to_subscript, to_superscript, BRACKETS};
use latex2mathml::{latex_to_mathml, DisplayStyle};
use roxmltree::{Document, Node};
use std::fmt;
#[derive(Debug)]
pub enum RenderError {
LatexConversion(String),
MathMLParse(String),
InvalidStructure(String),
}
impl fmt::Display for RenderError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RenderError::LatexConversion(e) => write!(f, "LaTeX conversion error: {}", e),
RenderError::MathMLParse(e) => write!(f, "MathML parse error: {}", e),
RenderError::InvalidStructure(e) => write!(f, "Invalid math structure: {}", e),
}
}
}
impl std::error::Error for RenderError {}
pub struct MathRenderer {
use_unicode_scripts: bool,
}
impl MathRenderer {
pub fn new() -> Self {
Self {
use_unicode_scripts: true,
}
}
pub fn use_unicode_scripts(mut self, use_unicode: bool) -> Self {
self.use_unicode_scripts = use_unicode;
self
}
pub fn render_latex(&self, latex: &str) -> Result<String, RenderError> {
let mathml = latex_to_mathml(latex, DisplayStyle::Inline)
.map_err(|e| RenderError::LatexConversion(e.to_string()))?;
self.render_mathml(&mathml)
}
pub fn render_mathml(&self, mathml: &str) -> Result<String, RenderError> {
let doc = Document::parse(mathml)
.map_err(|e| RenderError::MathMLParse(e.to_string()))?;
let root = doc.root_element();
let math_box = self.process_element(&root)?;
Ok(math_box.to_string())
}
pub fn render_to_box(&self, latex: &str) -> Result<MathBox, RenderError> {
let mathml = latex_to_mathml(latex, DisplayStyle::Inline)
.map_err(|e| RenderError::LatexConversion(e.to_string()))?;
let doc = Document::parse(&mathml)
.map_err(|e| RenderError::MathMLParse(e.to_string()))?;
let root = doc.root_element();
self.process_element(&root)
}
fn process_element(&self, node: &Node) -> Result<MathBox, RenderError> {
let tag = node.tag_name().name();
match tag {
"math" | "mrow" | "mstyle" | "mpadded" | "mphantom" => {
self.process_row(node)
}
"mi" | "mn" | "mtext" => {
self.process_text(node)
}
"mo" => {
self.process_operator(node)
}
"msup" => {
self.process_superscript(node)
}
"msub" => {
self.process_subscript(node)
}
"msubsup" => {
self.process_subsup(node)
}
"mfrac" => {
self.process_fraction(node)
}
"msqrt" => {
self.process_sqrt(node)
}
"mroot" => {
self.process_nthroot(node)
}
"mover" => {
self.process_over(node)
}
"munder" => {
self.process_under(node)
}
"munderover" => {
self.process_underover(node)
}
"mtable" => {
self.process_table(node)
}
"mtr" => {
self.process_table_row(node)
}
"mtd" => {
self.process_row(node)
}
"mfenced" => {
self.process_fenced(node)
}
"menclose" => {
self.process_row(node) }
"mspace" => {
Ok(MathBox::from_text(" "))
}
"semantics" => {
if let Some(child) = node.children().filter(|n| n.is_element()).next() {
self.process_element(&child)
} else {
Ok(MathBox::empty(0, 1, 0))
}
}
"annotation" | "annotation-xml" => {
Ok(MathBox::empty(0, 1, 0))
}
_ => {
self.process_row(node)
}
}
}
fn process_row(&self, node: &Node) -> Result<MathBox, RenderError> {
self.process_row_inner(node, true)
}
fn process_row_compact(&self, node: &Node) -> Result<MathBox, RenderError> {
self.process_row_inner(node, false)
}
fn process_row_inner(&self, node: &Node, add_spacing: bool) -> Result<MathBox, RenderError> {
let child_nodes: Vec<_> = node.children().filter(|n| n.is_element()).collect();
if child_nodes.is_empty() {
let text = self.get_text_content(node);
if !text.is_empty() {
return Ok(MathBox::from_text(&text));
}
return Ok(MathBox::empty(0, 1, 0));
}
let mut boxes = Vec::new();
let mut prev_multiline = false;
for (i, child) in child_nodes.iter().enumerate() {
let child_box = self.process_element(child)?;
let is_multiline = child_box.height > 1;
if add_spacing && i > 0 && (prev_multiline || is_multiline) {
boxes.push(MathBox::from_text(" "));
}
if add_spacing && child.tag_name().name() == "mo" {
let op = self.get_text_content(child);
let is_first = i == 0;
let is_binary_op = !is_first && matches!(op.as_str(), "+" | "-" | "±" | "∓");
let is_relation = matches!(
op.as_str(),
"=" | "≤" | "≥" | "≠" | "≈" | "≡" | "→" | "⇒" | "⟹" | "×" | "÷" | "·"
);
if is_binary_op || is_relation {
if !prev_multiline && !is_multiline {
boxes.push(MathBox::from_text(" "));
}
boxes.push(child_box);
boxes.push(MathBox::from_text(" "));
prev_multiline = is_multiline;
continue;
}
}
boxes.push(child_box);
prev_multiline = is_multiline;
}
Ok(MathBox::concat_horizontal(&boxes))
}
fn process_text(&self, node: &Node) -> Result<MathBox, RenderError> {
let text = self.get_text_content(node);
if let Some(greek) = get_greek(&text) {
return Ok(MathBox::from_text(&greek.to_string()));
}
Ok(MathBox::from_text(&text))
}
fn process_operator(&self, node: &Node) -> Result<MathBox, RenderError> {
let text = self.get_text_content(node);
let rendered = match text.as_str() {
"∑" | "∏" | "∫" | "∬" | "∭" | "∮" | "⋃" | "⋂" => {
text
}
_ => {
if text.starts_with('\\') {
let cmd = &text[1..];
if let Some(sym) = get_symbol(cmd) {
sym.to_string()
} else if let Some(greek) = get_greek(cmd) {
greek.to_string()
} else {
text
}
} else {
text
}
}
};
Ok(MathBox::from_text(&rendered))
}
fn process_superscript(&self, node: &Node) -> Result<MathBox, RenderError> {
let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
if children.len() != 2 {
return Err(RenderError::InvalidStructure(
"msup requires exactly 2 children".to_string(),
));
}
let base = self.process_element(&children[0])?;
let sup = if children[1].tag_name().name() == "mrow" {
self.process_row_compact(&children[1])?
} else {
self.process_element(&children[1])?
};
if self.use_unicode_scripts && base.height == 1 && sup.height == 1 {
let sup_text = sup.to_string();
if let Some(unicode_sup) = to_superscript(sup_text.trim()) {
let combined = format!("{}{}", base.to_string(), unicode_sup);
return Ok(MathBox::from_text(&combined));
}
}
let width = base.width + sup.width;
let height = base.height + 1;
let mut result = MathBox::empty(width, height, base.baseline + 1);
result.blit(&base, 0, 1);
result.blit(&sup, base.width, 0);
Ok(result)
}
fn process_subscript(&self, node: &Node) -> Result<MathBox, RenderError> {
let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
if children.len() != 2 {
return Err(RenderError::InvalidStructure(
"msub requires exactly 2 children".to_string(),
));
}
let base = self.process_element(&children[0])?;
let sub = if children[1].tag_name().name() == "mrow" {
self.process_row_compact(&children[1])?
} else {
self.process_element(&children[1])?
};
if self.use_unicode_scripts && base.height == 1 && sub.height == 1 {
let sub_text = sub.to_string();
if let Some(unicode_sub) = to_subscript(sub_text.trim()) {
let combined = format!("{}{}", base.to_string(), unicode_sub);
return Ok(MathBox::from_text(&combined));
}
}
let width = base.width + sub.width;
let height = base.height + 1;
let mut result = MathBox::empty(width, height, base.baseline);
result.blit(&base, 0, 0);
result.blit(&sub, base.width, base.height);
Ok(result)
}
fn process_subsup(&self, node: &Node) -> Result<MathBox, RenderError> {
let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
if children.len() != 3 {
return Err(RenderError::InvalidStructure(
"msubsup requires exactly 3 children".to_string(),
));
}
let base_text = self.get_text_content(&children[0]);
let is_big_operator = matches!(
base_text.as_str(),
"∫" | "∬" | "∭" | "∮" | "∑" | "∏" | "⋃" | "⋂"
);
let base = self.process_element(&children[0])?;
let sub = self.process_element(&children[1])?;
let sup = self.process_element(&children[2])?;
if is_big_operator {
return Ok(MathBox::stack_vertical(&[sup, base, sub]));
}
if self.use_unicode_scripts && base.height == 1 && sub.height == 1 && sup.height == 1 {
let sub_text = sub.to_string();
let sup_text = sup.to_string();
if let (Some(unicode_sub), Some(unicode_sup)) =
(to_subscript(sub_text.trim()), to_superscript(sup_text.trim()))
{
let combined = format!("{}{}{}", base.to_string(), unicode_sub, unicode_sup);
return Ok(MathBox::from_text(&combined));
}
}
let script_width = sub.width.max(sup.width);
let width = base.width + script_width;
let height = base.height + 2;
let mut result = MathBox::empty(width, height, base.baseline + 1);
result.blit(&base, 0, 1);
result.blit(&sup, base.width, 0);
result.blit(&sub, base.width, height - 1);
Ok(result)
}
fn process_fraction(&self, node: &Node) -> Result<MathBox, RenderError> {
let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
if children.len() != 2 {
return Err(RenderError::InvalidStructure(
"mfrac requires exactly 2 children".to_string(),
));
}
let num = self.process_element(&children[0])?;
let den = self.process_element(&children[1])?;
let width = num.width.max(den.width);
let height = num.height + 1 + den.height;
let baseline = num.height;
let mut result = MathBox::empty(width, height, baseline);
let num_offset = (width - num.width) / 2;
result.blit(&num, num_offset, 0);
result.fill_row(num.height, '─');
let den_offset = (width - den.width) / 2;
result.blit(&den, den_offset, num.height + 1);
Ok(result)
}
fn process_sqrt(&self, node: &Node) -> Result<MathBox, RenderError> {
let inner = self.process_row(node)?;
if inner.height == 1 {
let inner_text = inner.to_string();
let inner_width = inner_text.chars().count();
let width = 1 + inner_width;
let height = 2;
let mut result = MathBox::empty(width, height, 1);
for x in 1..width {
result.set(x, 0, '_');
}
result.set(0, 1, '√');
for (i, ch) in inner_text.chars().enumerate() {
result.set(1 + i, 1, ch);
}
return Ok(result);
}
let width = inner.width + 1;
let height = inner.height + 1;
let mut result = MathBox::empty(width, height, inner.baseline + 1);
for x in 1..width {
result.set(x, 0, '_');
}
result.set(0, 1, '√');
result.blit(&inner, 1, 1);
Ok(result)
}
fn process_nthroot(&self, node: &Node) -> Result<MathBox, RenderError> {
let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
if children.len() != 2 {
return Err(RenderError::InvalidStructure(
"mroot requires exactly 2 children".to_string(),
));
}
let inner = self.process_element(&children[0])?;
let index = self.process_element(&children[1])?;
let index_text = index.to_string();
if let Some(unicode_idx) = to_superscript(index_text.trim()) {
let text = format!("{}√{}", unicode_idx, inner.to_string());
return Ok(MathBox::from_text(&text));
}
let width = index.width + inner.width + 2;
let height = (inner.height + 1).max(index.height);
let mut result = MathBox::empty(width, height, height / 2);
result.blit(&index, 0, 0);
result.set(index.width, height - 1, '√');
for x in (index.width + 1)..width {
result.set(x, 0, '─');
}
result.blit(&inner, index.width + 2, 1);
Ok(result)
}
fn process_over(&self, node: &Node) -> Result<MathBox, RenderError> {
let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
if children.len() != 2 {
return Err(RenderError::InvalidStructure(
"mover requires exactly 2 children".to_string(),
));
}
let base = self.process_element(&children[0])?;
let over = self.process_element(&children[1])?;
let over_text = over.to_string().trim().to_string();
if base.height == 1 {
let accent = match over_text.as_str() {
"^" | "ˆ" => Some("̂"), "~" | "˜" => Some("̃"), "¯" | "-" => Some("̄"), "." => Some("̇"), ".." | "¨" => Some("̈"), "→" => Some("⃗"), _ => None,
};
if let Some(combining) = accent {
let base_text = base.to_string();
let text = format!("{}{}", base_text, combining);
return Ok(MathBox::from_text(&text));
}
}
Ok(MathBox::stack_vertical(&[over, base]))
}
fn process_under(&self, node: &Node) -> Result<MathBox, RenderError> {
let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
if children.len() != 2 {
return Err(RenderError::InvalidStructure(
"munder requires exactly 2 children".to_string(),
));
}
let base_text = self.get_text_content(&children[0]);
let base = self.process_element(&children[0])?;
let under = self.process_element(&children[1])?;
if base_text == "lim" || base_text == "max" || base_text == "min" || base_text == "sup" || base_text == "inf" {
let under_text = under.to_string();
let under_trimmed = under_text.trim();
if let Some(subscript) = to_subscript(under_trimmed) {
let combined = format!("{}{}", base_text, subscript);
return Ok(MathBox::from_text(&combined));
}
let combined = format!("{}({})", base_text, under_trimmed);
return Ok(MathBox::from_text(&combined));
}
let width = base.width.max(under.width);
let height = base.height + under.height;
let mut result = MathBox::empty(width, height, base.baseline);
let base_offset = (width - base.width) / 2;
result.blit(&base, base_offset, 0);
let under_offset = (width - under.width) / 2;
result.blit(&under, under_offset, base.height);
Ok(result)
}
fn process_underover(&self, node: &Node) -> Result<MathBox, RenderError> {
let children: Vec<_> = node.children().filter(|n| n.is_element()).collect();
if children.len() != 3 {
return Err(RenderError::InvalidStructure(
"munderover requires exactly 3 children".to_string(),
));
}
let base = self.process_element(&children[0])?;
let under = self.process_element(&children[1])?;
let over = self.process_element(&children[2])?;
Ok(MathBox::stack_vertical(&[over, base, under]))
}
fn process_table(&self, node: &Node) -> Result<MathBox, RenderError> {
let rows: Vec<Vec<MathBox>> = node
.children()
.filter(|n| n.is_element() && n.tag_name().name() == "mtr")
.map(|row| {
row.children()
.filter(|n| n.is_element() && n.tag_name().name() == "mtd")
.map(|cell| self.process_row(&cell))
.collect::<Result<Vec<_>, _>>()
})
.collect::<Result<Vec<_>, _>>()?;
if rows.is_empty() {
return Ok(MathBox::empty(0, 1, 0));
}
let num_cols = rows.iter().map(|r| r.len()).max().unwrap_or(0);
let mut col_widths = vec![0; num_cols];
let mut row_heights = vec![0; rows.len()];
for (i, row) in rows.iter().enumerate() {
for (j, cell) in row.iter().enumerate() {
col_widths[j] = col_widths[j].max(cell.width);
row_heights[i] = row_heights[i].max(cell.height);
}
}
let spacing = 2;
let total_width: usize = col_widths.iter().sum::<usize>() + spacing * (num_cols.saturating_sub(1));
let total_height: usize = row_heights.iter().sum();
let mut result = MathBox::empty(total_width, total_height, total_height / 2);
let mut y_pos = 0;
for (i, row) in rows.iter().enumerate() {
let mut x_pos = 0;
for (j, cell) in row.iter().enumerate() {
let x_offset = (col_widths[j] - cell.width) / 2;
result.blit(cell, x_pos + x_offset, y_pos);
x_pos += col_widths[j] + spacing;
}
y_pos += row_heights[i];
}
Ok(result)
}
fn process_table_row(&self, node: &Node) -> Result<MathBox, RenderError> {
let cells: Vec<MathBox> = node
.children()
.filter(|n| n.is_element())
.map(|n| self.process_row(&n))
.collect::<Result<Vec<_>, _>>()?;
let spacing = MathBox::from_text(" ");
let mut parts = Vec::new();
for (i, cell) in cells.into_iter().enumerate() {
if i > 0 {
parts.push(spacing.clone());
}
parts.push(cell);
}
Ok(MathBox::concat_horizontal(&parts))
}
fn process_fenced(&self, node: &Node) -> Result<MathBox, RenderError> {
let open = node.attribute("open").unwrap_or("(");
let close = node.attribute("close").unwrap_or(")");
let inner = self.process_row(node)?;
if inner.height <= 1 {
let text = format!("{}{}{}", open, inner.to_string(), close);
return Ok(MathBox::from_text(&text));
}
let left_chars = BRACKETS.get_left(open, inner.height);
let right_chars = BRACKETS.get_right(close, inner.height);
let width = 1 + inner.width + 1;
let height = inner.height;
let mut result = MathBox::empty(width, height, inner.baseline);
for (y, &ch) in left_chars.iter().enumerate() {
result.set(0, y, ch);
}
for (y, &ch) in right_chars.iter().enumerate() {
result.set(width - 1, y, ch);
}
result.blit(&inner, 1, 0);
Ok(result)
}
fn get_text_content(&self, node: &Node) -> String {
let mut text = String::new();
for child in node.children() {
if child.is_text() {
text.push_str(child.text().unwrap_or(""));
}
}
text.trim().to_string()
}
}
impl Default for MathRenderer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_expression() {
let renderer = MathRenderer::new();
let result = renderer.render_latex("x + y").unwrap();
assert!(result.contains('x'));
assert!(result.contains('y'));
}
#[test]
fn test_superscript() {
let renderer = MathRenderer::new();
let result = renderer.render_latex("x^2").unwrap();
assert!(result.contains('²') || result.contains('2'));
}
#[test]
fn test_fraction() {
let renderer = MathRenderer::new();
let result = renderer.render_latex(r"\frac{a}{b}").unwrap();
assert!(result.contains('a'));
assert!(result.contains('b'));
assert!(result.contains('─'));
}
}