1use crate::pdfium_backend::TextCell;
8use image::RgbImage;
9use ort::session::Session;
10use ort::value::{DynValue, Tensor};
11
12const SIDE: u32 = 448;
13#[allow(clippy::excessive_precision)]
16const MEAN: [f32; 3] = [0.94247851, 0.94254675, 0.94292611];
17#[allow(clippy::excessive_precision)]
18const STD: [f32; 3] = [0.17910956, 0.17940403, 0.17931663];
19const MAX_STEPS: usize = 1024;
20const N_LAYERS: usize = 6;
23const EMBED_DIM: usize = 512;
24
25pub const START: i64 = 2;
27pub const END: i64 = 3;
28pub const ECEL: i64 = 4; pub const FCEL: i64 = 5; pub const LCEL: i64 = 6; pub const UCEL: i64 = 7; pub const XCEL: i64 = 8; pub const NL: i64 = 9; pub const CHED: i64 = 10; pub const RHED: i64 = 11; pub const SROW: i64 = 12; #[derive(Debug, Clone)]
41pub struct TableCell {
42 pub row: usize,
43 pub col: usize,
44 pub colspan: usize,
45 pub rowspan: usize,
46 pub tag: i64,
47 pub cx: f32,
48 pub cy: f32,
49 pub w: f32,
50 pub h: f32,
51}
52
53pub struct TableFormer {
54 encoder: Session,
55 decoder: Session,
56 bbox: Session,
57}
58
59struct EncodeOut {
64 ck: DynValue,
65 cv: DynValue,
66 eo: DynValue,
67}
68
69impl TableFormer {
70 pub fn load() -> Option<Self> {
74 Self::load_with(crate::intra_threads())
75 }
76
77 pub fn load_with(intra: usize) -> Option<Self> {
81 let enc = std::env::var("DOCLING_TABLEFORMER_ENCODER")
82 .unwrap_or_else(|_| "models/tableformer/encoder.onnx".to_string());
83 let dec = crate::model_path(
86 "DOCLING_TABLEFORMER_DECODER",
87 "models/tableformer/decoder.onnx",
88 "models/tableformer/decoder_int8.onnx",
89 );
90 let bbx = std::env::var("DOCLING_TABLEFORMER_BBOX")
91 .unwrap_or_else(|_| "models/tableformer/bbox.onnx".to_string());
92 if [&enc, &dec, &bbx]
93 .iter()
94 .any(|p| !std::path::Path::new(p).exists())
95 {
96 warn_missing_once(&enc, &dec, &bbx);
103 return None;
104 }
105 let build = |path: &str, mem_pattern: bool| -> Result<Session, String> {
111 Session::builder()
112 .map_err(|e| e.to_string())?
113 .with_intra_threads(intra)
114 .map_err(|e| e.to_string())?
115 .with_memory_pattern(mem_pattern)
116 .map_err(|e| e.to_string())?
117 .commit_from_file(path)
118 .map_err(|e| format!("tableformer load {path}: {e}"))
119 };
120 match (build(&enc, true), build(&dec, false), build(&bbx, true)) {
121 (Ok(encoder), Ok(decoder), Ok(bbox)) => Some(Self {
122 encoder,
123 decoder,
124 bbox,
125 }),
126 _ => None,
127 }
128 }
129
130 fn encode(&mut self, img: &RgbImage) -> Result<EncodeOut, String> {
134 let input = preprocess(img)?;
135 let mut enc_out = self
136 .encoder
137 .run(ort::inputs!["image" => input])
138 .map_err(|e| format!("tableformer: encode: {e}"))?;
139 let mut grab = |name: &str| -> Result<DynValue, String> {
140 enc_out
141 .remove(name)
142 .ok_or_else(|| format!("tableformer: encoder output {name} missing"))
143 };
144 Ok(EncodeOut {
145 ck: grab("cross_k")?,
146 cv: grab("cross_v")?,
147 eo: grab("enc_out")?,
148 })
149 }
150
151 fn decode_step(
160 &mut self,
161 tags: &[i64],
162 enc: &EncodeOut,
163 cache: &mut Option<DynValue>,
164 empty_cache: &Tensor<f32>,
165 ) -> Result<(i64, Vec<f32>), String> {
166 let tags_t = Tensor::from_array(([tags.len(), 1usize], tags.to_vec()))
167 .map_err(|e| format!("tableformer: tags: {e}"))?;
168 let mut dout = match cache.as_ref() {
169 None => self.decoder.run(ort::inputs![
170 "tags" => tags_t, "cross_k" => &enc.ck, "cross_v" => &enc.cv,
171 "cache" => empty_cache]),
172 Some(c) => self.decoder.run(ort::inputs![
173 "tags" => tags_t, "cross_k" => &enc.ck, "cross_v" => &enc.cv,
174 "cache" => c]),
175 }
176 .map_err(|e| format!("tableformer: decode: {e}"))?;
177 let (_, logits) = dout["logits"]
178 .try_extract_tensor::<f32>()
179 .map_err(|e| format!("tableformer: logits: {e}"))?;
180 let raw = argmax(logits) as i64;
181 let (_, hidden) = dout["hidden"]
182 .try_extract_tensor::<f32>()
183 .map_err(|e| format!("tableformer: hidden: {e}"))?;
184 let hidden = hidden.to_vec();
185 *cache = Some(
186 dout.remove("out_cache")
187 .ok_or_else(|| "tableformer: decoder output out_cache missing".to_string())?,
188 );
189 Ok((raw, hidden))
190 }
191
192 fn empty_cache(&self) -> Result<Tensor<f32>, String> {
195 Tensor::<f32>::new(self.decoder.allocator(), [N_LAYERS, 0usize, 1, EMBED_DIM])
196 .map_err(|e| format!("tableformer: empty cache: {e}"))
197 }
198
199 pub fn predict_otsl(&mut self, img: &RgbImage) -> Result<Vec<i64>, String> {
201 let enc = self.encode(img)?;
202 let mut tags: Vec<i64> = vec![START];
205 let mut out: Vec<i64> = Vec::new();
206 let mut prev_ucel = false;
207 let mut cache: Option<DynValue> = None;
208 let empty = self.empty_cache()?;
209 while out.len() < MAX_STEPS {
210 let (raw, _hidden) = self.decode_step(&tags, &enc, &mut cache, &empty)?;
211 let mut tag = raw;
212 if tag == XCEL {
213 tag = LCEL;
214 }
215 if prev_ucel && tag == LCEL {
216 tag = FCEL;
217 }
218 if tag == END {
219 break;
220 }
221 out.push(tag);
222 tags.push(tag);
223 prev_ucel = tag == UCEL;
224 }
225 Ok(out)
226 }
227
228 pub fn predict_table_structure(&mut self, img: &RgbImage) -> Result<Vec<TableCell>, String> {
234 let enc = self.encode(img)?;
235
236 let mut tags: Vec<i64> = vec![START];
237 let mut otsl: Vec<i64> = Vec::new();
238 let mut hiddens: Vec<f32> = Vec::new(); let mut n = 0usize;
240 let mut prev_ucel = false;
241 let mut skip = true; let mut first_lcel = true;
243 let mut bbox_ind = 0usize;
244 let mut cur_bbox_ind = 0usize;
245 let mut merge: std::collections::HashMap<usize, i64> = std::collections::HashMap::new();
246 let mut cache: Option<DynValue> = None;
247 let empty = self.empty_cache()?;
248 while otsl.len() < MAX_STEPS {
249 let (raw, hidden) = self.decode_step(&tags, &enc, &mut cache, &empty)?;
250 let mut tag = raw;
251 if tag == XCEL {
252 tag = LCEL;
253 }
254 if prev_ucel && tag == LCEL {
255 tag = FCEL;
256 }
257 if tag == END {
258 break;
259 }
260 if !skip && matches!(tag, FCEL | ECEL | CHED | RHED | SROW | NL | UCEL) {
262 hiddens.extend_from_slice(&hidden);
263 n += 1;
264 if !first_lcel {
265 merge.insert(cur_bbox_ind, bbox_ind as i64);
266 }
267 bbox_ind += 1;
268 }
269 if tag != LCEL {
270 first_lcel = true;
271 } else if first_lcel {
272 hiddens.extend_from_slice(&hidden);
273 n += 1;
274 first_lcel = false;
275 cur_bbox_ind = bbox_ind;
276 merge.insert(cur_bbox_ind, -1);
277 bbox_ind += 1;
278 }
279 skip = matches!(tag, NL | UCEL | XCEL);
280 prev_ucel = tag == UCEL;
281 otsl.push(tag);
282 tags.push(tag);
283 }
284 if n == 0 {
285 return Ok(Vec::new());
286 }
287 let tag_h = Tensor::from_array(([n, 512usize], hiddens))
288 .map_err(|e| format!("tableformer: tag_h: {e}"))?;
289 let bout = self
290 .bbox
291 .run(ort::inputs!["enc_out" => &enc.eo, "tag_h" => tag_h])
292 .map_err(|e| format!("tableformer: bbox: {e}"))?;
293 let (_, raw) = bout["boxes"]
294 .try_extract_tensor::<f32>()
295 .map_err(|e| format!("tableformer: boxes: {e}"))?;
296 let boxes: Vec<[f32; 4]> = raw
297 .chunks_exact(4)
298 .map(|c| [c[0], c[1], c[2], c[3]])
299 .collect();
300 let merged = merge_spans(&boxes, &merge);
301 Ok(build_table_cells(&otsl, &merged))
302 }
303
304 pub fn predict_table_rows(
311 &mut self,
312 page_image: &RgbImage,
313 page_h: f32,
314 region: [f32; 4],
315 words: &[TextCell],
316 ) -> Option<Vec<Vec<String>>> {
317 let sf = 1024.0 / page_image.height() as f32;
319 let pw = (page_image.width() as f32 * sf) as u32;
320 let page1024 = crate::timing::timed("tableformer.inter_area", || {
321 crate::resample::inter_area(page_image, pw, 1024)
322 });
323 let k = 1024.0 / page_h;
324 let x = (region[0] * k).round().max(0.0) as u32;
325 let y = (region[1] * k).round().max(0.0) as u32;
326 let x2 = ((region[2] * k).round() as u32).min(page1024.width());
327 let y2 = ((region[3] * k).round() as u32).min(page1024.height());
328 if x2 <= x || y2 <= y {
329 return None;
330 }
331 let crop = image::imageops::crop_imm(&page1024, x, y, x2 - x, y2 - y).to_image();
332 let cells = crate::timing::timed("tableformer.structure", || {
333 self.predict_table_structure(&crop)
334 })
335 .ok()?;
336 if cells.is_empty() {
337 return None;
338 }
339 let (rw, rh) = (region[2] - region[0], region[3] - region[1]);
340
341 let boxes: Vec<[f32; 4]> = cells
343 .iter()
344 .map(|c| {
345 [
346 region[0] + (c.cx - c.w / 2.0) * rw,
347 region[1] + (c.cy - c.h / 2.0) * rh,
348 region[0] + (c.cx + c.w / 2.0) * rw,
349 region[1] + (c.cy + c.h / 2.0) * rh,
350 ]
351 })
352 .collect();
353
354 let mut cell_words: Vec<Vec<usize>> = vec![Vec::new(); cells.len()];
356 for (wi, w) in words.iter().enumerate() {
357 let wa = ((w.r - w.l) * (w.b - w.t)).max(1.0);
358 let mut best: Option<(f32, usize)> = None;
359 for (ci, b) in boxes.iter().enumerate() {
360 let ix = (w.r.min(b[2]) - w.l.max(b[0])).max(0.0);
361 let iy = (w.b.min(b[3]) - w.t.max(b[1])).max(0.0);
362 let io = ix * iy / wa;
363 if io > 0.0 && best.is_none_or(|(bo, _)| io > bo) {
364 best = Some((io, ci));
365 }
366 }
367 if let Some((_, ci)) = best {
368 cell_words[ci].push(wi);
369 }
370 }
371
372 let num_rows = cells.iter().map(|c| c.row + c.rowspan).max().unwrap_or(0);
373 let num_cols = cells.iter().map(|c| c.col + c.colspan).max().unwrap_or(0);
374 if num_rows == 0 || num_cols == 0 {
375 return None;
376 }
377 let mut grid = vec![vec![String::new(); num_cols]; num_rows];
378 for (ci, c) in cells.iter().enumerate() {
379 let wis = std::mem::take(&mut cell_words[ci]);
383 let text = wis
384 .iter()
385 .map(|&i| words[i].text.trim())
386 .collect::<Vec<_>>()
387 .join(" ");
388 for row in grid.iter_mut().skip(c.row).take(c.rowspan) {
390 for cell in row.iter_mut().skip(c.col).take(c.colspan) {
391 *cell = text.clone();
392 }
393 }
394 }
395 Some(grid)
396 }
397}
398
399fn warn_missing_once(enc: &str, dec: &str, bbx: &str) {
406 static WARNED: std::sync::Once = std::sync::Once::new();
407 WARNED.call_once(|| {
408 eprintln!(
409 "fleischwolf: TableFormer models not found (checked {enc}, {dec}, {bbx}); \
410 tables will use geometric reconstruction instead of ML table-structure \
411 recognition. Set DOCLING_TABLEFORMER_ENCODER / DOCLING_TABLEFORMER_DECODER \
412 / DOCLING_TABLEFORMER_BBOX to enable it (see README.md)."
413 );
414 });
415}
416
417fn preprocess(img: &RgbImage) -> Result<Tensor<f32>, String> {
422 let nn = (SIDE * SIDE) as usize;
423 let side = SIDE as usize;
424 let (sw, sh) = (img.width() as i32, img.height() as i32);
425 let sxr = sw as f32 / SIDE as f32;
426 let syr = sh as f32 / SIDE as f32;
427 let mut data = vec![0f32; 3 * nn];
428 for h in 0..side {
429 let fy = (h as f32 + 0.5) * syr - 0.5;
430 let wy = fy - fy.floor();
431 let y0c = (fy.floor() as i32).clamp(0, sh - 1) as u32;
432 let y1c = (fy.floor() as i32 + 1).clamp(0, sh - 1) as u32;
433 for w in 0..side {
434 let fx = (w as f32 + 0.5) * sxr - 0.5;
435 let wx = fx - fx.floor();
436 let x0c = (fx.floor() as i32).clamp(0, sw - 1) as u32;
437 let x1c = (fx.floor() as i32 + 1).clamp(0, sw - 1) as u32;
438 let p00 = img.get_pixel(x0c, y0c);
439 let p01 = img.get_pixel(x1c, y0c);
440 let p10 = img.get_pixel(x0c, y1c);
441 let p11 = img.get_pixel(x1c, y1c);
442 let idx = w * side + h; for c in 0..3 {
444 let top = p00[c] as f32 * (1.0 - wx) + p01[c] as f32 * wx;
445 let bot = p10[c] as f32 * (1.0 - wx) + p11[c] as f32 * wx;
446 let v = top * (1.0 - wy) + bot * wy;
447 data[c * nn + idx] = (v / 255.0 - MEAN[c]) / STD[c];
448 }
449 }
450 }
451 Tensor::from_array(([1usize, 3, side, side], data))
452 .map_err(|e| format!("tableformer: input: {e}"))
453}
454
455fn mergebboxes(b1: [f32; 4], b2: [f32; 4]) -> [f32; 4] {
458 let new_w = (b2[0] + b2[2] / 2.0) - (b1[0] - b1[2] / 2.0);
459 let new_h = (b2[1] + b2[3] / 2.0) - (b1[1] - b1[3] / 2.0);
460 let new_left = b1[0] - b1[2] / 2.0;
461 let new_top = (b2[1] - b2[3] / 2.0).min(b1[1] - b1[3] / 2.0);
462 [new_left + new_w / 2.0, new_top + new_h / 2.0, new_w, new_h]
463}
464
465fn merge_spans(boxes: &[[f32; 4]], merge: &std::collections::HashMap<usize, i64>) -> Vec<[f32; 4]> {
468 let skip: std::collections::HashSet<usize> = merge
469 .values()
470 .filter(|&&v| v >= 0)
471 .map(|&v| v as usize)
472 .collect();
473 let mut out = Vec::new();
474 for (i, &b) in boxes.iter().enumerate() {
475 if let Some(&j) = merge.get(&i) {
476 let partner = if j < 0 { boxes.len() - 1 } else { j as usize };
477 out.push(mergebboxes(b, boxes[partner.min(boxes.len() - 1)]));
478 } else if !skip.contains(&i) {
479 out.push(b);
480 }
481 }
482 out
483}
484
485const CELL_TAGS: [i64; 6] = [FCEL, ECEL, XCEL, CHED, RHED, SROW];
486
487fn build_table_cells(otsl: &[i64], boxes: &[[f32; 4]]) -> Vec<TableCell> {
493 let mut grid: Vec<Vec<i64>> = vec![Vec::new()];
495 for &t in otsl {
496 if t == NL {
497 grid.push(Vec::new());
498 } else {
499 grid.last_mut().unwrap().push(t);
500 }
501 }
502 let mut cells = Vec::new();
503 let mut cell_id = 0usize;
504 for (r, row) in grid.iter().enumerate() {
505 for (c, &tag) in row.iter().enumerate() {
506 if !CELL_TAGS.contains(&tag) {
507 continue;
508 }
509 let mut colspan = 1;
510 while c + colspan < row.len() && matches!(row[c + colspan], LCEL | XCEL) {
511 colspan += 1;
512 }
513 let mut rowspan = 1;
514 while r + rowspan < grid.len()
515 && grid[r + rowspan]
516 .get(c)
517 .is_some_and(|&t| matches!(t, UCEL | XCEL))
518 {
519 rowspan += 1;
520 }
521 let b = boxes.get(cell_id).copied().unwrap_or([0.0; 4]);
522 cells.push(TableCell {
523 row: r,
524 col: c,
525 colspan,
526 rowspan,
527 tag,
528 cx: b[0],
529 cy: b[1],
530 w: b[2],
531 h: b[3],
532 });
533 cell_id += 1;
534 }
535 }
536 cells
537}
538
539fn argmax(v: &[f32]) -> usize {
540 v.iter()
541 .enumerate()
542 .max_by(|a, b| a.1.total_cmp(b.1))
543 .map(|(i, _)| i)
544 .unwrap_or(0)
545}