1use crate::pdfium_backend::TextCell;
8use image::RgbImage;
9use ort::session::Session;
10use ort::value::{DynValue, Tensor};
11
12const SIDE: u32 = 448;
13#[allow(clippy::excessive_precision)]
16const MEAN: [f32; 3] = [0.94247851, 0.94254675, 0.94292611];
17#[allow(clippy::excessive_precision)]
18const STD: [f32; 3] = [0.17910956, 0.17940403, 0.17931663];
19const MAX_STEPS: usize = 1024;
20const N_LAYERS: usize = 6;
23const EMBED_DIM: usize = 512;
24
25pub const START: i64 = 2;
27pub const END: i64 = 3;
28pub const ECEL: i64 = 4; pub const FCEL: i64 = 5; pub const LCEL: i64 = 6; pub const UCEL: i64 = 7; pub const XCEL: i64 = 8; pub const NL: i64 = 9; pub const CHED: i64 = 10; pub const RHED: i64 = 11; pub const SROW: i64 = 12; #[derive(Debug, Clone)]
41pub struct TableCell {
42 pub row: usize,
43 pub col: usize,
44 pub colspan: usize,
45 pub rowspan: usize,
46 pub tag: i64,
47 pub cx: f32,
48 pub cy: f32,
49 pub w: f32,
50 pub h: f32,
51}
52
53pub struct TableFormer {
54 encoder: Session,
55 decoder: Session,
56 bbox: Session,
57 kv: bool,
63}
64
65const KV_HEADS: usize = 8;
68const KV_HEAD_DIM: usize = 64;
69
70#[derive(Default)]
74struct DecodeCache {
75 a: Option<DynValue>,
76 b: Option<DynValue>,
77}
78
79type EmptyCache = (Tensor<f32>, Option<Tensor<f32>>);
82
83struct EncodeOut {
88 ck: DynValue,
89 cv: DynValue,
90 eo: DynValue,
91}
92
93impl TableFormer {
94 pub fn load() -> Option<Self> {
98 Self::load_with(crate::intra_threads())
99 }
100
101 pub fn load_with(intra: usize) -> Option<Self> {
105 let enc = std::env::var("DOCLING_TABLEFORMER_ENCODER")
106 .unwrap_or_else(|_| crate::resolve_asset("models/tableformer/encoder.onnx"));
107 let dec = std::env::var("DOCLING_TABLEFORMER_DECODER").unwrap_or_else(|_| {
116 let candidates: &[&str] = if crate::fp32_forced() {
117 &[
118 "models/tableformer/decoder.onnx",
119 "models/tableformer/decoder_kv.onnx",
120 ]
121 } else {
122 &[
123 "models/tableformer/decoder_int8.onnx",
124 "models/tableformer/decoder_kv_int8.onnx",
125 "models/tableformer/decoder.onnx",
126 "models/tableformer/decoder_kv.onnx",
127 ]
128 };
129 candidates
130 .iter()
131 .map(|p| crate::resolve_asset(p))
132 .find(|p| std::path::Path::new(p).exists())
133 .unwrap_or_else(|| "models/tableformer/decoder.onnx".to_string())
134 });
135 let bbx = std::env::var("DOCLING_TABLEFORMER_BBOX")
136 .unwrap_or_else(|_| crate::resolve_asset("models/tableformer/bbox.onnx"));
137 if [&enc, &dec, &bbx]
138 .iter()
139 .any(|p| !std::path::Path::new(p).exists())
140 {
141 warn_missing_once(&enc, &dec, &bbx);
148 return None;
149 }
150 let build = |path: &str, mem_pattern: bool| -> Result<Session, String> {
156 Session::builder()
157 .map_err(|e| e.to_string())?
158 .with_intra_threads(intra)
159 .map_err(|e| e.to_string())?
160 .with_memory_pattern(mem_pattern)
161 .map_err(|e| e.to_string())?
162 .commit_from_file(path)
163 .map_err(|e| format!("tableformer load {path}: {e}"))
164 };
165 match (build(&enc, true), build(&dec, false), build(&bbx, true)) {
166 (Ok(encoder), Ok(decoder), Ok(bbox)) => {
167 let kv = decoder.inputs().iter().any(|i| i.name() == "cache_k");
168 Some(Self {
169 encoder,
170 decoder,
171 bbox,
172 kv,
173 })
174 }
175 _ => None,
176 }
177 }
178
179 fn encode(&mut self, img: &RgbImage) -> Result<EncodeOut, String> {
183 let input = preprocess(img)?;
184 let mut enc_out = self
185 .encoder
186 .run(ort::inputs!["image" => input])
187 .map_err(|e| format!("tableformer: encode: {e}"))?;
188 let mut grab = |name: &str| -> Result<DynValue, String> {
189 enc_out
190 .remove(name)
191 .ok_or_else(|| format!("tableformer: encoder output {name} missing"))
192 };
193 Ok(EncodeOut {
194 ck: grab("cross_k")?,
195 cv: grab("cross_v")?,
196 eo: grab("enc_out")?,
197 })
198 }
199
200 fn decode_step(
209 &mut self,
210 tags: &[i64],
211 enc: &EncodeOut,
212 cache: &mut DecodeCache,
213 empty: &EmptyCache,
214 ) -> Result<(i64, Vec<f32>), String> {
215 let mut dout = if self.kv {
216 let last = *tags.last().expect("decode starts from <start>");
219 let tag_t = Tensor::from_array(([1usize, 1usize], vec![last]))
220 .map_err(|e| format!("tableformer: tag: {e}"))?;
221 match (cache.a.as_ref(), cache.b.as_ref()) {
222 (Some(k), Some(v)) => self.decoder.run(ort::inputs![
223 "tag" => tag_t, "cross_k" => &enc.ck, "cross_v" => &enc.cv,
224 "cache_k" => k, "cache_v" => v]),
225 _ => self.decoder.run(ort::inputs![
226 "tag" => tag_t, "cross_k" => &enc.ck, "cross_v" => &enc.cv,
227 "cache_k" => &empty.0,
228 "cache_v" => empty.1.as_ref().expect("kv empty cache has both halves")]),
229 }
230 } else {
231 let tags_t = Tensor::from_array(([tags.len(), 1usize], tags.to_vec()))
232 .map_err(|e| format!("tableformer: tags: {e}"))?;
233 match cache.a.as_ref() {
234 None => self.decoder.run(ort::inputs![
235 "tags" => tags_t, "cross_k" => &enc.ck, "cross_v" => &enc.cv,
236 "cache" => &empty.0]),
237 Some(c) => self.decoder.run(ort::inputs![
238 "tags" => tags_t, "cross_k" => &enc.ck, "cross_v" => &enc.cv,
239 "cache" => c]),
240 }
241 }
242 .map_err(|e| format!("tableformer: decode: {e}"))?;
243 let (_, logits) = dout["logits"]
244 .try_extract_tensor::<f32>()
245 .map_err(|e| format!("tableformer: logits: {e}"))?;
246 let raw = argmax(logits) as i64;
247 let (_, hidden) = dout["hidden"]
248 .try_extract_tensor::<f32>()
249 .map_err(|e| format!("tableformer: hidden: {e}"))?;
250 let hidden = hidden.to_vec();
251 if self.kv {
252 cache.a = Some(
253 dout.remove("out_cache_k")
254 .ok_or_else(|| "tableformer: out_cache_k missing".to_string())?,
255 );
256 cache.b = Some(
257 dout.remove("out_cache_v")
258 .ok_or_else(|| "tableformer: out_cache_v missing".to_string())?,
259 );
260 } else {
261 cache.a = Some(
262 dout.remove("out_cache")
263 .ok_or_else(|| "tableformer: decoder output out_cache missing".to_string())?,
264 );
265 }
266 Ok((raw, hidden))
267 }
268
269 fn empty_cache(&self) -> Result<EmptyCache, String> {
273 let alloc = self.decoder.allocator();
274 if self.kv {
275 let mk = || {
276 Tensor::<f32>::new(alloc, [N_LAYERS, 1, KV_HEADS, 0usize, KV_HEAD_DIM])
277 .map_err(|e| format!("tableformer: empty kv cache: {e}"))
278 };
279 Ok((mk()?, Some(mk()?)))
280 } else {
281 let c = Tensor::<f32>::new(alloc, [N_LAYERS, 0usize, 1, EMBED_DIM])
282 .map_err(|e| format!("tableformer: empty cache: {e}"))?;
283 Ok((c, None))
284 }
285 }
286
287 pub fn predict_otsl(&mut self, img: &RgbImage) -> Result<Vec<i64>, String> {
289 let enc = self.encode(img)?;
290 let mut tags: Vec<i64> = vec![START];
293 let mut out: Vec<i64> = Vec::new();
294 let mut prev_ucel = false;
295 let mut cache = DecodeCache::default();
296 let empty = self.empty_cache()?;
297 while out.len() < MAX_STEPS {
298 let (raw, _hidden) = self.decode_step(&tags, &enc, &mut cache, &empty)?;
299 let mut tag = raw;
300 if tag == XCEL {
301 tag = LCEL;
302 }
303 if prev_ucel && tag == LCEL {
304 tag = FCEL;
305 }
306 if tag == END {
307 break;
308 }
309 out.push(tag);
310 tags.push(tag);
311 prev_ucel = tag == UCEL;
312 }
313 Ok(out)
314 }
315
316 pub fn predict_table_structure(&mut self, img: &RgbImage) -> Result<Vec<TableCell>, String> {
322 let enc = self.encode(img)?;
323
324 let mut tags: Vec<i64> = vec![START];
325 let mut otsl: Vec<i64> = Vec::new();
326 let mut hiddens: Vec<f32> = Vec::new(); let mut n = 0usize;
328 let mut prev_ucel = false;
329 let mut skip = true; let mut first_lcel = true;
331 let mut bbox_ind = 0usize;
332 let mut cur_bbox_ind = 0usize;
333 let mut merge: std::collections::HashMap<usize, i64> = std::collections::HashMap::new();
334 let mut cache = DecodeCache::default();
335 let empty = self.empty_cache()?;
336 while otsl.len() < MAX_STEPS {
337 let (raw, hidden) = self.decode_step(&tags, &enc, &mut cache, &empty)?;
338 let mut tag = raw;
339 if tag == XCEL {
340 tag = LCEL;
341 }
342 if prev_ucel && tag == LCEL {
343 tag = FCEL;
344 }
345 if tag == END {
346 break;
347 }
348 if !skip && matches!(tag, FCEL | ECEL | CHED | RHED | SROW | NL | UCEL) {
350 hiddens.extend_from_slice(&hidden);
351 n += 1;
352 if !first_lcel {
353 merge.insert(cur_bbox_ind, bbox_ind as i64);
354 }
355 bbox_ind += 1;
356 }
357 if tag != LCEL {
358 first_lcel = true;
359 } else if first_lcel {
360 hiddens.extend_from_slice(&hidden);
361 n += 1;
362 first_lcel = false;
363 cur_bbox_ind = bbox_ind;
364 merge.insert(cur_bbox_ind, -1);
365 bbox_ind += 1;
366 }
367 skip = matches!(tag, NL | UCEL | XCEL);
368 prev_ucel = tag == UCEL;
369 otsl.push(tag);
370 tags.push(tag);
371 }
372 if n == 0 {
373 return Ok(Vec::new());
374 }
375 let tag_h = Tensor::from_array(([n, 512usize], hiddens))
376 .map_err(|e| format!("tableformer: tag_h: {e}"))?;
377 let bout = self
378 .bbox
379 .run(ort::inputs!["enc_out" => &enc.eo, "tag_h" => tag_h])
380 .map_err(|e| format!("tableformer: bbox: {e}"))?;
381 let (_, raw) = bout["boxes"]
382 .try_extract_tensor::<f32>()
383 .map_err(|e| format!("tableformer: boxes: {e}"))?;
384 let boxes: Vec<[f32; 4]> = raw
385 .chunks_exact(4)
386 .map(|c| [c[0], c[1], c[2], c[3]])
387 .collect();
388 let merged = merge_spans(&boxes, &merge);
389 Ok(build_table_cells(&otsl, &merged))
390 }
391
392 pub fn predict_table_rows(
399 &mut self,
400 page_image: &RgbImage,
401 page_h: f32,
402 region: [f32; 4],
403 words: &[TextCell],
404 ) -> Option<Vec<Vec<String>>> {
405 let sf = 1024.0 / page_image.height() as f32;
407 let pw = (page_image.width() as f32 * sf) as u32;
408 let page1024 = crate::timing::timed("tableformer.inter_area", || {
409 crate::resample::inter_area(page_image, pw, 1024)
410 });
411 let k = 1024.0 / page_h;
412 let x = (region[0] * k).round().max(0.0) as u32;
413 let y = (region[1] * k).round().max(0.0) as u32;
414 let x2 = ((region[2] * k).round() as u32).min(page1024.width());
415 let y2 = ((region[3] * k).round() as u32).min(page1024.height());
416 if x2 <= x || y2 <= y {
417 return None;
418 }
419 let crop = image::imageops::crop_imm(&page1024, x, y, x2 - x, y2 - y).to_image();
420 let cells = crate::timing::timed("tableformer.structure", || {
421 self.predict_table_structure(&crop)
422 })
423 .ok()?;
424 if cells.is_empty() {
425 return None;
426 }
427 let (rw, rh) = (region[2] - region[0], region[3] - region[1]);
428
429 let boxes: Vec<[f32; 4]> = cells
431 .iter()
432 .map(|c| {
433 [
434 region[0] + (c.cx - c.w / 2.0) * rw,
435 region[1] + (c.cy - c.h / 2.0) * rh,
436 region[0] + (c.cx + c.w / 2.0) * rw,
437 region[1] + (c.cy + c.h / 2.0) * rh,
438 ]
439 })
440 .collect();
441
442 let mut cell_words: Vec<Vec<usize>> = vec![Vec::new(); cells.len()];
444 for (wi, w) in words.iter().enumerate() {
445 let wa = ((w.r - w.l) * (w.b - w.t)).max(1.0);
446 let mut best: Option<(f32, usize)> = None;
447 for (ci, b) in boxes.iter().enumerate() {
448 let ix = (w.r.min(b[2]) - w.l.max(b[0])).max(0.0);
449 let iy = (w.b.min(b[3]) - w.t.max(b[1])).max(0.0);
450 let io = ix * iy / wa;
451 if io > 0.0 && best.is_none_or(|(bo, _)| io > bo) {
452 best = Some((io, ci));
453 }
454 }
455 if let Some((_, ci)) = best {
456 cell_words[ci].push(wi);
457 }
458 }
459
460 let num_rows = cells.iter().map(|c| c.row + c.rowspan).max().unwrap_or(0);
461 let num_cols = cells.iter().map(|c| c.col + c.colspan).max().unwrap_or(0);
462 if num_rows == 0 || num_cols == 0 {
463 return None;
464 }
465 let mut grid = vec![vec![String::new(); num_cols]; num_rows];
466 for (ci, c) in cells.iter().enumerate() {
467 let wis = std::mem::take(&mut cell_words[ci]);
471 let text = wis
472 .iter()
473 .map(|&i| words[i].text.trim())
474 .collect::<Vec<_>>()
475 .join(" ");
476 for row in grid.iter_mut().skip(c.row).take(c.rowspan) {
478 for cell in row.iter_mut().skip(c.col).take(c.colspan) {
479 *cell = text.clone();
480 }
481 }
482 }
483 Some(grid)
484 }
485}
486
487fn warn_missing_once(enc: &str, dec: &str, bbx: &str) {
494 static WARNED: std::sync::Once = std::sync::Once::new();
495 WARNED.call_once(|| {
496 eprintln!(
497 "fleischwolf: TableFormer models not found (checked {enc}, {dec}, {bbx}); \
498 tables will use geometric reconstruction instead of ML table-structure \
499 recognition. Set DOCLING_TABLEFORMER_ENCODER / DOCLING_TABLEFORMER_DECODER \
500 / DOCLING_TABLEFORMER_BBOX to enable it (see README.md)."
501 );
502 });
503}
504
505fn preprocess(img: &RgbImage) -> Result<Tensor<f32>, String> {
510 let nn = (SIDE * SIDE) as usize;
511 let side = SIDE as usize;
512 let (sw, sh) = (img.width() as i32, img.height() as i32);
513 let sxr = sw as f32 / SIDE as f32;
514 let syr = sh as f32 / SIDE as f32;
515 let mut data = vec![0f32; 3 * nn];
516 for h in 0..side {
517 let fy = (h as f32 + 0.5) * syr - 0.5;
518 let wy = fy - fy.floor();
519 let y0c = (fy.floor() as i32).clamp(0, sh - 1) as u32;
520 let y1c = (fy.floor() as i32 + 1).clamp(0, sh - 1) as u32;
521 for w in 0..side {
522 let fx = (w as f32 + 0.5) * sxr - 0.5;
523 let wx = fx - fx.floor();
524 let x0c = (fx.floor() as i32).clamp(0, sw - 1) as u32;
525 let x1c = (fx.floor() as i32 + 1).clamp(0, sw - 1) as u32;
526 let p00 = img.get_pixel(x0c, y0c);
527 let p01 = img.get_pixel(x1c, y0c);
528 let p10 = img.get_pixel(x0c, y1c);
529 let p11 = img.get_pixel(x1c, y1c);
530 let idx = w * side + h; for c in 0..3 {
532 let top = p00[c] as f32 * (1.0 - wx) + p01[c] as f32 * wx;
533 let bot = p10[c] as f32 * (1.0 - wx) + p11[c] as f32 * wx;
534 let v = top * (1.0 - wy) + bot * wy;
535 data[c * nn + idx] = (v / 255.0 - MEAN[c]) / STD[c];
536 }
537 }
538 }
539 Tensor::from_array(([1usize, 3, side, side], data))
540 .map_err(|e| format!("tableformer: input: {e}"))
541}
542
543fn mergebboxes(b1: [f32; 4], b2: [f32; 4]) -> [f32; 4] {
546 let new_w = (b2[0] + b2[2] / 2.0) - (b1[0] - b1[2] / 2.0);
547 let new_h = (b2[1] + b2[3] / 2.0) - (b1[1] - b1[3] / 2.0);
548 let new_left = b1[0] - b1[2] / 2.0;
549 let new_top = (b2[1] - b2[3] / 2.0).min(b1[1] - b1[3] / 2.0);
550 [new_left + new_w / 2.0, new_top + new_h / 2.0, new_w, new_h]
551}
552
553fn merge_spans(boxes: &[[f32; 4]], merge: &std::collections::HashMap<usize, i64>) -> Vec<[f32; 4]> {
556 let skip: std::collections::HashSet<usize> = merge
557 .values()
558 .filter(|&&v| v >= 0)
559 .map(|&v| v as usize)
560 .collect();
561 let mut out = Vec::new();
562 for (i, &b) in boxes.iter().enumerate() {
563 if let Some(&j) = merge.get(&i) {
564 let partner = if j < 0 { boxes.len() - 1 } else { j as usize };
565 out.push(mergebboxes(b, boxes[partner.min(boxes.len() - 1)]));
566 } else if !skip.contains(&i) {
567 out.push(b);
568 }
569 }
570 out
571}
572
573const CELL_TAGS: [i64; 6] = [FCEL, ECEL, XCEL, CHED, RHED, SROW];
574
575fn build_table_cells(otsl: &[i64], boxes: &[[f32; 4]]) -> Vec<TableCell> {
581 let mut grid: Vec<Vec<i64>> = vec![Vec::new()];
583 for &t in otsl {
584 if t == NL {
585 grid.push(Vec::new());
586 } else {
587 grid.last_mut().unwrap().push(t);
588 }
589 }
590 let mut cells = Vec::new();
591 let mut cell_id = 0usize;
592 for (r, row) in grid.iter().enumerate() {
593 for (c, &tag) in row.iter().enumerate() {
594 if !CELL_TAGS.contains(&tag) {
595 continue;
596 }
597 let mut colspan = 1;
598 while c + colspan < row.len() && matches!(row[c + colspan], LCEL | XCEL) {
599 colspan += 1;
600 }
601 let mut rowspan = 1;
602 while r + rowspan < grid.len()
603 && grid[r + rowspan]
604 .get(c)
605 .is_some_and(|&t| matches!(t, UCEL | XCEL))
606 {
607 rowspan += 1;
608 }
609 let b = boxes.get(cell_id).copied().unwrap_or([0.0; 4]);
610 cells.push(TableCell {
611 row: r,
612 col: c,
613 colspan,
614 rowspan,
615 tag,
616 cx: b[0],
617 cy: b[1],
618 w: b[2],
619 h: b[3],
620 });
621 cell_id += 1;
622 }
623 }
624 cells
625}
626
627fn argmax(v: &[f32]) -> usize {
628 v.iter()
629 .enumerate()
630 .max_by(|a, b| a.1.total_cmp(b.1))
631 .map(|(i, _)| i)
632 .unwrap_or(0)
633}