use crate::pdfium_backend::TextCell;
use image::RgbImage;
use ort::session::Session;
use ort::value::{Tensor, TensorRef};
const SIDE: u32 = 448;
#[allow(clippy::excessive_precision)]
const MEAN: [f32; 3] = [0.94247851, 0.94254675, 0.94292611];
#[allow(clippy::excessive_precision)]
const STD: [f32; 3] = [0.17910956, 0.17940403, 0.17931663];
const MAX_STEPS: usize = 1024;
const N_LAYERS: usize = 6;
const EMBED_DIM: usize = 512;
pub const START: i64 = 2;
pub const END: i64 = 3;
pub const ECEL: i64 = 4; pub const FCEL: i64 = 5; pub const LCEL: i64 = 6; pub const UCEL: i64 = 7; pub const XCEL: i64 = 8; pub const NL: i64 = 9; pub const CHED: i64 = 10; pub const RHED: i64 = 11; pub const SROW: i64 = 12;
#[derive(Debug, Clone)]
pub struct TableCell {
pub row: usize,
pub col: usize,
pub colspan: usize,
pub rowspan: usize,
pub tag: i64,
pub cx: f32,
pub cy: f32,
pub w: f32,
pub h: f32,
}
pub struct TableFormer {
encoder: Session,
decoder: Session,
bbox: Session,
}
struct EncodeOut {
ck_shape: Vec<usize>,
ck: Vec<f32>,
cv_shape: Vec<usize>,
cv: Vec<f32>,
eo_shape: Vec<usize>,
eo: Vec<f32>,
}
impl TableFormer {
pub fn load() -> Option<Self> {
Self::load_with(crate::intra_threads())
}
pub fn load_with(intra: usize) -> Option<Self> {
let enc = std::env::var("DOCLING_TABLEFORMER_ENCODER")
.unwrap_or_else(|_| "models/tableformer/encoder.onnx".to_string());
let dec = std::env::var("DOCLING_TABLEFORMER_DECODER")
.unwrap_or_else(|_| "models/tableformer/decoder.onnx".to_string());
let bbx = std::env::var("DOCLING_TABLEFORMER_BBOX")
.unwrap_or_else(|_| "models/tableformer/bbox.onnx".to_string());
if [&enc, &dec, &bbx]
.iter()
.any(|p| !std::path::Path::new(p).exists())
{
return None;
}
let build = |path: &str| -> Result<Session, String> {
Session::builder()
.map_err(|e| e.to_string())?
.with_intra_threads(intra)
.map_err(|e| e.to_string())?
.commit_from_file(path)
.map_err(|e| format!("tableformer load {path}: {e}"))
};
match (build(&enc), build(&dec), build(&bbx)) {
(Ok(encoder), Ok(decoder), Ok(bbox)) => Some(Self {
encoder,
decoder,
bbox,
}),
_ => None,
}
}
fn encode(&mut self, img: &RgbImage) -> Result<EncodeOut, String> {
let input = preprocess(img)?;
let enc_out = self
.encoder
.run(ort::inputs!["image" => input])
.map_err(|e| format!("tableformer: encode: {e}"))?;
let grab = |name: &str| -> Result<(Vec<usize>, Vec<f32>), String> {
let (sh, data) = enc_out[name]
.try_extract_tensor::<f32>()
.map_err(|e| format!("tableformer: {name}: {e}"))?;
Ok((sh.iter().map(|&x| x as usize).collect(), data.to_vec()))
};
let (ck_shape, ck) = grab("cross_k")?;
let (cv_shape, cv) = grab("cross_v")?;
let (eo_shape, eo) = grab("enc_out")?;
Ok(EncodeOut {
ck_shape,
ck,
cv_shape,
cv,
eo_shape,
eo,
})
}
fn decode_step(
&mut self,
tags: &[i64],
enc: &EncodeOut,
cache: &mut Vec<f32>,
cache_past: &mut usize,
empty_cache: &Tensor<f32>,
) -> Result<(i64, Vec<f32>), String> {
let tags_t = Tensor::from_array(([tags.len(), 1usize], tags.to_vec()))
.map_err(|e| format!("tableformer: tags: {e}"))?;
let ck_t = TensorRef::from_array_view((enc.ck_shape.as_slice(), enc.ck.as_slice()))
.map_err(|e| format!("tableformer: cross_k: {e}"))?;
let cv_t = TensorRef::from_array_view((enc.cv_shape.as_slice(), enc.cv.as_slice()))
.map_err(|e| format!("tableformer: cross_v: {e}"))?;
let dout = if *cache_past == 0 {
self.decoder.run(ort::inputs![
"tags" => tags_t, "cross_k" => ck_t, "cross_v" => cv_t, "cache" => empty_cache])
} else {
let cache_t = TensorRef::from_array_view((
[N_LAYERS, *cache_past, 1, EMBED_DIM],
cache.as_slice(),
))
.map_err(|e| format!("tableformer: cache: {e}"))?;
self.decoder.run(ort::inputs![
"tags" => tags_t, "cross_k" => ck_t, "cross_v" => cv_t, "cache" => cache_t])
}
.map_err(|e| format!("tableformer: decode: {e}"))?;
let (_, logits) = dout["logits"]
.try_extract_tensor::<f32>()
.map_err(|e| format!("tableformer: logits: {e}"))?;
let raw = argmax(logits) as i64;
let (oshape, ocache) = dout["out_cache"]
.try_extract_tensor::<f32>()
.map_err(|e| format!("tableformer: out_cache: {e}"))?;
let next_cache = ocache.to_vec();
let next_past = oshape[1] as usize;
let (_, hidden) = dout["hidden"]
.try_extract_tensor::<f32>()
.map_err(|e| format!("tableformer: hidden: {e}"))?;
let hidden = hidden.to_vec();
*cache = next_cache;
*cache_past = next_past;
Ok((raw, hidden))
}
fn empty_cache(&self) -> Result<Tensor<f32>, String> {
Tensor::<f32>::new(self.decoder.allocator(), [N_LAYERS, 0usize, 1, EMBED_DIM])
.map_err(|e| format!("tableformer: empty cache: {e}"))
}
pub fn predict_otsl(&mut self, img: &RgbImage) -> Result<Vec<i64>, String> {
let enc = self.encode(img)?;
let mut tags: Vec<i64> = vec![START];
let mut out: Vec<i64> = Vec::new();
let mut prev_ucel = false;
let mut cache: Vec<f32> = Vec::new();
let mut cache_past = 0usize;
let empty = self.empty_cache()?;
while out.len() < MAX_STEPS {
let (raw, _hidden) =
self.decode_step(&tags, &enc, &mut cache, &mut cache_past, &empty)?;
let mut tag = raw;
if tag == XCEL {
tag = LCEL;
}
if prev_ucel && tag == LCEL {
tag = FCEL;
}
if tag == END {
break;
}
out.push(tag);
tags.push(tag);
prev_ucel = tag == UCEL;
}
Ok(out)
}
pub fn predict_table_structure(&mut self, img: &RgbImage) -> Result<Vec<TableCell>, String> {
let enc = self.encode(img)?;
let mut tags: Vec<i64> = vec![START];
let mut otsl: Vec<i64> = Vec::new();
let mut hiddens: Vec<f32> = Vec::new(); let mut n = 0usize;
let mut prev_ucel = false;
let mut skip = true; let mut first_lcel = true;
let mut bbox_ind = 0usize;
let mut cur_bbox_ind = 0usize;
let mut merge: std::collections::HashMap<usize, i64> = std::collections::HashMap::new();
let mut cache: Vec<f32> = Vec::new();
let mut cache_past = 0usize;
let empty = self.empty_cache()?;
while otsl.len() < MAX_STEPS {
let (raw, hidden) =
self.decode_step(&tags, &enc, &mut cache, &mut cache_past, &empty)?;
let mut tag = raw;
if tag == XCEL {
tag = LCEL;
}
if prev_ucel && tag == LCEL {
tag = FCEL;
}
if tag == END {
break;
}
if !skip && matches!(tag, FCEL | ECEL | CHED | RHED | SROW | NL | UCEL) {
hiddens.extend_from_slice(&hidden);
n += 1;
if !first_lcel {
merge.insert(cur_bbox_ind, bbox_ind as i64);
}
bbox_ind += 1;
}
if tag != LCEL {
first_lcel = true;
} else if first_lcel {
hiddens.extend_from_slice(&hidden);
n += 1;
first_lcel = false;
cur_bbox_ind = bbox_ind;
merge.insert(cur_bbox_ind, -1);
bbox_ind += 1;
}
skip = matches!(tag, NL | UCEL | XCEL);
prev_ucel = tag == UCEL;
otsl.push(tag);
tags.push(tag);
}
if n == 0 {
return Ok(Vec::new());
}
let tag_h = Tensor::from_array(([n, 512usize], hiddens))
.map_err(|e| format!("tableformer: tag_h: {e}"))?;
let eo_t = Tensor::from_array((enc.eo_shape.clone(), enc.eo.clone()))
.map_err(|e| format!("tableformer: eo: {e}"))?;
let bout = self
.bbox
.run(ort::inputs!["enc_out" => eo_t, "tag_h" => tag_h])
.map_err(|e| format!("tableformer: bbox: {e}"))?;
let (_, raw) = bout["boxes"]
.try_extract_tensor::<f32>()
.map_err(|e| format!("tableformer: boxes: {e}"))?;
let boxes: Vec<[f32; 4]> = raw
.chunks_exact(4)
.map(|c| [c[0], c[1], c[2], c[3]])
.collect();
let merged = merge_spans(&boxes, &merge);
Ok(build_table_cells(&otsl, &merged))
}
pub fn predict_table_rows(
&mut self,
page_image: &RgbImage,
page_h: f32,
region: [f32; 4],
words: &[TextCell],
) -> Option<Vec<Vec<String>>> {
let sf = 1024.0 / page_image.height() as f32;
let pw = (page_image.width() as f32 * sf) as u32;
let page1024 = crate::resample::inter_area(page_image, pw, 1024);
let k = 1024.0 / page_h;
let x = (region[0] * k).round().max(0.0) as u32;
let y = (region[1] * k).round().max(0.0) as u32;
let x2 = ((region[2] * k).round() as u32).min(page1024.width());
let y2 = ((region[3] * k).round() as u32).min(page1024.height());
if x2 <= x || y2 <= y {
return None;
}
let crop = image::imageops::crop_imm(&page1024, x, y, x2 - x, y2 - y).to_image();
let cells = self.predict_table_structure(&crop).ok()?;
if cells.is_empty() {
return None;
}
let (rw, rh) = (region[2] - region[0], region[3] - region[1]);
let boxes: Vec<[f32; 4]> = cells
.iter()
.map(|c| {
[
region[0] + (c.cx - c.w / 2.0) * rw,
region[1] + (c.cy - c.h / 2.0) * rh,
region[0] + (c.cx + c.w / 2.0) * rw,
region[1] + (c.cy + c.h / 2.0) * rh,
]
})
.collect();
let mut cell_words: Vec<Vec<usize>> = vec![Vec::new(); cells.len()];
for (wi, w) in words.iter().enumerate() {
let wa = ((w.r - w.l) * (w.b - w.t)).max(1.0);
let mut best: Option<(f32, usize)> = None;
for (ci, b) in boxes.iter().enumerate() {
let ix = (w.r.min(b[2]) - w.l.max(b[0])).max(0.0);
let iy = (w.b.min(b[3]) - w.t.max(b[1])).max(0.0);
let io = ix * iy / wa;
if io > 0.0 && best.is_none_or(|(bo, _)| io > bo) {
best = Some((io, ci));
}
}
if let Some((_, ci)) = best {
cell_words[ci].push(wi);
}
}
let num_rows = cells.iter().map(|c| c.row + c.rowspan).max().unwrap_or(0);
let num_cols = cells.iter().map(|c| c.col + c.colspan).max().unwrap_or(0);
if num_rows == 0 || num_cols == 0 {
return None;
}
let mut grid = vec![vec![String::new(); num_cols]; num_rows];
for (ci, c) in cells.iter().enumerate() {
let wis = std::mem::take(&mut cell_words[ci]);
let text = wis
.iter()
.map(|&i| words[i].text.trim())
.collect::<Vec<_>>()
.join(" ");
for row in grid.iter_mut().skip(c.row).take(c.rowspan) {
for cell in row.iter_mut().skip(c.col).take(c.colspan) {
*cell = text.clone();
}
}
}
Some(grid)
}
}
fn preprocess(img: &RgbImage) -> Result<Tensor<f32>, String> {
let nn = (SIDE * SIDE) as usize;
let side = SIDE as usize;
let (sw, sh) = (img.width() as i32, img.height() as i32);
let sxr = sw as f32 / SIDE as f32;
let syr = sh as f32 / SIDE as f32;
let mut data = vec![0f32; 3 * nn];
for h in 0..side {
let fy = (h as f32 + 0.5) * syr - 0.5;
let wy = fy - fy.floor();
let y0c = (fy.floor() as i32).clamp(0, sh - 1) as u32;
let y1c = (fy.floor() as i32 + 1).clamp(0, sh - 1) as u32;
for w in 0..side {
let fx = (w as f32 + 0.5) * sxr - 0.5;
let wx = fx - fx.floor();
let x0c = (fx.floor() as i32).clamp(0, sw - 1) as u32;
let x1c = (fx.floor() as i32 + 1).clamp(0, sw - 1) as u32;
let p00 = img.get_pixel(x0c, y0c);
let p01 = img.get_pixel(x1c, y0c);
let p10 = img.get_pixel(x0c, y1c);
let p11 = img.get_pixel(x1c, y1c);
let idx = w * side + h; for c in 0..3 {
let top = p00[c] as f32 * (1.0 - wx) + p01[c] as f32 * wx;
let bot = p10[c] as f32 * (1.0 - wx) + p11[c] as f32 * wx;
let v = top * (1.0 - wy) + bot * wy;
data[c * nn + idx] = (v / 255.0 - MEAN[c]) / STD[c];
}
}
}
Tensor::from_array(([1usize, 3, side, side], data))
.map_err(|e| format!("tableformer: input: {e}"))
}
fn mergebboxes(b1: [f32; 4], b2: [f32; 4]) -> [f32; 4] {
let new_w = (b2[0] + b2[2] / 2.0) - (b1[0] - b1[2] / 2.0);
let new_h = (b2[1] + b2[3] / 2.0) - (b1[1] - b1[3] / 2.0);
let new_left = b1[0] - b1[2] / 2.0;
let new_top = (b2[1] - b2[3] / 2.0).min(b1[1] - b1[3] / 2.0);
[new_left + new_w / 2.0, new_top + new_h / 2.0, new_w, new_h]
}
fn merge_spans(boxes: &[[f32; 4]], merge: &std::collections::HashMap<usize, i64>) -> Vec<[f32; 4]> {
let skip: std::collections::HashSet<usize> = merge
.values()
.filter(|&&v| v >= 0)
.map(|&v| v as usize)
.collect();
let mut out = Vec::new();
for (i, &b) in boxes.iter().enumerate() {
if let Some(&j) = merge.get(&i) {
let partner = if j < 0 { boxes.len() - 1 } else { j as usize };
out.push(mergebboxes(b, boxes[partner.min(boxes.len() - 1)]));
} else if !skip.contains(&i) {
out.push(b);
}
}
out
}
const CELL_TAGS: [i64; 6] = [FCEL, ECEL, XCEL, CHED, RHED, SROW];
fn build_table_cells(otsl: &[i64], boxes: &[[f32; 4]]) -> Vec<TableCell> {
let mut grid: Vec<Vec<i64>> = vec![Vec::new()];
for &t in otsl {
if t == NL {
grid.push(Vec::new());
} else {
grid.last_mut().unwrap().push(t);
}
}
let mut cells = Vec::new();
let mut cell_id = 0usize;
for (r, row) in grid.iter().enumerate() {
for (c, &tag) in row.iter().enumerate() {
if !CELL_TAGS.contains(&tag) {
continue;
}
let mut colspan = 1;
while c + colspan < row.len() && matches!(row[c + colspan], LCEL | XCEL) {
colspan += 1;
}
let mut rowspan = 1;
while r + rowspan < grid.len()
&& grid[r + rowspan]
.get(c)
.is_some_and(|&t| matches!(t, UCEL | XCEL))
{
rowspan += 1;
}
let b = boxes.get(cell_id).copied().unwrap_or([0.0; 4]);
cells.push(TableCell {
row: r,
col: c,
colspan,
rowspan,
tag,
cx: b[0],
cy: b[1],
w: b[2],
h: b[3],
});
cell_id += 1;
}
}
cells
}
fn argmax(v: &[f32]) -> usize {
v.iter()
.enumerate()
.max_by(|a, b| a.1.total_cmp(b.1))
.map(|(i, _)| i)
.unwrap_or(0)
}