1use crate::pdfium_backend::TextCell;
8use image::RgbImage;
9use ort::session::Session;
10use ort::value::{Tensor, TensorRef};
11
12const SIDE: u32 = 448;
13#[allow(clippy::excessive_precision)]
16const MEAN: [f32; 3] = [0.94247851, 0.94254675, 0.94292611];
17#[allow(clippy::excessive_precision)]
18const STD: [f32; 3] = [0.17910956, 0.17940403, 0.17931663];
19const MAX_STEPS: usize = 1024;
20const N_LAYERS: usize = 6;
23const EMBED_DIM: usize = 512;
24
25pub const START: i64 = 2;
27pub const END: i64 = 3;
28pub const ECEL: i64 = 4; pub const FCEL: i64 = 5; pub const LCEL: i64 = 6; pub const UCEL: i64 = 7; pub const XCEL: i64 = 8; pub const NL: i64 = 9; pub const CHED: i64 = 10; pub const RHED: i64 = 11; pub const SROW: i64 = 12; #[derive(Debug, Clone)]
41pub struct TableCell {
42 pub row: usize,
43 pub col: usize,
44 pub colspan: usize,
45 pub rowspan: usize,
46 pub tag: i64,
47 pub cx: f32,
48 pub cy: f32,
49 pub w: f32,
50 pub h: f32,
51}
52
53pub struct TableFormer {
54 encoder: Session,
55 decoder: Session,
56 bbox: Session,
57}
58
59struct EncodeOut {
63 ck_shape: Vec<usize>,
64 ck: Vec<f32>,
65 cv_shape: Vec<usize>,
66 cv: Vec<f32>,
67 eo_shape: Vec<usize>,
68 eo: Vec<f32>,
69}
70
71impl TableFormer {
72 pub fn load() -> Option<Self> {
76 Self::load_with(crate::intra_threads())
77 }
78
79 pub fn load_with(intra: usize) -> Option<Self> {
83 let enc = std::env::var("DOCLING_TABLEFORMER_ENCODER")
84 .unwrap_or_else(|_| "models/tableformer/encoder.onnx".to_string());
85 let dec = std::env::var("DOCLING_TABLEFORMER_DECODER")
86 .unwrap_or_else(|_| "models/tableformer/decoder.onnx".to_string());
87 let bbx = std::env::var("DOCLING_TABLEFORMER_BBOX")
88 .unwrap_or_else(|_| "models/tableformer/bbox.onnx".to_string());
89 if [&enc, &dec, &bbx]
90 .iter()
91 .any(|p| !std::path::Path::new(p).exists())
92 {
93 warn_missing_once(&enc, &dec, &bbx);
100 return None;
101 }
102 let build = |path: &str, mem_pattern: bool| -> Result<Session, String> {
108 Session::builder()
109 .map_err(|e| e.to_string())?
110 .with_intra_threads(intra)
111 .map_err(|e| e.to_string())?
112 .with_memory_pattern(mem_pattern)
113 .map_err(|e| e.to_string())?
114 .commit_from_file(path)
115 .map_err(|e| format!("tableformer load {path}: {e}"))
116 };
117 match (build(&enc, true), build(&dec, false), build(&bbx, true)) {
118 (Ok(encoder), Ok(decoder), Ok(bbox)) => Some(Self {
119 encoder,
120 decoder,
121 bbox,
122 }),
123 _ => None,
124 }
125 }
126
127 fn encode(&mut self, img: &RgbImage) -> Result<EncodeOut, String> {
131 let input = preprocess(img)?;
132 let enc_out = self
133 .encoder
134 .run(ort::inputs!["image" => input])
135 .map_err(|e| format!("tableformer: encode: {e}"))?;
136 let grab = |name: &str| -> Result<(Vec<usize>, Vec<f32>), String> {
137 let (sh, data) = enc_out[name]
138 .try_extract_tensor::<f32>()
139 .map_err(|e| format!("tableformer: {name}: {e}"))?;
140 Ok((sh.iter().map(|&x| x as usize).collect(), data.to_vec()))
141 };
142 let (ck_shape, ck) = grab("cross_k")?;
143 let (cv_shape, cv) = grab("cross_v")?;
144 let (eo_shape, eo) = grab("enc_out")?;
145 Ok(EncodeOut {
146 ck_shape,
147 ck,
148 cv_shape,
149 cv,
150 eo_shape,
151 eo,
152 })
153 }
154
155 fn decode_step(
161 &mut self,
162 tags: &[i64],
163 enc: &EncodeOut,
164 cache: &mut Vec<f32>,
165 cache_past: &mut usize,
166 empty_cache: &Tensor<f32>,
167 ) -> Result<(i64, Vec<f32>), String> {
168 let tags_t = Tensor::from_array(([tags.len(), 1usize], tags.to_vec()))
169 .map_err(|e| format!("tableformer: tags: {e}"))?;
170 let ck_t = TensorRef::from_array_view((enc.ck_shape.as_slice(), enc.ck.as_slice()))
172 .map_err(|e| format!("tableformer: cross_k: {e}"))?;
173 let cv_t = TensorRef::from_array_view((enc.cv_shape.as_slice(), enc.cv.as_slice()))
174 .map_err(|e| format!("tableformer: cross_v: {e}"))?;
175 let dout = if *cache_past == 0 {
176 self.decoder.run(ort::inputs![
177 "tags" => tags_t, "cross_k" => ck_t, "cross_v" => cv_t, "cache" => empty_cache])
178 } else {
179 let cache_t = TensorRef::from_array_view((
180 [N_LAYERS, *cache_past, 1, EMBED_DIM],
181 cache.as_slice(),
182 ))
183 .map_err(|e| format!("tableformer: cache: {e}"))?;
184 self.decoder.run(ort::inputs![
185 "tags" => tags_t, "cross_k" => ck_t, "cross_v" => cv_t, "cache" => cache_t])
186 }
187 .map_err(|e| format!("tableformer: decode: {e}"))?;
188 let (_, logits) = dout["logits"]
189 .try_extract_tensor::<f32>()
190 .map_err(|e| format!("tableformer: logits: {e}"))?;
191 let raw = argmax(logits) as i64;
192 let (oshape, ocache) = dout["out_cache"]
193 .try_extract_tensor::<f32>()
194 .map_err(|e| format!("tableformer: out_cache: {e}"))?;
195 let next_cache = ocache.to_vec();
196 let next_past = oshape[1] as usize;
197 let (_, hidden) = dout["hidden"]
198 .try_extract_tensor::<f32>()
199 .map_err(|e| format!("tableformer: hidden: {e}"))?;
200 let hidden = hidden.to_vec();
201 *cache = next_cache;
202 *cache_past = next_past;
203 Ok((raw, hidden))
204 }
205
206 fn empty_cache(&self) -> Result<Tensor<f32>, String> {
209 Tensor::<f32>::new(self.decoder.allocator(), [N_LAYERS, 0usize, 1, EMBED_DIM])
210 .map_err(|e| format!("tableformer: empty cache: {e}"))
211 }
212
213 pub fn predict_otsl(&mut self, img: &RgbImage) -> Result<Vec<i64>, String> {
215 let enc = self.encode(img)?;
216 let mut tags: Vec<i64> = vec![START];
219 let mut out: Vec<i64> = Vec::new();
220 let mut prev_ucel = false;
221 let mut cache: Vec<f32> = Vec::new();
222 let mut cache_past = 0usize;
223 let empty = self.empty_cache()?;
224 while out.len() < MAX_STEPS {
225 let (raw, _hidden) =
226 self.decode_step(&tags, &enc, &mut cache, &mut cache_past, &empty)?;
227 let mut tag = raw;
228 if tag == XCEL {
229 tag = LCEL;
230 }
231 if prev_ucel && tag == LCEL {
232 tag = FCEL;
233 }
234 if tag == END {
235 break;
236 }
237 out.push(tag);
238 tags.push(tag);
239 prev_ucel = tag == UCEL;
240 }
241 Ok(out)
242 }
243
244 pub fn predict_table_structure(&mut self, img: &RgbImage) -> Result<Vec<TableCell>, String> {
250 let enc = self.encode(img)?;
251
252 let mut tags: Vec<i64> = vec![START];
253 let mut otsl: Vec<i64> = Vec::new();
254 let mut hiddens: Vec<f32> = Vec::new(); let mut n = 0usize;
256 let mut prev_ucel = false;
257 let mut skip = true; let mut first_lcel = true;
259 let mut bbox_ind = 0usize;
260 let mut cur_bbox_ind = 0usize;
261 let mut merge: std::collections::HashMap<usize, i64> = std::collections::HashMap::new();
262 let mut cache: Vec<f32> = Vec::new();
263 let mut cache_past = 0usize;
264 let empty = self.empty_cache()?;
265 while otsl.len() < MAX_STEPS {
266 let (raw, hidden) =
267 self.decode_step(&tags, &enc, &mut cache, &mut cache_past, &empty)?;
268 let mut tag = raw;
269 if tag == XCEL {
270 tag = LCEL;
271 }
272 if prev_ucel && tag == LCEL {
273 tag = FCEL;
274 }
275 if tag == END {
276 break;
277 }
278 if !skip && matches!(tag, FCEL | ECEL | CHED | RHED | SROW | NL | UCEL) {
280 hiddens.extend_from_slice(&hidden);
281 n += 1;
282 if !first_lcel {
283 merge.insert(cur_bbox_ind, bbox_ind as i64);
284 }
285 bbox_ind += 1;
286 }
287 if tag != LCEL {
288 first_lcel = true;
289 } else if first_lcel {
290 hiddens.extend_from_slice(&hidden);
291 n += 1;
292 first_lcel = false;
293 cur_bbox_ind = bbox_ind;
294 merge.insert(cur_bbox_ind, -1);
295 bbox_ind += 1;
296 }
297 skip = matches!(tag, NL | UCEL | XCEL);
298 prev_ucel = tag == UCEL;
299 otsl.push(tag);
300 tags.push(tag);
301 }
302 if n == 0 {
303 return Ok(Vec::new());
304 }
305 let tag_h = Tensor::from_array(([n, 512usize], hiddens))
306 .map_err(|e| format!("tableformer: tag_h: {e}"))?;
307 let eo_t = Tensor::from_array((enc.eo_shape.clone(), enc.eo.clone()))
308 .map_err(|e| format!("tableformer: eo: {e}"))?;
309 let bout = self
310 .bbox
311 .run(ort::inputs!["enc_out" => eo_t, "tag_h" => tag_h])
312 .map_err(|e| format!("tableformer: bbox: {e}"))?;
313 let (_, raw) = bout["boxes"]
314 .try_extract_tensor::<f32>()
315 .map_err(|e| format!("tableformer: boxes: {e}"))?;
316 let boxes: Vec<[f32; 4]> = raw
317 .chunks_exact(4)
318 .map(|c| [c[0], c[1], c[2], c[3]])
319 .collect();
320 let merged = merge_spans(&boxes, &merge);
321 Ok(build_table_cells(&otsl, &merged))
322 }
323
324 pub fn predict_table_rows(
331 &mut self,
332 page_image: &RgbImage,
333 page_h: f32,
334 region: [f32; 4],
335 words: &[TextCell],
336 ) -> Option<Vec<Vec<String>>> {
337 let sf = 1024.0 / page_image.height() as f32;
339 let pw = (page_image.width() as f32 * sf) as u32;
340 let page1024 = crate::resample::inter_area(page_image, pw, 1024);
341 let k = 1024.0 / page_h;
342 let x = (region[0] * k).round().max(0.0) as u32;
343 let y = (region[1] * k).round().max(0.0) as u32;
344 let x2 = ((region[2] * k).round() as u32).min(page1024.width());
345 let y2 = ((region[3] * k).round() as u32).min(page1024.height());
346 if x2 <= x || y2 <= y {
347 return None;
348 }
349 let crop = image::imageops::crop_imm(&page1024, x, y, x2 - x, y2 - y).to_image();
350 let cells = self.predict_table_structure(&crop).ok()?;
351 if cells.is_empty() {
352 return None;
353 }
354 let (rw, rh) = (region[2] - region[0], region[3] - region[1]);
355
356 let boxes: Vec<[f32; 4]> = cells
358 .iter()
359 .map(|c| {
360 [
361 region[0] + (c.cx - c.w / 2.0) * rw,
362 region[1] + (c.cy - c.h / 2.0) * rh,
363 region[0] + (c.cx + c.w / 2.0) * rw,
364 region[1] + (c.cy + c.h / 2.0) * rh,
365 ]
366 })
367 .collect();
368
369 let mut cell_words: Vec<Vec<usize>> = vec![Vec::new(); cells.len()];
371 for (wi, w) in words.iter().enumerate() {
372 let wa = ((w.r - w.l) * (w.b - w.t)).max(1.0);
373 let mut best: Option<(f32, usize)> = None;
374 for (ci, b) in boxes.iter().enumerate() {
375 let ix = (w.r.min(b[2]) - w.l.max(b[0])).max(0.0);
376 let iy = (w.b.min(b[3]) - w.t.max(b[1])).max(0.0);
377 let io = ix * iy / wa;
378 if io > 0.0 && best.is_none_or(|(bo, _)| io > bo) {
379 best = Some((io, ci));
380 }
381 }
382 if let Some((_, ci)) = best {
383 cell_words[ci].push(wi);
384 }
385 }
386
387 let num_rows = cells.iter().map(|c| c.row + c.rowspan).max().unwrap_or(0);
388 let num_cols = cells.iter().map(|c| c.col + c.colspan).max().unwrap_or(0);
389 if num_rows == 0 || num_cols == 0 {
390 return None;
391 }
392 let mut grid = vec![vec![String::new(); num_cols]; num_rows];
393 for (ci, c) in cells.iter().enumerate() {
394 let wis = std::mem::take(&mut cell_words[ci]);
398 let text = wis
399 .iter()
400 .map(|&i| words[i].text.trim())
401 .collect::<Vec<_>>()
402 .join(" ");
403 for row in grid.iter_mut().skip(c.row).take(c.rowspan) {
405 for cell in row.iter_mut().skip(c.col).take(c.colspan) {
406 *cell = text.clone();
407 }
408 }
409 }
410 Some(grid)
411 }
412}
413
414fn warn_missing_once(enc: &str, dec: &str, bbx: &str) {
421 static WARNED: std::sync::Once = std::sync::Once::new();
422 WARNED.call_once(|| {
423 eprintln!(
424 "fleischwolf: TableFormer models not found (checked {enc}, {dec}, {bbx}); \
425 tables will use geometric reconstruction instead of ML table-structure \
426 recognition. Set DOCLING_TABLEFORMER_ENCODER / DOCLING_TABLEFORMER_DECODER \
427 / DOCLING_TABLEFORMER_BBOX to enable it (see README.md)."
428 );
429 });
430}
431
432fn preprocess(img: &RgbImage) -> Result<Tensor<f32>, String> {
437 let nn = (SIDE * SIDE) as usize;
438 let side = SIDE as usize;
439 let (sw, sh) = (img.width() as i32, img.height() as i32);
440 let sxr = sw as f32 / SIDE as f32;
441 let syr = sh as f32 / SIDE as f32;
442 let mut data = vec![0f32; 3 * nn];
443 for h in 0..side {
444 let fy = (h as f32 + 0.5) * syr - 0.5;
445 let wy = fy - fy.floor();
446 let y0c = (fy.floor() as i32).clamp(0, sh - 1) as u32;
447 let y1c = (fy.floor() as i32 + 1).clamp(0, sh - 1) as u32;
448 for w in 0..side {
449 let fx = (w as f32 + 0.5) * sxr - 0.5;
450 let wx = fx - fx.floor();
451 let x0c = (fx.floor() as i32).clamp(0, sw - 1) as u32;
452 let x1c = (fx.floor() as i32 + 1).clamp(0, sw - 1) as u32;
453 let p00 = img.get_pixel(x0c, y0c);
454 let p01 = img.get_pixel(x1c, y0c);
455 let p10 = img.get_pixel(x0c, y1c);
456 let p11 = img.get_pixel(x1c, y1c);
457 let idx = w * side + h; for c in 0..3 {
459 let top = p00[c] as f32 * (1.0 - wx) + p01[c] as f32 * wx;
460 let bot = p10[c] as f32 * (1.0 - wx) + p11[c] as f32 * wx;
461 let v = top * (1.0 - wy) + bot * wy;
462 data[c * nn + idx] = (v / 255.0 - MEAN[c]) / STD[c];
463 }
464 }
465 }
466 Tensor::from_array(([1usize, 3, side, side], data))
467 .map_err(|e| format!("tableformer: input: {e}"))
468}
469
470fn mergebboxes(b1: [f32; 4], b2: [f32; 4]) -> [f32; 4] {
473 let new_w = (b2[0] + b2[2] / 2.0) - (b1[0] - b1[2] / 2.0);
474 let new_h = (b2[1] + b2[3] / 2.0) - (b1[1] - b1[3] / 2.0);
475 let new_left = b1[0] - b1[2] / 2.0;
476 let new_top = (b2[1] - b2[3] / 2.0).min(b1[1] - b1[3] / 2.0);
477 [new_left + new_w / 2.0, new_top + new_h / 2.0, new_w, new_h]
478}
479
480fn merge_spans(boxes: &[[f32; 4]], merge: &std::collections::HashMap<usize, i64>) -> Vec<[f32; 4]> {
483 let skip: std::collections::HashSet<usize> = merge
484 .values()
485 .filter(|&&v| v >= 0)
486 .map(|&v| v as usize)
487 .collect();
488 let mut out = Vec::new();
489 for (i, &b) in boxes.iter().enumerate() {
490 if let Some(&j) = merge.get(&i) {
491 let partner = if j < 0 { boxes.len() - 1 } else { j as usize };
492 out.push(mergebboxes(b, boxes[partner.min(boxes.len() - 1)]));
493 } else if !skip.contains(&i) {
494 out.push(b);
495 }
496 }
497 out
498}
499
500const CELL_TAGS: [i64; 6] = [FCEL, ECEL, XCEL, CHED, RHED, SROW];
501
502fn build_table_cells(otsl: &[i64], boxes: &[[f32; 4]]) -> Vec<TableCell> {
508 let mut grid: Vec<Vec<i64>> = vec![Vec::new()];
510 for &t in otsl {
511 if t == NL {
512 grid.push(Vec::new());
513 } else {
514 grid.last_mut().unwrap().push(t);
515 }
516 }
517 let mut cells = Vec::new();
518 let mut cell_id = 0usize;
519 for (r, row) in grid.iter().enumerate() {
520 for (c, &tag) in row.iter().enumerate() {
521 if !CELL_TAGS.contains(&tag) {
522 continue;
523 }
524 let mut colspan = 1;
525 while c + colspan < row.len() && matches!(row[c + colspan], LCEL | XCEL) {
526 colspan += 1;
527 }
528 let mut rowspan = 1;
529 while r + rowspan < grid.len()
530 && grid[r + rowspan]
531 .get(c)
532 .is_some_and(|&t| matches!(t, UCEL | XCEL))
533 {
534 rowspan += 1;
535 }
536 let b = boxes.get(cell_id).copied().unwrap_or([0.0; 4]);
537 cells.push(TableCell {
538 row: r,
539 col: c,
540 colspan,
541 rowspan,
542 tag,
543 cx: b[0],
544 cy: b[1],
545 w: b[2],
546 h: b[3],
547 });
548 cell_id += 1;
549 }
550 }
551 cells
552}
553
554fn argmax(v: &[f32]) -> usize {
555 v.iter()
556 .enumerate()
557 .max_by(|a, b| a.1.total_cmp(b.1))
558 .map(|(i, _)| i)
559 .unwrap_or(0)
560}