use crate::core::OCRError;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
type TableDecodeArtifacts = (Vec<String>, Vec<[f32; 8]>, f32);
type TableDecodeResult = Result<TableDecodeArtifacts, OCRError>;
pub fn wrap_table_html(tokens: &[String]) -> String {
render_table_html(tokens, None)
}
pub fn wrap_table_html_with_content(tokens: &[String], cell_texts: &[Option<String>]) -> String {
render_table_html(tokens, Some(cell_texts))
}
fn render_table_html(tokens: &[String], cell_texts: Option<&[Option<String>]>) -> String {
let mut result = Vec::new();
let mut td_index = 0;
let mut idx = 0usize;
result.push("<html><body>".to_string());
let has_table_tag = tokens
.first()
.map(|t| t.contains("<table"))
.unwrap_or(false);
if !has_table_tag {
result.push("<table>".to_string());
}
while idx < tokens.len() {
let tag = tokens[idx].as_str();
if tag == "<td></td>" {
result.push("<td>".to_string());
if let Some(texts) = cell_texts
&& let Some(Some(text)) = texts.get(td_index)
{
result.push(text.clone());
}
result.push("</td>".to_string());
td_index += 1;
idx += 1;
continue;
}
if tag.starts_with("<td") {
let parsed = parse_td_tag(tokens, idx);
result.push(format!("<td{}>", parsed.attrs));
let mut is_bold = false;
let next_idx = parsed.next_index;
if next_idx < tokens.len() && tokens[next_idx] == "<b>" {
is_bold = true;
}
if let Some(texts) = cell_texts
&& let Some(Some(text)) = texts.get(td_index)
{
if is_bold {
result.push("<b>".to_string());
}
result.push(text.clone());
if is_bold {
result.push("</b>".to_string());
}
}
result.push("</td>".to_string());
td_index += 1;
idx = parsed.next_index;
continue;
}
result.push(tokens[idx].clone());
idx += 1;
}
if !has_table_tag {
result.push("</table>".to_string());
}
result.push("</body></html>".to_string());
result.join("")
}
#[derive(Debug, Clone, Default)]
pub struct CellGridInfo {
pub row: usize,
pub col: usize,
pub row_span: usize,
pub col_span: usize,
}
pub fn parse_cell_grid_info(tokens: &[String]) -> Vec<CellGridInfo> {
let mut cells = Vec::new();
let mut current_row: usize = 0;
let mut current_col: usize = 0;
let mut idx = 0usize;
let mut occupied: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
while idx < tokens.len() {
let token = tokens[idx].as_str();
if token == "<tr>" {
current_col = 0;
while occupied.contains(&(current_row, current_col)) {
current_col += 1;
}
idx += 1;
continue;
}
if token == "</tr>" {
current_row += 1;
idx += 1;
continue;
}
if token == "<td></td>" {
while occupied.contains(&(current_row, current_col)) {
current_col += 1;
}
cells.push(CellGridInfo {
row: current_row,
col: current_col,
row_span: 1,
col_span: 1,
});
current_col += 1;
idx += 1;
continue;
}
if token.starts_with("<td") {
let parsed = parse_td_tag(tokens, idx);
while occupied.contains(&(current_row, current_col)) {
current_col += 1;
}
cells.push(CellGridInfo {
row: current_row,
col: current_col,
row_span: parsed.row_span,
col_span: parsed.col_span,
});
if parsed.row_span > 1 {
for r in 1..parsed.row_span {
for c in 0..parsed.col_span {
occupied.insert((current_row + r, current_col + c));
}
}
}
current_col += parsed.col_span;
idx = parsed.next_index;
continue;
}
idx += 1;
}
cells
}
fn parse_span_attr(token: &str, attr: &str) -> Option<usize> {
let pattern = format!("{}=\"", attr);
if let Some(start) = token.find(&pattern) {
let value_start = start + pattern.len();
if let Some(end) = token[value_start..].find('"')
&& let Ok(value) = token[value_start..value_start + end].parse::<usize>()
{
return Some(value);
}
}
None
}
#[derive(Debug, Clone)]
struct ParsedTdTag {
attrs: String,
row_span: usize,
col_span: usize,
next_index: usize,
}
fn parse_td_tag(tokens: &[String], start_idx: usize) -> ParsedTdTag {
let mut attrs = String::new();
let mut col_span = 1usize;
let mut row_span = 1usize;
if let Some(start_token) = tokens.get(start_idx)
&& let Some(stripped) = start_token.strip_prefix("<td")
&& let Some(before_gt) = stripped.split('>').next()
&& !before_gt.is_empty()
{
attrs.push_str(before_gt);
if let Some(v) = parse_span_attr(before_gt, "colspan") {
col_span = v;
}
if let Some(v) = parse_span_attr(before_gt, "rowspan") {
row_span = v;
}
}
let mut idx = start_idx + 1;
while idx < tokens.len() {
let token = tokens[idx].as_str();
if token == ">"
|| token == "</td>"
|| token.starts_with("<td")
|| token == "<tr>"
|| token == "</tr>"
{
break;
}
attrs.push_str(token);
if let Some(v) = parse_span_attr(token, "colspan") {
col_span = v;
}
if let Some(v) = parse_span_attr(token, "rowspan") {
row_span = v;
}
idx += 1;
}
let mut next_index = idx;
while next_index < tokens.len() {
let token = tokens[next_index].as_str();
if token == "</td>" {
next_index += 1;
break;
}
if token.starts_with("<td") || token == "<tr>" || token == "</tr>" {
break;
}
next_index += 1;
}
ParsedTdTag {
attrs,
row_span,
col_span,
next_index: next_index.max(start_idx + 1),
}
}
#[derive(Debug, Clone)]
pub struct TableStructureDecodeOutput {
pub structure_tokens: Vec<Vec<String>>,
pub bboxes: Vec<Vec<[f32; 8]>>,
pub structure_scores: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct TableStructureDecode {
character_dict: Vec<String>,
ignored_tokens: Vec<usize>,
td_token_indices: Vec<usize>,
end_idx: usize,
}
impl TableStructureDecode {
pub fn from_dict_path(dict_path: &Path) -> Result<Self, OCRError> {
let mut character_dict = Self::load_dict(dict_path)?;
let merge_no_span_structure = true;
if merge_no_span_structure {
if !character_dict.contains(&"<td></td>".to_string()) {
character_dict.push("<td></td>".to_string());
}
if let Some(pos) = character_dict.iter().position(|s| s == "<td>") {
character_dict.remove(pos);
}
}
let beg_str = "sos";
let end_str = "eos";
let original_dict_size = character_dict.len();
let mut final_dict = Vec::with_capacity(original_dict_size + 2);
final_dict.push(beg_str.to_string()); final_dict.extend(character_dict); final_dict.push(end_str.to_string());
tracing::debug!("Dictionary processing complete:");
tracing::debug!(" Original dict size: {}", original_dict_size);
tracing::debug!(" Final dict size: {}", final_dict.len());
tracing::debug!(
" First 10 dict entries: {:?}",
&final_dict[..10.min(final_dict.len())]
);
tracing::debug!(
" Last 10 dict entries: {:?}",
&final_dict[final_dict.len().saturating_sub(10)..]
);
let start_idx = 0; let end_idx = final_dict.len() - 1;
let ignored_tokens = vec![start_idx, end_idx];
let td_tokens = ["<td>", "<td", "<td></td>"];
let td_token_indices: Vec<usize> = td_tokens
.iter()
.filter_map(|&token| final_dict.iter().position(|s| s == token))
.collect();
tracing::debug!("TD token indices: {:?}", td_token_indices);
tracing::debug!(
"Ignored tokens (sos={}, eos={}): {:?}",
start_idx,
end_idx,
ignored_tokens
);
Ok(Self {
character_dict: final_dict,
ignored_tokens,
td_token_indices,
end_idx,
})
}
fn load_dict(path: &Path) -> Result<Vec<String>, OCRError> {
let file = File::open(path).map_err(|e| OCRError::ConfigError {
message: format!("Failed to open dictionary file '{}': {}", path.display(), e),
})?;
let reader = BufReader::new(file);
let mut dict = Vec::new();
for line in reader.lines() {
let line = line.map_err(|e| OCRError::ConfigError {
message: format!("Failed to read dictionary line: {}", e),
})?;
let trimmed = line.trim_end();
if !trimmed.is_empty() {
dict.push(trimmed.to_string());
}
}
Ok(dict)
}
pub fn decode(
&self,
structure_logits: &ndarray::Array3<f32>,
bbox_preds: &ndarray::Array3<f32>,
shape_info: &[[f32; 6]],
) -> Result<TableStructureDecodeOutput, OCRError> {
let batch_size = structure_logits.shape()[0];
let mut structure_tokens_batch = Vec::with_capacity(batch_size);
let mut bboxes_batch = Vec::with_capacity(batch_size);
let mut scores_batch = Vec::with_capacity(batch_size);
for batch_idx in 0..batch_size {
let (tokens, bboxes, score) =
self.decode_single(structure_logits, bbox_preds, batch_idx, shape_info)?;
structure_tokens_batch.push(tokens);
bboxes_batch.push(bboxes);
scores_batch.push(score);
}
Ok(TableStructureDecodeOutput {
structure_tokens: structure_tokens_batch,
bboxes: bboxes_batch,
structure_scores: scores_batch,
})
}
fn decode_single(
&self,
structure_logits: &ndarray::Array3<f32>,
bbox_preds: &ndarray::Array3<f32>,
batch_idx: usize,
shape_info: &[[f32; 6]],
) -> TableDecodeResult {
let seq_len = structure_logits.shape()[1];
let mut structure_tokens = Vec::new();
let mut bboxes = Vec::new();
let mut scores = Vec::new();
tracing::debug!(
"Starting token decoding for batch {}, sequence length {}",
batch_idx,
seq_len
);
tracing::debug!("Structure logits shape: {:?}", structure_logits.shape());
tracing::debug!("Bbox preds shape: {:?}", bbox_preds.shape());
for seq_idx in 0..seq_len {
let (token_idx, token_prob) = self.argmax_at(structure_logits, batch_idx, seq_idx);
if seq_idx > 0 && token_idx == self.end_idx {
tracing::debug!(
"Stopping at end token (idx: {}) at sequence position {}",
token_idx,
seq_idx
);
break;
}
if self.ignored_tokens.contains(&token_idx) {
tracing::debug!(
"Skipping ignored token at seq_idx {}: token_idx={}, token='{}'",
seq_idx,
token_idx,
self.character_dict
.get(token_idx)
.unwrap_or(&"<INVALID>".to_string())
);
continue;
}
let token = self
.character_dict
.get(token_idx)
.cloned()
.unwrap_or_else(|| format!("UNK_{}", token_idx));
tracing::debug!(
"Decoded token at seq_idx {}: token_idx={}, dict_size={}, token='{}', prob={:.6}",
seq_idx,
token_idx,
self.character_dict.len(),
token,
token_prob
);
structure_tokens.push(token.clone());
scores.push(token_prob);
if self.td_token_indices.contains(&token_idx) {
let bbox = self.extract_bbox(bbox_preds, batch_idx, seq_idx, shape_info)?;
tracing::debug!("Extracted bbox for TD token '{}': {:?}", token, bbox);
bboxes.push(bbox);
}
}
tracing::info!(
"Decoded {} structure tokens: {:?}",
structure_tokens.len(),
structure_tokens
);
tracing::info!("Extracted {} bounding boxes", bboxes.len());
let mean_score = if scores.is_empty() {
0.0
} else {
let sum: f32 = scores.iter().copied().sum();
sum / (scores.len() as f32)
};
Ok((structure_tokens, bboxes, mean_score))
}
fn argmax_at(
&self,
logits: &ndarray::Array3<f32>,
batch_idx: usize,
seq_idx: usize,
) -> (usize, f32) {
let vocab_size = logits.shape()[2];
let mut max_idx = 0;
let mut max_val = f32::NEG_INFINITY;
for vocab_idx in 0..vocab_size {
let val = logits[[batch_idx, seq_idx, vocab_idx]];
if val > max_val {
max_val = val;
max_idx = vocab_idx;
}
}
(max_idx, max_val)
}
fn extract_bbox(
&self,
bbox_preds: &ndarray::Array3<f32>,
batch_idx: usize,
seq_idx: usize,
shape_info: &[[f32; 6]],
) -> Result<[f32; 8], OCRError> {
let mut bbox = [0.0f32; 8];
for (idx, coord) in bbox.iter_mut().enumerate() {
*coord = bbox_preds[[batch_idx, seq_idx, idx]];
}
if let Some(shape) = shape_info.get(batch_idx) {
let [orig_h, orig_w, scale, _pad_h, _pad_w, target_size] = *shape;
if scale <= 0.0 || target_size <= 0.0 {
return Err(OCRError::InvalidInput {
message: format!(
"Invalid shape info for batch {}: scale={} target_size={}",
batch_idx, scale, target_size
),
});
}
let longest_side = target_size / scale;
for (idx, coord_ref) in bbox.iter_mut().enumerate() {
let mut coord = *coord_ref * longest_side;
if idx % 2 == 0 {
coord = coord.clamp(0.0, orig_w);
} else {
coord = coord.clamp(0.0, orig_h);
}
*coord_ref = coord;
}
}
Ok(bbox)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_load_dict() {
}
#[test]
fn test_dictionary_processing() {
let temp_dict = vec![
"<html>".to_string(),
"<body>".to_string(),
"<table>".to_string(),
"<tr>".to_string(),
"<td>".to_string(), "<td".to_string(),
" colspan=\"4\"".to_string(),
">".to_string(),
"</td>".to_string(),
"</tr>".to_string(),
"</table>".to_string(),
"</body>".to_string(),
"</html>".to_string(),
];
let mut processed_dict = temp_dict.clone();
let merge_no_span_structure = true;
if merge_no_span_structure {
if !processed_dict.contains(&"<td></td>".to_string()) {
processed_dict.push("<td></td>".to_string());
}
if let Some(pos) = processed_dict.iter().position(|s| s == "<td>") {
processed_dict.remove(pos);
}
}
assert!(!processed_dict.contains(&"<td>".to_string()));
assert!(processed_dict.contains(&"<td></td>".to_string()));
let beg_str = "sos";
let end_str = "eos";
let mut final_dict = vec![beg_str.to_string()];
final_dict.extend(processed_dict);
final_dict.push(end_str.to_string());
assert_eq!(final_dict[0], "sos");
assert_eq!(final_dict[final_dict.len() - 1], "eos");
assert!(final_dict.contains(&"<html>".to_string()));
assert!(final_dict.contains(&"<td".to_string()));
assert!(final_dict.contains(&" colspan=\"4\"".to_string()));
}
#[test]
fn test_argmax() -> Result<(), OCRError> {
use ndarray::Array3;
let dict_path = Path::new("models/table_structure_dict.txt");
if !dict_path.exists() {
return Ok(()); }
let decoder = TableStructureDecode::from_dict_path(dict_path)?;
let logits = Array3::zeros((1, 5, 50));
let (idx, _prob) = decoder.argmax_at(&logits, 0, 0);
assert_eq!(idx, 0); Ok(())
}
#[test]
fn test_parse_cell_grid_info_simple() {
let tokens = vec![
"<tr>".to_string(),
"<td></td>".to_string(),
"<td></td>".to_string(),
"</tr>".to_string(),
"<tr>".to_string(),
"<td></td>".to_string(),
"<td></td>".to_string(),
"</tr>".to_string(),
];
let grid = parse_cell_grid_info(&tokens);
assert_eq!(grid.len(), 4);
assert_eq!(grid[0].row, 0);
assert_eq!(grid[0].col, 0);
assert_eq!(grid[0].row_span, 1);
assert_eq!(grid[0].col_span, 1);
assert_eq!(grid[1].row, 0);
assert_eq!(grid[1].col, 1);
assert_eq!(grid[2].row, 1);
assert_eq!(grid[2].col, 0);
assert_eq!(grid[3].row, 1);
assert_eq!(grid[3].col, 1);
}
#[test]
fn test_parse_cell_grid_info_colspan() {
let tokens = vec![
"<tr>".to_string(),
"<td colspan=\"2\"></td>".to_string(),
"</tr>".to_string(),
"<tr>".to_string(),
"<td></td>".to_string(),
"<td></td>".to_string(),
"</tr>".to_string(),
];
let grid = parse_cell_grid_info(&tokens);
assert_eq!(grid.len(), 3);
assert_eq!(grid[0].row, 0);
assert_eq!(grid[0].col, 0);
assert_eq!(grid[0].col_span, 2);
assert_eq!(grid[1].row, 1);
assert_eq!(grid[1].col, 0);
assert_eq!(grid[2].row, 1);
assert_eq!(grid[2].col, 1);
}
#[test]
fn test_parse_cell_grid_info_rowspan() {
let tokens = vec![
"<tr>".to_string(),
"<td rowspan=\"2\"></td>".to_string(),
"<td></td>".to_string(),
"</tr>".to_string(),
"<tr>".to_string(),
"<td></td>".to_string(), "</tr>".to_string(),
];
let grid = parse_cell_grid_info(&tokens);
assert_eq!(grid.len(), 3);
assert_eq!(grid[0].row, 0);
assert_eq!(grid[0].col, 0);
assert_eq!(grid[0].row_span, 2);
assert_eq!(grid[1].row, 0);
assert_eq!(grid[1].col, 1);
assert_eq!(grid[2].row, 1);
assert_eq!(grid[2].col, 1);
}
#[test]
fn test_parse_cell_grid_info_split_tokens_with_spans() {
let tokens = vec![
"<tr>",
"<td",
" colspan=\"2\"",
">",
"</td>",
"</tr>", "<tr>",
"<td",
" rowspan=\"2\"",
">",
"</td>",
"<td></td>",
"</tr>", "<tr>",
"<td></td>",
"</tr>", ]
.into_iter()
.map(str::to_string)
.collect::<Vec<_>>();
let grid = parse_cell_grid_info(&tokens);
assert_eq!(grid.len(), 4);
assert_eq!(grid[0].row, 0);
assert_eq!(grid[0].col_span, 2);
assert_eq!(grid[1].row, 1);
assert_eq!(grid[1].col, 0);
assert_eq!(grid[1].row_span, 2);
assert_eq!(grid[2].row, 1);
assert_eq!(grid[2].col, 1);
assert_eq!(grid[3].row, 2);
assert_eq!(grid[3].col, 1);
}
#[test]
fn test_wrap_table_html_with_split_tokens() {
let tokens = vec!["<tr>", "<td", " colspan=\"2\"", ">", "</td>", "</tr>"]
.into_iter()
.map(str::to_string)
.collect::<Vec<_>>();
let cell_texts = vec![Some("Cell A".to_string())];
let html = wrap_table_html_with_content(&tokens, &cell_texts);
assert!(html.contains("<td colspan=\"2\">Cell A</td>"));
assert!(html.starts_with("<html><body><table>"));
assert!(html.ends_with("</table></body></html>"));
}
#[test]
fn test_parse_span_attr() {
assert_eq!(parse_span_attr("<td colspan=\"2\">", "colspan"), Some(2));
assert_eq!(parse_span_attr("<td rowspan=\"3\">", "rowspan"), Some(3));
assert_eq!(
parse_span_attr("<td colspan=\"2\" rowspan=\"3\">", "colspan"),
Some(2)
);
assert_eq!(
parse_span_attr("<td colspan=\"2\" rowspan=\"3\">", "rowspan"),
Some(3)
);
assert_eq!(parse_span_attr("<td></td>", "colspan"), None);
assert_eq!(parse_span_attr("<td>", "rowspan"), None);
}
#[test]
fn test_extract_bbox_longest_side_scaling_matches_standard() -> Result<(), OCRError> {
let decoder = TableStructureDecode {
character_dict: Vec::new(),
ignored_tokens: Vec::new(),
td_token_indices: Vec::new(),
end_idx: 0,
};
let mut bbox_preds = ndarray::Array3::<f32>::zeros((1, 1, 8));
let preds = [0.45f32, 0.25, 0.9, 0.25, 0.45, 0.8, 0.9, 0.8];
for (i, val) in preds.iter().enumerate() {
bbox_preds[[0, 0, i]] = *val;
}
let orig_h: f32 = 600.0;
let orig_w: f32 = 300.0;
let target_size: f32 = 512.0;
let scale = target_size / orig_h.max(orig_w); let pad_h = 0.0;
let pad_w = target_size - (orig_w * scale); let shape_info = [[orig_h, orig_w, scale, pad_h, pad_w, target_size]];
let bbox = decoder.extract_bbox(&bbox_preds, 0, 0, &shape_info)?;
let longest_side = orig_h.max(orig_w);
let expected = [
(preds[0] * longest_side).clamp(0.0, orig_w),
(preds[1] * longest_side).clamp(0.0, orig_h),
(preds[2] * longest_side).clamp(0.0, orig_w),
(preds[3] * longest_side).clamp(0.0, orig_h),
(preds[4] * longest_side).clamp(0.0, orig_w),
(preds[5] * longest_side).clamp(0.0, orig_h),
(preds[6] * longest_side).clamp(0.0, orig_w),
(preds[7] * longest_side).clamp(0.0, orig_h),
];
for (idx, (got, exp)) in bbox.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-3,
"bbox coord {} mismatch: got {}, expected {}",
idx,
got,
exp
);
}
Ok(())
}
}