Skip to main content

fleischwolf_pdf/
tableformer.rs

1//! TableFormer: table-structure recovery via docling-ibm-models, exported to
2//! ONNX by `scripts/export_tableformer.py`. The image encoder + tag-transformer
3//! encoder run once to a memory tensor; the decoder is then stepped
4//! autoregressively to emit an OTSL structure-token sequence (the same model
5//! docling runs). See PDF_CONFORMANCE.md.
6
7use crate::pdfium_backend::TextCell;
8use image::RgbImage;
9use ort::session::Session;
10use ort::value::{DynValue, Tensor};
11
12const SIDE: u32 = 448;
13// Verbatim from docling's tm_config.json image_normalization (more digits than
14// f32 holds; kept exact for provenance).
15#[allow(clippy::excessive_precision)]
16const MEAN: [f32; 3] = [0.94247851, 0.94254675, 0.94292611];
17#[allow(clippy::excessive_precision)]
18const STD: [f32; 3] = [0.17910956, 0.17940403, 0.17931663];
19const MAX_STEPS: usize = 1024;
20/// Decoder geometry, fixed by the exported TableModel04_rs graph: the cached
21/// decoder threads a `[N_LAYERS, past, 1, EMBED_DIM]` per-layer state cache.
22const N_LAYERS: usize = 6;
23const EMBED_DIM: usize = 512;
24
25/// OTSL structure tokens (TableModel04_rs wordmap indices).
26pub const START: i64 = 2;
27pub const END: i64 = 3;
28pub const ECEL: i64 = 4; // empty cell
29pub const FCEL: i64 = 5; // full (content) cell
30pub const LCEL: i64 = 6; // left-looking: extends the cell to its left (colspan)
31pub const UCEL: i64 = 7; // up-looking: extends the cell above (rowspan)
32pub const XCEL: i64 = 8; // cross: spans both ways
33pub const NL: i64 = 9; // new row
34pub const CHED: i64 = 10; // column header
35pub const RHED: i64 = 11; // row header
36pub const SROW: i64 = 12; // section row
37
38/// A predicted table cell: an OTSL grid position (with spans) + its box in the
39/// 448 image normalized cxcywh, and the OTSL tag.
40#[derive(Debug, Clone)]
41pub struct TableCell {
42    pub row: usize,
43    pub col: usize,
44    pub colspan: usize,
45    pub rowspan: usize,
46    pub tag: i64,
47    pub cx: f32,
48    pub cy: f32,
49    pub w: f32,
50    pub h: f32,
51}
52
53pub struct TableFormer {
54    encoder: Session,
55    decoder: Session,
56    bbox: Session,
57    /// True when the decoder is the true-KV-cache export (`decoder_kv.onnx`:
58    /// inputs `tag`/`cache_k`/`cache_v`, one token per step); false for the
59    /// legacy layer-output-cache graph (`decoder.onnx`: full `tags` + `cache`).
60    /// Detected from the session's input names, so an explicit
61    /// `DOCLING_TABLEFORMER_DECODER` override works with either graph.
62    kv: bool,
63}
64
65/// KV-cache geometry fixed by the `decoder_kv.onnx` export
66/// (`[N_LAYERS, 1, KV_HEADS, past, KV_HEAD_DIM]`, `KV_HEADS × KV_HEAD_DIM = EMBED_DIM`).
67const KV_HEADS: usize = 8;
68const KV_HEAD_DIM: usize = 64;
69
70/// The autoregressive decode state: `a` is the legacy layer-output cache, or
71/// `cache_k` for the KV graph; `b` is `cache_v` (KV graph only). `None` = first
72/// step (the zero-`past` empties are allocated per table by [`TableFormer::empty_cache`]).
73#[derive(Default)]
74struct DecodeCache {
75    a: Option<DynValue>,
76    b: Option<DynValue>,
77}
78
79/// Zero-`past` first-step cache tensors: `(cache, None)` for the legacy graph,
80/// `(cache_k, Some(cache_v))` for the KV graph.
81type EmptyCache = (Tensor<f32>, Option<Tensor<f32>>);
82
83/// Encoder outputs that drive the cached decode loop: the per-layer cross-attention
84/// K/V (projected from the image memory once, constant across decode steps) and
85/// `enc_out` for the bbox decoder. Kept as owned `ort` values so each decode step
86/// (and the bbox run) borrows them directly — no per-step extract/copy/re-wrap.
87struct EncodeOut {
88    ck: DynValue,
89    cv: DynValue,
90    eo: DynValue,
91}
92
93impl TableFormer {
94    /// Load the exported encoder/decoder/bbox ONNX graphs (env overrides, else
95    /// `models/tableformer/{encoder,decoder,bbox}.onnx`). Returns `None` if any is
96    /// absent, so the pipeline falls back to geometric reconstruction.
97    pub fn load() -> Option<Self> {
98        Self::load_with(crate::intra_threads())
99    }
100
101    /// Like [`load`](Self::load) but with an explicit intra-op thread count, so a
102    /// parallel page-worker pool can run each table model on fewer threads (the
103    /// throughput comes from running pages concurrently, not from one fat model).
104    pub fn load_with(intra: usize) -> Option<Self> {
105        let enc = std::env::var("DOCLING_TABLEFORMER_ENCODER")
106            .unwrap_or_else(|_| crate::resolve_asset("models/tableformer/encoder.onnx"));
107        // Decoder preference (explicit override wins): INT8 variants first
108        // unless FLEISCHWOLF_FP32 opts out; within a precision the true-KV-cache
109        // export (`decoder_kv*`, one token per step) ranks behind the legacy
110        // layer-output-cache graph it matches byte-for-byte — measured parity on
111        // corpus-sized tables (ORT batches the legacy graph's prefix
112        // re-projection efficiently), so the smaller legacy file stays the
113        // default and `decoder_kv*` serves very-large-table workloads, where its
114        // O(past) step cost wins.
115        let dec = std::env::var("DOCLING_TABLEFORMER_DECODER").unwrap_or_else(|_| {
116            let candidates: &[&str] = if crate::fp32_forced() {
117                &[
118                    "models/tableformer/decoder.onnx",
119                    "models/tableformer/decoder_kv.onnx",
120                ]
121            } else {
122                &[
123                    "models/tableformer/decoder_int8.onnx",
124                    "models/tableformer/decoder_kv_int8.onnx",
125                    "models/tableformer/decoder.onnx",
126                    "models/tableformer/decoder_kv.onnx",
127                ]
128            };
129            candidates
130                .iter()
131                .map(|p| crate::resolve_asset(p))
132                .find(|p| std::path::Path::new(p).exists())
133                .unwrap_or_else(|| "models/tableformer/decoder.onnx".to_string())
134        });
135        let bbx = std::env::var("DOCLING_TABLEFORMER_BBOX")
136            .unwrap_or_else(|_| crate::resolve_asset("models/tableformer/bbox.onnx"));
137        if [&enc, &dec, &bbx]
138            .iter()
139            .any(|p| !std::path::Path::new(p).exists())
140        {
141            // The geometric fallback is a supported, intentional configuration
142            // (docling has no ML table-structure equivalent baked in either), so
143            // this stays a single quiet stderr note rather than an error — but it
144            // fires every process (not per-worker) so a CWD-relative default that
145            // silently misses its files (a very easy mistake for anything not run
146            // from the repo root, e.g. an embedding app) is at least visible once.
147            warn_missing_once(&enc, &dec, &bbx);
148            return None;
149        }
150        // The decoder's KV-cache grows by one entry every autoregressive step, so
151        // its input shapes differ on every `run()` call. ONNX Runtime's memory
152        // pattern optimizer assumes stable shapes to plan buffer reuse; disabling
153        // it for this session avoids repeatedly re-validating/re-touching that
154        // plan (and the external-weights file) on each step.
155        let build = |path: &str, mem_pattern: bool| -> Result<Session, String> {
156            Session::builder()
157                .map_err(|e| e.to_string())?
158                .with_intra_threads(intra)
159                .map_err(|e| e.to_string())?
160                .with_memory_pattern(mem_pattern)
161                .map_err(|e| e.to_string())?
162                .commit_from_file(path)
163                .map_err(|e| format!("tableformer load {path}: {e}"))
164        };
165        match (build(&enc, true), build(&dec, false), build(&bbx, true)) {
166            (Ok(encoder), Ok(decoder), Ok(bbox)) => {
167                let kv = decoder.inputs().iter().any(|i| i.name() == "cache_k");
168                Some(Self {
169                    encoder,
170                    decoder,
171                    bbox,
172                    kv,
173                })
174            }
175            _ => None,
176        }
177    }
178
179    /// Run the image encoder and capture what the cached decoder loop needs: each
180    /// decoder layer's cross-attention K/V (projected from the image memory once,
181    /// shape `[N_LAYERS,1,H,S,head_dim]`) and `enc_out` for the bbox decoder.
182    fn encode(&mut self, img: &RgbImage) -> Result<EncodeOut, String> {
183        let input = preprocess(img)?;
184        let mut enc_out = self
185            .encoder
186            .run(ort::inputs!["image" => input])
187            .map_err(|e| format!("tableformer: encode: {e}"))?;
188        let mut grab = |name: &str| -> Result<DynValue, String> {
189            enc_out
190                .remove(name)
191                .ok_or_else(|| format!("tableformer: encoder output {name} missing"))
192        };
193        Ok(EncodeOut {
194            ck: grab("cross_k")?,
195            cv: grab("cross_v")?,
196            eo: grab("enc_out")?,
197        })
198    }
199
200    /// One doubly-cached decode step: feed the current `tags`, the constant cross
201    /// K/V, and the growing self-attention `cache`; return the raw argmax tag and
202    /// the last token's hidden state, advancing the cache. The cache stays an owned
203    /// `ort` value — the previous step's `out_cache` output is fed back directly,
204    /// never extracted or copied (it grows every step, so per-step copies were
205    /// O(steps²) float traffic). `empty_cache` is the zero-`past` value used on the
206    /// first step (ort's array constructors reject a 0-length dim, so it is
207    /// allocated through the session allocator by the caller).
208    fn decode_step(
209        &mut self,
210        tags: &[i64],
211        enc: &EncodeOut,
212        cache: &mut DecodeCache,
213        empty: &EmptyCache,
214    ) -> Result<(i64, Vec<f32>), String> {
215        let mut dout = if self.kv {
216            // KV graph: feed only the newly emitted tag; the projected K/V for
217            // the whole prefix live in cache_k/cache_v and are fed back as-is.
218            let last = *tags.last().expect("decode starts from <start>");
219            let tag_t = Tensor::from_array(([1usize, 1usize], vec![last]))
220                .map_err(|e| format!("tableformer: tag: {e}"))?;
221            match (cache.a.as_ref(), cache.b.as_ref()) {
222                (Some(k), Some(v)) => self.decoder.run(ort::inputs![
223                    "tag" => tag_t, "cross_k" => &enc.ck, "cross_v" => &enc.cv,
224                    "cache_k" => k, "cache_v" => v]),
225                _ => self.decoder.run(ort::inputs![
226                    "tag" => tag_t, "cross_k" => &enc.ck, "cross_v" => &enc.cv,
227                    "cache_k" => &empty.0,
228                    "cache_v" => empty.1.as_ref().expect("kv empty cache has both halves")]),
229            }
230        } else {
231            let tags_t = Tensor::from_array(([tags.len(), 1usize], tags.to_vec()))
232                .map_err(|e| format!("tableformer: tags: {e}"))?;
233            match cache.a.as_ref() {
234                None => self.decoder.run(ort::inputs![
235                    "tags" => tags_t, "cross_k" => &enc.ck, "cross_v" => &enc.cv,
236                    "cache" => &empty.0]),
237                Some(c) => self.decoder.run(ort::inputs![
238                    "tags" => tags_t, "cross_k" => &enc.ck, "cross_v" => &enc.cv,
239                    "cache" => c]),
240            }
241        }
242        .map_err(|e| format!("tableformer: decode: {e}"))?;
243        let (_, logits) = dout["logits"]
244            .try_extract_tensor::<f32>()
245            .map_err(|e| format!("tableformer: logits: {e}"))?;
246        let raw = argmax(logits) as i64;
247        let (_, hidden) = dout["hidden"]
248            .try_extract_tensor::<f32>()
249            .map_err(|e| format!("tableformer: hidden: {e}"))?;
250        let hidden = hidden.to_vec();
251        if self.kv {
252            cache.a = Some(
253                dout.remove("out_cache_k")
254                    .ok_or_else(|| "tableformer: out_cache_k missing".to_string())?,
255            );
256            cache.b = Some(
257                dout.remove("out_cache_v")
258                    .ok_or_else(|| "tableformer: out_cache_v missing".to_string())?,
259            );
260        } else {
261            cache.a = Some(
262                dout.remove("out_cache")
263                    .ok_or_else(|| "tableformer: decoder output out_cache missing".to_string())?,
264            );
265        }
266        Ok((raw, hidden))
267    }
268
269    /// The zero-`past` first-step cache(s), allocated through the session
270    /// allocator (ort's array constructors reject a 0-length dim; the C API does
271    /// allow it).
272    fn empty_cache(&self) -> Result<EmptyCache, String> {
273        let alloc = self.decoder.allocator();
274        if self.kv {
275            let mk = || {
276                Tensor::<f32>::new(alloc, [N_LAYERS, 1, KV_HEADS, 0usize, KV_HEAD_DIM])
277                    .map_err(|e| format!("tableformer: empty kv cache: {e}"))
278            };
279            Ok((mk()?, Some(mk()?)))
280        } else {
281            let c = Tensor::<f32>::new(alloc, [N_LAYERS, 0usize, 1, EMBED_DIM])
282                .map_err(|e| format!("tableformer: empty cache: {e}"))?;
283            Ok((c, None))
284        }
285    }
286
287    /// Predict the OTSL structure-token sequence for a table-region image.
288    pub fn predict_otsl(&mut self, img: &RgbImage) -> Result<Vec<i64>, String> {
289        let enc = self.encode(img)?;
290        // The two structure corrections mirror docling's `predict` exactly — note
291        // its `line_num` is never incremented, so `xcel→lcel` applies on every row.
292        let mut tags: Vec<i64> = vec![START];
293        let mut out: Vec<i64> = Vec::new();
294        let mut prev_ucel = false;
295        let mut cache = DecodeCache::default();
296        let empty = self.empty_cache()?;
297        while out.len() < MAX_STEPS {
298            let (raw, _hidden) = self.decode_step(&tags, &enc, &mut cache, &empty)?;
299            let mut tag = raw;
300            if tag == XCEL {
301                tag = LCEL;
302            }
303            if prev_ucel && tag == LCEL {
304                tag = FCEL;
305            }
306            if tag == END {
307                break;
308            }
309            out.push(tag);
310            tags.push(tag);
311            prev_ucel = tag == UCEL;
312        }
313        Ok(out)
314    }
315
316    /// Full structure prediction: OTSL grid cells with per-cell boxes (in the 448
317    /// image, normalized cxcywh). Collects per-cell decoder hidden states using
318    /// docling's exact bbox bookkeeping (skip-after-row-break, first-lcel of a
319    /// horizontal span), runs the bbox decoder, merges span boxes, then lays the
320    /// cells onto the OTSL grid with row/col spans.
321    pub fn predict_table_structure(&mut self, img: &RgbImage) -> Result<Vec<TableCell>, String> {
322        let enc = self.encode(img)?;
323
324        let mut tags: Vec<i64> = vec![START];
325        let mut otsl: Vec<i64> = Vec::new();
326        let mut hiddens: Vec<f32> = Vec::new(); // flattened [n, 512]
327        let mut n = 0usize;
328        let mut prev_ucel = false;
329        let mut skip = true; // first tag after <start> is skipped
330        let mut first_lcel = true;
331        let mut bbox_ind = 0usize;
332        let mut cur_bbox_ind = 0usize;
333        let mut merge: std::collections::HashMap<usize, i64> = std::collections::HashMap::new();
334        let mut cache = DecodeCache::default();
335        let empty = self.empty_cache()?;
336        while otsl.len() < MAX_STEPS {
337            let (raw, hidden) = self.decode_step(&tags, &enc, &mut cache, &empty)?;
338            let mut tag = raw;
339            if tag == XCEL {
340                tag = LCEL;
341            }
342            if prev_ucel && tag == LCEL {
343                tag = FCEL;
344            }
345            if tag == END {
346                break;
347            }
348            // docling's tag_H_buf / bboxes_to_merge bookkeeping.
349            if !skip && matches!(tag, FCEL | ECEL | CHED | RHED | SROW | NL | UCEL) {
350                hiddens.extend_from_slice(&hidden);
351                n += 1;
352                if !first_lcel {
353                    merge.insert(cur_bbox_ind, bbox_ind as i64);
354                }
355                bbox_ind += 1;
356            }
357            if tag != LCEL {
358                first_lcel = true;
359            } else if first_lcel {
360                hiddens.extend_from_slice(&hidden);
361                n += 1;
362                first_lcel = false;
363                cur_bbox_ind = bbox_ind;
364                merge.insert(cur_bbox_ind, -1);
365                bbox_ind += 1;
366            }
367            skip = matches!(tag, NL | UCEL | XCEL);
368            prev_ucel = tag == UCEL;
369            otsl.push(tag);
370            tags.push(tag);
371        }
372        if n == 0 {
373            return Ok(Vec::new());
374        }
375        let tag_h = Tensor::from_array(([n, 512usize], hiddens))
376            .map_err(|e| format!("tableformer: tag_h: {e}"))?;
377        let bout = self
378            .bbox
379            .run(ort::inputs!["enc_out" => &enc.eo, "tag_h" => tag_h])
380            .map_err(|e| format!("tableformer: bbox: {e}"))?;
381        let (_, raw) = bout["boxes"]
382            .try_extract_tensor::<f32>()
383            .map_err(|e| format!("tableformer: boxes: {e}"))?;
384        let boxes: Vec<[f32; 4]> = raw
385            .chunks_exact(4)
386            .map(|c| [c[0], c[1], c[2], c[3]])
387            .collect();
388        let merged = merge_spans(&boxes, &merge);
389        Ok(build_table_cells(&otsl, &merged))
390    }
391
392    /// Predict a table region's Markdown grid: crop the region (docling's
393    /// page→1024px box-average then bbox crop), run the structure model, map each
394    /// cell box back to page points, match the page's word cells into cells by
395    /// intersection-over-word-area, and expand spans into a dense `rows × cols`
396    /// grid. `region` is `(l, t, r, b)` in page points (top-left). Returns `None`
397    /// if no structure is predicted.
398    pub fn predict_table_rows(
399        &mut self,
400        page_image: &RgbImage,
401        page_h: f32,
402        region: [f32; 4],
403        words: &[TextCell],
404    ) -> Option<Vec<Vec<String>>> {
405        // page → 1024px height (cv2.INTER_AREA), then crop the table bbox.
406        let sf = 1024.0 / page_image.height() as f32;
407        let pw = (page_image.width() as f32 * sf) as u32;
408        let page1024 = crate::timing::timed("tableformer.inter_area", || {
409            crate::resample::inter_area(page_image, pw, 1024)
410        });
411        let k = 1024.0 / page_h;
412        let x = (region[0] * k).round().max(0.0) as u32;
413        let y = (region[1] * k).round().max(0.0) as u32;
414        let x2 = ((region[2] * k).round() as u32).min(page1024.width());
415        let y2 = ((region[3] * k).round() as u32).min(page1024.height());
416        if x2 <= x || y2 <= y {
417            return None;
418        }
419        let crop = image::imageops::crop_imm(&page1024, x, y, x2 - x, y2 - y).to_image();
420        let cells = crate::timing::timed("tableformer.structure", || {
421            self.predict_table_structure(&crop)
422        })
423        .ok()?;
424        if cells.is_empty() {
425            return None;
426        }
427        let (rw, rh) = (region[2] - region[0], region[3] - region[1]);
428
429        // Cell boxes in page points (top-left), aligned with `cells`.
430        let boxes: Vec<[f32; 4]> = cells
431            .iter()
432            .map(|c| {
433                [
434                    region[0] + (c.cx - c.w / 2.0) * rw,
435                    region[1] + (c.cy - c.h / 2.0) * rh,
436                    region[0] + (c.cx + c.w / 2.0) * rw,
437                    region[1] + (c.cy + c.h / 2.0) * rh,
438                ]
439            })
440            .collect();
441
442        // Assign each word to the cell it overlaps most (intersection / word area).
443        let mut cell_words: Vec<Vec<usize>> = vec![Vec::new(); cells.len()];
444        for (wi, w) in words.iter().enumerate() {
445            let wa = ((w.r - w.l) * (w.b - w.t)).max(1.0);
446            let mut best: Option<(f32, usize)> = None;
447            for (ci, b) in boxes.iter().enumerate() {
448                let ix = (w.r.min(b[2]) - w.l.max(b[0])).max(0.0);
449                let iy = (w.b.min(b[3]) - w.t.max(b[1])).max(0.0);
450                let io = ix * iy / wa;
451                if io > 0.0 && best.is_none_or(|(bo, _)| io > bo) {
452                    best = Some((io, ci));
453                }
454            }
455            if let Some((_, ci)) = best {
456                cell_words[ci].push(wi);
457            }
458        }
459
460        let num_rows = cells.iter().map(|c| c.row + c.rowspan).max().unwrap_or(0);
461        let num_cols = cells.iter().map(|c| c.col + c.colspan).max().unwrap_or(0);
462        if num_rows == 0 || num_cols == 0 {
463            return None;
464        }
465        let mut grid = vec![vec![String::new(); num_cols]; num_rows];
466        for (ci, c) in cells.iter().enumerate() {
467            // Keep words in text-stream order (the order they were collected =
468            // their word index), matching docling's cell text assembly — geometric
469            // re-sorting scrambles wrapped cells (`Inference time (secs)`).
470            let wis = std::mem::take(&mut cell_words[ci]);
471            let text = wis
472                .iter()
473                .map(|&i| words[i].text.trim())
474                .collect::<Vec<_>>()
475                .join(" ");
476            // Spanned cells repeat their text across the covered grid positions.
477            for row in grid.iter_mut().skip(c.row).take(c.rowspan) {
478                for cell in row.iter_mut().skip(c.col).take(c.colspan) {
479                    *cell = text.clone();
480                }
481            }
482        }
483        Some(grid)
484    }
485}
486
487/// Note once per process that TableFormer's ONNX graphs weren't found, so tables
488/// fall back to geometric reconstruction. The default paths are relative
489/// (`models/tableformer/*.onnx`), which only resolves when the process's current
490/// directory happens to be the repo root — a very easy miss for anything else
491/// (an embedding app, a binding invoked from a different working directory, …),
492/// and previously failed with no signal at all.
493fn warn_missing_once(enc: &str, dec: &str, bbx: &str) {
494    static WARNED: std::sync::Once = std::sync::Once::new();
495    WARNED.call_once(|| {
496        eprintln!(
497            "fleischwolf: TableFormer models not found (checked {enc}, {dec}, {bbx}); \
498             tables will use geometric reconstruction instead of ML table-structure \
499             recognition. Set DOCLING_TABLEFORMER_ENCODER / DOCLING_TABLEFORMER_DECODER \
500             / DOCLING_TABLEFORMER_BBOX to enable it (see README.md)."
501        );
502    });
503}
504
505/// docling's preprocessing: bilinear (cv2.INTER_LINEAR) resize the crop to 448²,
506/// normalize `(x/255 − mean)/std`, laid out as (C, W, H) — docling transposes
507/// (2,1,0), so width is the major spatial axis. The page→1024px box-average
508/// (cv2.INTER_AREA) is the caller's job.
509fn preprocess(img: &RgbImage) -> Result<Tensor<f32>, String> {
510    let nn = (SIDE * SIDE) as usize;
511    let side = SIDE as usize;
512    let (sw, sh) = (img.width() as i32, img.height() as i32);
513    let sxr = sw as f32 / SIDE as f32;
514    let syr = sh as f32 / SIDE as f32;
515    let mut data = vec![0f32; 3 * nn];
516    for h in 0..side {
517        let fy = (h as f32 + 0.5) * syr - 0.5;
518        let wy = fy - fy.floor();
519        let y0c = (fy.floor() as i32).clamp(0, sh - 1) as u32;
520        let y1c = (fy.floor() as i32 + 1).clamp(0, sh - 1) as u32;
521        for w in 0..side {
522            let fx = (w as f32 + 0.5) * sxr - 0.5;
523            let wx = fx - fx.floor();
524            let x0c = (fx.floor() as i32).clamp(0, sw - 1) as u32;
525            let x1c = (fx.floor() as i32 + 1).clamp(0, sw - 1) as u32;
526            let p00 = img.get_pixel(x0c, y0c);
527            let p01 = img.get_pixel(x1c, y0c);
528            let p10 = img.get_pixel(x0c, y1c);
529            let p11 = img.get_pixel(x1c, y1c);
530            let idx = w * side + h; // (C, W, H): c*n + w*H + h
531            for c in 0..3 {
532                let top = p00[c] as f32 * (1.0 - wx) + p01[c] as f32 * wx;
533                let bot = p10[c] as f32 * (1.0 - wx) + p11[c] as f32 * wx;
534                let v = top * (1.0 - wy) + bot * wy;
535                data[c * nn + idx] = (v / 255.0 - MEAN[c]) / STD[c];
536            }
537        }
538    }
539    Tensor::from_array(([1usize, 3, side, side], data))
540        .map_err(|e| format!("tableformer: input: {e}"))
541}
542
543/// docling's `mergebboxes` (cxcywh): the union box of a horizontal span's first
544/// and last cell.
545fn mergebboxes(b1: [f32; 4], b2: [f32; 4]) -> [f32; 4] {
546    let new_w = (b2[0] + b2[2] / 2.0) - (b1[0] - b1[2] / 2.0);
547    let new_h = (b2[1] + b2[3] / 2.0) - (b1[1] - b1[3] / 2.0);
548    let new_left = b1[0] - b1[2] / 2.0;
549    let new_top = (b2[1] - b2[3] / 2.0).min(b1[1] - b1[3] / 2.0);
550    [new_left + new_w / 2.0, new_top + new_h / 2.0, new_w, new_h]
551}
552
553/// Apply docling's span merges: each merge key combines its box with the partner
554/// (`-1` → the last box); partners are dropped.
555fn merge_spans(boxes: &[[f32; 4]], merge: &std::collections::HashMap<usize, i64>) -> Vec<[f32; 4]> {
556    let skip: std::collections::HashSet<usize> = merge
557        .values()
558        .filter(|&&v| v >= 0)
559        .map(|&v| v as usize)
560        .collect();
561    let mut out = Vec::new();
562    for (i, &b) in boxes.iter().enumerate() {
563        if let Some(&j) = merge.get(&i) {
564            let partner = if j < 0 { boxes.len() - 1 } else { j as usize };
565            out.push(mergebboxes(b, boxes[partner.min(boxes.len() - 1)]));
566        } else if !skip.contains(&i) {
567            out.push(b);
568        }
569    }
570    out
571}
572
573const CELL_TAGS: [i64; 6] = [FCEL, ECEL, XCEL, CHED, RHED, SROW];
574
575/// Lay the OTSL tag stream onto a grid (docling's `_build_table_cells`, OTSL
576/// mode): cell tags create cells at (row, col); `lcel`/`ucel`/`xcel` are spans
577/// (counted toward the column index but not cells). Colspan/rowspan are read off
578/// the grid (consecutive `lcel`/`ucel` to the right/below). `boxes` are indexed
579/// by cell order and aligned with the cells.
580fn build_table_cells(otsl: &[i64], boxes: &[[f32; 4]]) -> Vec<TableCell> {
581    // 2D grid of tags (rows split on NL) for span lookups.
582    let mut grid: Vec<Vec<i64>> = vec![Vec::new()];
583    for &t in otsl {
584        if t == NL {
585            grid.push(Vec::new());
586        } else {
587            grid.last_mut().unwrap().push(t);
588        }
589    }
590    let mut cells = Vec::new();
591    let mut cell_id = 0usize;
592    for (r, row) in grid.iter().enumerate() {
593        for (c, &tag) in row.iter().enumerate() {
594            if !CELL_TAGS.contains(&tag) {
595                continue;
596            }
597            let mut colspan = 1;
598            while c + colspan < row.len() && matches!(row[c + colspan], LCEL | XCEL) {
599                colspan += 1;
600            }
601            let mut rowspan = 1;
602            while r + rowspan < grid.len()
603                && grid[r + rowspan]
604                    .get(c)
605                    .is_some_and(|&t| matches!(t, UCEL | XCEL))
606            {
607                rowspan += 1;
608            }
609            let b = boxes.get(cell_id).copied().unwrap_or([0.0; 4]);
610            cells.push(TableCell {
611                row: r,
612                col: c,
613                colspan,
614                rowspan,
615                tag,
616                cx: b[0],
617                cy: b[1],
618                w: b[2],
619                h: b[3],
620            });
621            cell_id += 1;
622        }
623    }
624    cells
625}
626
627fn argmax(v: &[f32]) -> usize {
628    v.iter()
629        .enumerate()
630        .max_by(|a, b| a.1.total_cmp(b.1))
631        .map(|(i, _)| i)
632        .unwrap_or(0)
633}