Skip to main content

fleischwolf_pdf/
tableformer.rs

1//! TableFormer: table-structure recovery via docling-ibm-models, exported to
2//! ONNX by `scripts/export_tableformer.py`. The image encoder + tag-transformer
3//! encoder run once to a memory tensor; the decoder is then stepped
4//! autoregressively to emit an OTSL structure-token sequence (the same model
5//! docling runs). See PDF_CONFORMANCE.md.
6
7use crate::pdfium_backend::TextCell;
8use image::RgbImage;
9use ort::session::Session;
10use ort::value::Tensor;
11
12const SIDE: u32 = 448;
13// Verbatim from docling's tm_config.json image_normalization (more digits than
14// f32 holds; kept exact for provenance).
15#[allow(clippy::excessive_precision)]
16const MEAN: [f32; 3] = [0.94247851, 0.94254675, 0.94292611];
17#[allow(clippy::excessive_precision)]
18const STD: [f32; 3] = [0.17910956, 0.17940403, 0.17931663];
19const MAX_STEPS: usize = 1024;
20
21/// OTSL structure tokens (TableModel04_rs wordmap indices).
22pub const START: i64 = 2;
23pub const END: i64 = 3;
24pub const ECEL: i64 = 4; // empty cell
25pub const FCEL: i64 = 5; // full (content) cell
26pub const LCEL: i64 = 6; // left-looking: extends the cell to its left (colspan)
27pub const UCEL: i64 = 7; // up-looking: extends the cell above (rowspan)
28pub const XCEL: i64 = 8; // cross: spans both ways
29pub const NL: i64 = 9; // new row
30pub const CHED: i64 = 10; // column header
31pub const RHED: i64 = 11; // row header
32pub const SROW: i64 = 12; // section row
33
34/// A predicted table cell: an OTSL grid position (with spans) + its box in the
35/// 448 image normalized cxcywh, and the OTSL tag.
36#[derive(Debug, Clone)]
37pub struct TableCell {
38    pub row: usize,
39    pub col: usize,
40    pub colspan: usize,
41    pub rowspan: usize,
42    pub tag: i64,
43    pub cx: f32,
44    pub cy: f32,
45    pub w: f32,
46    pub h: f32,
47}
48
49pub struct TableFormer {
50    encoder: Session,
51    decoder: Session,
52    bbox: Session,
53}
54
55impl TableFormer {
56    /// Load the exported encoder/decoder/bbox ONNX graphs (env overrides, else
57    /// `models/tableformer/{encoder,decoder,bbox}.onnx`). Returns `None` if any is
58    /// absent, so the pipeline falls back to geometric reconstruction.
59    pub fn load() -> Option<Self> {
60        let enc = std::env::var("DOCLING_TABLEFORMER_ENCODER")
61            .unwrap_or_else(|_| "models/tableformer/encoder.onnx".to_string());
62        let dec = std::env::var("DOCLING_TABLEFORMER_DECODER")
63            .unwrap_or_else(|_| "models/tableformer/decoder.onnx".to_string());
64        let bbx = std::env::var("DOCLING_TABLEFORMER_BBOX")
65            .unwrap_or_else(|_| "models/tableformer/bbox.onnx".to_string());
66        if [&enc, &dec, &bbx]
67            .iter()
68            .any(|p| !std::path::Path::new(p).exists())
69        {
70            return None;
71        }
72        let build = |path: &str| -> Result<Session, String> {
73            Session::builder()
74                .map_err(|e| e.to_string())?
75                .with_intra_threads(crate::intra_threads())
76                .map_err(|e| e.to_string())?
77                .commit_from_file(path)
78                .map_err(|e| format!("tableformer load {path}: {e}"))
79        };
80        match (build(&enc), build(&dec), build(&bbx)) {
81            (Ok(encoder), Ok(decoder), Ok(bbox)) => Some(Self {
82                encoder,
83                decoder,
84                bbox,
85            }),
86            _ => None,
87        }
88    }
89
90    /// Predict the OTSL structure-token sequence for a table-region image.
91    pub fn predict_otsl(&mut self, img: &RgbImage) -> Result<Vec<i64>, String> {
92        let input = preprocess(img)?;
93        let enc_out = self
94            .encoder
95            .run(ort::inputs!["image" => input])
96            .map_err(|e| format!("tableformer: encode: {e}"))?;
97        let (mshape, mem) = enc_out["memory"]
98            .try_extract_tensor::<f32>()
99            .map_err(|e| format!("tableformer: memory: {e}"))?;
100        let mshape: Vec<usize> = mshape.iter().map(|&x| x as usize).collect();
101        let mem: Vec<f32> = mem.to_vec();
102
103        // Autoregressive decode: the decoder graph re-applies the layers to the
104        // whole prefix under a causal mask (statelessly reproducing the model's
105        // per-layer cache), so we just feed the growing token list back in. The
106        // two structure corrections mirror docling's `predict` exactly — note its
107        // `line_num` is never incremented, so `xcel→lcel` applies on every row.
108        let mut tags: Vec<i64> = vec![START];
109        let mut out: Vec<i64> = Vec::new();
110        let mut prev_ucel = false;
111        while out.len() < MAX_STEPS {
112            let tags_t = Tensor::from_array(([tags.len(), 1usize], tags.clone()))
113                .map_err(|e| format!("tableformer: tags: {e}"))?;
114            let mem_t = Tensor::from_array((mshape.clone(), mem.clone()))
115                .map_err(|e| format!("tableformer: mem: {e}"))?;
116            let dout = self
117                .decoder
118                .run(ort::inputs!["tags" => tags_t, "memory" => mem_t])
119                .map_err(|e| format!("tableformer: decode: {e}"))?;
120            let (_, logits) = dout["logits"]
121                .try_extract_tensor::<f32>()
122                .map_err(|e| format!("tableformer: logits: {e}"))?;
123            let mut tag = argmax(logits) as i64;
124            if tag == XCEL {
125                tag = LCEL;
126            }
127            if prev_ucel && tag == LCEL {
128                tag = FCEL;
129            }
130            if tag == END {
131                break;
132            }
133            out.push(tag);
134            tags.push(tag);
135            prev_ucel = tag == UCEL;
136        }
137        Ok(out)
138    }
139
140    /// Full structure prediction: OTSL grid cells with per-cell boxes (in the 448
141    /// image, normalized cxcywh). Collects per-cell decoder hidden states using
142    /// docling's exact bbox bookkeeping (skip-after-row-break, first-lcel of a
143    /// horizontal span), runs the bbox decoder, merges span boxes, then lays the
144    /// cells onto the OTSL grid with row/col spans.
145    pub fn predict_table_structure(&mut self, img: &RgbImage) -> Result<Vec<TableCell>, String> {
146        let input = preprocess(img)?;
147        let enc_out = self
148            .encoder
149            .run(ort::inputs!["image" => input])
150            .map_err(|e| format!("tableformer: encode: {e}"))?;
151        let (mshape, mem) = enc_out["memory"]
152            .try_extract_tensor::<f32>()
153            .map_err(|e| format!("tableformer: memory: {e}"))?;
154        let mshape: Vec<usize> = mshape.iter().map(|&x| x as usize).collect();
155        let mem: Vec<f32> = mem.to_vec();
156        let (eshape, eo) = enc_out["enc_out"]
157            .try_extract_tensor::<f32>()
158            .map_err(|e| format!("tableformer: enc_out: {e}"))?;
159        let eshape: Vec<usize> = eshape.iter().map(|&x| x as usize).collect();
160        let eo: Vec<f32> = eo.to_vec();
161
162        let mut tags: Vec<i64> = vec![START];
163        let mut otsl: Vec<i64> = Vec::new();
164        let mut hiddens: Vec<f32> = Vec::new(); // flattened [n, 512]
165        let mut n = 0usize;
166        let mut prev_ucel = false;
167        let mut skip = true; // first tag after <start> is skipped
168        let mut first_lcel = true;
169        let mut bbox_ind = 0usize;
170        let mut cur_bbox_ind = 0usize;
171        let mut merge: std::collections::HashMap<usize, i64> = std::collections::HashMap::new();
172        while otsl.len() < MAX_STEPS {
173            let tags_t = Tensor::from_array(([tags.len(), 1usize], tags.clone()))
174                .map_err(|e| format!("tableformer: tags: {e}"))?;
175            let mem_t = Tensor::from_array((mshape.clone(), mem.clone()))
176                .map_err(|e| format!("tableformer: mem: {e}"))?;
177            let dout = self
178                .decoder
179                .run(ort::inputs!["tags" => tags_t, "memory" => mem_t])
180                .map_err(|e| format!("tableformer: decode: {e}"))?;
181            let (_, logits) = dout["logits"]
182                .try_extract_tensor::<f32>()
183                .map_err(|e| format!("tableformer: logits: {e}"))?;
184            let mut tag = argmax(logits) as i64;
185            if tag == XCEL {
186                tag = LCEL;
187            }
188            if prev_ucel && tag == LCEL {
189                tag = FCEL;
190            }
191            if tag == END {
192                break;
193            }
194            let (_, hidden) = dout["hidden"]
195                .try_extract_tensor::<f32>()
196                .map_err(|e| format!("tableformer: hidden: {e}"))?;
197            // docling's tag_H_buf / bboxes_to_merge bookkeeping.
198            if !skip && matches!(tag, FCEL | ECEL | CHED | RHED | SROW | NL | UCEL) {
199                hiddens.extend_from_slice(hidden);
200                n += 1;
201                if !first_lcel {
202                    merge.insert(cur_bbox_ind, bbox_ind as i64);
203                }
204                bbox_ind += 1;
205            }
206            if tag != LCEL {
207                first_lcel = true;
208            } else if first_lcel {
209                hiddens.extend_from_slice(hidden);
210                n += 1;
211                first_lcel = false;
212                cur_bbox_ind = bbox_ind;
213                merge.insert(cur_bbox_ind, -1);
214                bbox_ind += 1;
215            }
216            skip = matches!(tag, NL | UCEL | XCEL);
217            prev_ucel = tag == UCEL;
218            otsl.push(tag);
219            tags.push(tag);
220        }
221        if n == 0 {
222            return Ok(Vec::new());
223        }
224        let tag_h = Tensor::from_array(([n, 512usize], hiddens))
225            .map_err(|e| format!("tableformer: tag_h: {e}"))?;
226        let eo_t = Tensor::from_array((eshape, eo)).map_err(|e| format!("tableformer: eo: {e}"))?;
227        let bout = self
228            .bbox
229            .run(ort::inputs!["enc_out" => eo_t, "tag_h" => tag_h])
230            .map_err(|e| format!("tableformer: bbox: {e}"))?;
231        let (_, raw) = bout["boxes"]
232            .try_extract_tensor::<f32>()
233            .map_err(|e| format!("tableformer: boxes: {e}"))?;
234        let boxes: Vec<[f32; 4]> = raw
235            .chunks_exact(4)
236            .map(|c| [c[0], c[1], c[2], c[3]])
237            .collect();
238        let merged = merge_spans(&boxes, &merge);
239        Ok(build_table_cells(&otsl, &merged))
240    }
241
242    /// Predict a table region's Markdown grid: crop the region (docling's
243    /// page→1024px box-average then bbox crop), run the structure model, map each
244    /// cell box back to page points, match the page's word cells into cells by
245    /// intersection-over-word-area, and expand spans into a dense `rows × cols`
246    /// grid. `region` is `(l, t, r, b)` in page points (top-left). Returns `None`
247    /// if no structure is predicted.
248    pub fn predict_table_rows(
249        &mut self,
250        page_image: &RgbImage,
251        page_h: f32,
252        region: [f32; 4],
253        words: &[TextCell],
254    ) -> Option<Vec<Vec<String>>> {
255        // page → 1024px height (cv2.INTER_AREA), then crop the table bbox.
256        let sf = 1024.0 / page_image.height() as f32;
257        let pw = (page_image.width() as f32 * sf) as u32;
258        let page1024 = crate::resample::inter_area(page_image, pw, 1024);
259        let k = 1024.0 / page_h;
260        let x = (region[0] * k).round().max(0.0) as u32;
261        let y = (region[1] * k).round().max(0.0) as u32;
262        let x2 = ((region[2] * k).round() as u32).min(page1024.width());
263        let y2 = ((region[3] * k).round() as u32).min(page1024.height());
264        if x2 <= x || y2 <= y {
265            return None;
266        }
267        let crop = image::imageops::crop_imm(&page1024, x, y, x2 - x, y2 - y).to_image();
268        let cells = self.predict_table_structure(&crop).ok()?;
269        if cells.is_empty() {
270            return None;
271        }
272        let (rw, rh) = (region[2] - region[0], region[3] - region[1]);
273
274        // Cell boxes in page points (top-left), aligned with `cells`.
275        let boxes: Vec<[f32; 4]> = cells
276            .iter()
277            .map(|c| {
278                [
279                    region[0] + (c.cx - c.w / 2.0) * rw,
280                    region[1] + (c.cy - c.h / 2.0) * rh,
281                    region[0] + (c.cx + c.w / 2.0) * rw,
282                    region[1] + (c.cy + c.h / 2.0) * rh,
283                ]
284            })
285            .collect();
286
287        // Assign each word to the cell it overlaps most (intersection / word area).
288        let mut cell_words: Vec<Vec<usize>> = vec![Vec::new(); cells.len()];
289        for (wi, w) in words.iter().enumerate() {
290            let wa = ((w.r - w.l) * (w.b - w.t)).max(1.0);
291            let mut best: Option<(f32, usize)> = None;
292            for (ci, b) in boxes.iter().enumerate() {
293                let ix = (w.r.min(b[2]) - w.l.max(b[0])).max(0.0);
294                let iy = (w.b.min(b[3]) - w.t.max(b[1])).max(0.0);
295                let io = ix * iy / wa;
296                if io > 0.0 && best.is_none_or(|(bo, _)| io > bo) {
297                    best = Some((io, ci));
298                }
299            }
300            if let Some((_, ci)) = best {
301                cell_words[ci].push(wi);
302            }
303        }
304
305        let num_rows = cells.iter().map(|c| c.row + c.rowspan).max().unwrap_or(0);
306        let num_cols = cells.iter().map(|c| c.col + c.colspan).max().unwrap_or(0);
307        if num_rows == 0 || num_cols == 0 {
308            return None;
309        }
310        let mut grid = vec![vec![String::new(); num_cols]; num_rows];
311        for (ci, c) in cells.iter().enumerate() {
312            // Keep words in text-stream order (the order they were collected =
313            // their word index), matching docling's cell text assembly — geometric
314            // re-sorting scrambles wrapped cells (`Inference time (secs)`).
315            let wis = std::mem::take(&mut cell_words[ci]);
316            let text = wis
317                .iter()
318                .map(|&i| words[i].text.trim())
319                .collect::<Vec<_>>()
320                .join(" ");
321            // Spanned cells repeat their text across the covered grid positions.
322            for row in grid.iter_mut().skip(c.row).take(c.rowspan) {
323                for cell in row.iter_mut().skip(c.col).take(c.colspan) {
324                    *cell = text.clone();
325                }
326            }
327        }
328        Some(grid)
329    }
330}
331
332/// docling's preprocessing: bilinear (cv2.INTER_LINEAR) resize the crop to 448²,
333/// normalize `(x/255 − mean)/std`, laid out as (C, W, H) — docling transposes
334/// (2,1,0), so width is the major spatial axis. The page→1024px box-average
335/// (cv2.INTER_AREA) is the caller's job.
336fn preprocess(img: &RgbImage) -> Result<Tensor<f32>, String> {
337    let nn = (SIDE * SIDE) as usize;
338    let side = SIDE as usize;
339    let (sw, sh) = (img.width() as i32, img.height() as i32);
340    let sxr = sw as f32 / SIDE as f32;
341    let syr = sh as f32 / SIDE as f32;
342    let mut data = vec![0f32; 3 * nn];
343    for h in 0..side {
344        let fy = (h as f32 + 0.5) * syr - 0.5;
345        let wy = fy - fy.floor();
346        let y0c = (fy.floor() as i32).clamp(0, sh - 1) as u32;
347        let y1c = (fy.floor() as i32 + 1).clamp(0, sh - 1) as u32;
348        for w in 0..side {
349            let fx = (w as f32 + 0.5) * sxr - 0.5;
350            let wx = fx - fx.floor();
351            let x0c = (fx.floor() as i32).clamp(0, sw - 1) as u32;
352            let x1c = (fx.floor() as i32 + 1).clamp(0, sw - 1) as u32;
353            let p00 = img.get_pixel(x0c, y0c);
354            let p01 = img.get_pixel(x1c, y0c);
355            let p10 = img.get_pixel(x0c, y1c);
356            let p11 = img.get_pixel(x1c, y1c);
357            let idx = w * side + h; // (C, W, H): c*n + w*H + h
358            for c in 0..3 {
359                let top = p00[c] as f32 * (1.0 - wx) + p01[c] as f32 * wx;
360                let bot = p10[c] as f32 * (1.0 - wx) + p11[c] as f32 * wx;
361                let v = top * (1.0 - wy) + bot * wy;
362                data[c * nn + idx] = (v / 255.0 - MEAN[c]) / STD[c];
363            }
364        }
365    }
366    Tensor::from_array(([1usize, 3, side, side], data))
367        .map_err(|e| format!("tableformer: input: {e}"))
368}
369
370/// docling's `mergebboxes` (cxcywh): the union box of a horizontal span's first
371/// and last cell.
372fn mergebboxes(b1: [f32; 4], b2: [f32; 4]) -> [f32; 4] {
373    let new_w = (b2[0] + b2[2] / 2.0) - (b1[0] - b1[2] / 2.0);
374    let new_h = (b2[1] + b2[3] / 2.0) - (b1[1] - b1[3] / 2.0);
375    let new_left = b1[0] - b1[2] / 2.0;
376    let new_top = (b2[1] - b2[3] / 2.0).min(b1[1] - b1[3] / 2.0);
377    [new_left + new_w / 2.0, new_top + new_h / 2.0, new_w, new_h]
378}
379
380/// Apply docling's span merges: each merge key combines its box with the partner
381/// (`-1` → the last box); partners are dropped.
382fn merge_spans(boxes: &[[f32; 4]], merge: &std::collections::HashMap<usize, i64>) -> Vec<[f32; 4]> {
383    let skip: std::collections::HashSet<usize> = merge
384        .values()
385        .filter(|&&v| v >= 0)
386        .map(|&v| v as usize)
387        .collect();
388    let mut out = Vec::new();
389    for (i, &b) in boxes.iter().enumerate() {
390        if let Some(&j) = merge.get(&i) {
391            let partner = if j < 0 { boxes.len() - 1 } else { j as usize };
392            out.push(mergebboxes(b, boxes[partner.min(boxes.len() - 1)]));
393        } else if !skip.contains(&i) {
394            out.push(b);
395        }
396    }
397    out
398}
399
400const CELL_TAGS: [i64; 6] = [FCEL, ECEL, XCEL, CHED, RHED, SROW];
401
402/// Lay the OTSL tag stream onto a grid (docling's `_build_table_cells`, OTSL
403/// mode): cell tags create cells at (row, col); `lcel`/`ucel`/`xcel` are spans
404/// (counted toward the column index but not cells). Colspan/rowspan are read off
405/// the grid (consecutive `lcel`/`ucel` to the right/below). `boxes` are indexed
406/// by cell order and aligned with the cells.
407fn build_table_cells(otsl: &[i64], boxes: &[[f32; 4]]) -> Vec<TableCell> {
408    // 2D grid of tags (rows split on NL) for span lookups.
409    let mut grid: Vec<Vec<i64>> = vec![Vec::new()];
410    for &t in otsl {
411        if t == NL {
412            grid.push(Vec::new());
413        } else {
414            grid.last_mut().unwrap().push(t);
415        }
416    }
417    let mut cells = Vec::new();
418    let mut cell_id = 0usize;
419    for (r, row) in grid.iter().enumerate() {
420        for (c, &tag) in row.iter().enumerate() {
421            if !CELL_TAGS.contains(&tag) {
422                continue;
423            }
424            let mut colspan = 1;
425            while c + colspan < row.len() && matches!(row[c + colspan], LCEL | XCEL) {
426                colspan += 1;
427            }
428            let mut rowspan = 1;
429            while r + rowspan < grid.len()
430                && grid[r + rowspan]
431                    .get(c)
432                    .is_some_and(|&t| matches!(t, UCEL | XCEL))
433            {
434                rowspan += 1;
435            }
436            let b = boxes.get(cell_id).copied().unwrap_or([0.0; 4]);
437            cells.push(TableCell {
438                row: r,
439                col: c,
440                colspan,
441                rowspan,
442                tag,
443                cx: b[0],
444                cy: b[1],
445                w: b[2],
446                h: b[3],
447            });
448            cell_id += 1;
449        }
450    }
451    cells
452}
453
454fn argmax(v: &[f32]) -> usize {
455    v.iter()
456        .enumerate()
457        .max_by(|a, b| a.1.total_cmp(b.1))
458        .map(|(i, _)| i)
459        .unwrap_or(0)
460}