1use crate::pdfium_backend::TextCell;
8use image::RgbImage;
9use ort::session::Session;
10use ort::value::Tensor;
11
12const SIDE: u32 = 448;
13#[allow(clippy::excessive_precision)]
16const MEAN: [f32; 3] = [0.94247851, 0.94254675, 0.94292611];
17#[allow(clippy::excessive_precision)]
18const STD: [f32; 3] = [0.17910956, 0.17940403, 0.17931663];
19const MAX_STEPS: usize = 1024;
20
21pub const START: i64 = 2;
23pub const END: i64 = 3;
24pub const ECEL: i64 = 4; pub const FCEL: i64 = 5; pub const LCEL: i64 = 6; pub const UCEL: i64 = 7; pub const XCEL: i64 = 8; pub const NL: i64 = 9; pub const CHED: i64 = 10; pub const RHED: i64 = 11; pub const SROW: i64 = 12; #[derive(Debug, Clone)]
37pub struct TableCell {
38 pub row: usize,
39 pub col: usize,
40 pub colspan: usize,
41 pub rowspan: usize,
42 pub tag: i64,
43 pub cx: f32,
44 pub cy: f32,
45 pub w: f32,
46 pub h: f32,
47}
48
49pub struct TableFormer {
50 encoder: Session,
51 decoder: Session,
52 bbox: Session,
53}
54
55impl TableFormer {
56 pub fn load() -> Option<Self> {
60 let enc = std::env::var("DOCLING_TABLEFORMER_ENCODER")
61 .unwrap_or_else(|_| "models/tableformer/encoder.onnx".to_string());
62 let dec = std::env::var("DOCLING_TABLEFORMER_DECODER")
63 .unwrap_or_else(|_| "models/tableformer/decoder.onnx".to_string());
64 let bbx = std::env::var("DOCLING_TABLEFORMER_BBOX")
65 .unwrap_or_else(|_| "models/tableformer/bbox.onnx".to_string());
66 if [&enc, &dec, &bbx]
67 .iter()
68 .any(|p| !std::path::Path::new(p).exists())
69 {
70 return None;
71 }
72 let build = |path: &str| -> Result<Session, String> {
73 Session::builder()
74 .map_err(|e| e.to_string())?
75 .with_intra_threads(crate::intra_threads())
76 .map_err(|e| e.to_string())?
77 .commit_from_file(path)
78 .map_err(|e| format!("tableformer load {path}: {e}"))
79 };
80 match (build(&enc), build(&dec), build(&bbx)) {
81 (Ok(encoder), Ok(decoder), Ok(bbox)) => Some(Self {
82 encoder,
83 decoder,
84 bbox,
85 }),
86 _ => None,
87 }
88 }
89
90 pub fn predict_otsl(&mut self, img: &RgbImage) -> Result<Vec<i64>, String> {
92 let input = preprocess(img)?;
93 let enc_out = self
94 .encoder
95 .run(ort::inputs!["image" => input])
96 .map_err(|e| format!("tableformer: encode: {e}"))?;
97 let (mshape, mem) = enc_out["memory"]
98 .try_extract_tensor::<f32>()
99 .map_err(|e| format!("tableformer: memory: {e}"))?;
100 let mshape: Vec<usize> = mshape.iter().map(|&x| x as usize).collect();
101 let mem: Vec<f32> = mem.to_vec();
102
103 let mut tags: Vec<i64> = vec![START];
109 let mut out: Vec<i64> = Vec::new();
110 let mut prev_ucel = false;
111 while out.len() < MAX_STEPS {
112 let tags_t = Tensor::from_array(([tags.len(), 1usize], tags.clone()))
113 .map_err(|e| format!("tableformer: tags: {e}"))?;
114 let mem_t = Tensor::from_array((mshape.clone(), mem.clone()))
115 .map_err(|e| format!("tableformer: mem: {e}"))?;
116 let dout = self
117 .decoder
118 .run(ort::inputs!["tags" => tags_t, "memory" => mem_t])
119 .map_err(|e| format!("tableformer: decode: {e}"))?;
120 let (_, logits) = dout["logits"]
121 .try_extract_tensor::<f32>()
122 .map_err(|e| format!("tableformer: logits: {e}"))?;
123 let mut tag = argmax(logits) as i64;
124 if tag == XCEL {
125 tag = LCEL;
126 }
127 if prev_ucel && tag == LCEL {
128 tag = FCEL;
129 }
130 if tag == END {
131 break;
132 }
133 out.push(tag);
134 tags.push(tag);
135 prev_ucel = tag == UCEL;
136 }
137 Ok(out)
138 }
139
140 pub fn predict_table_structure(&mut self, img: &RgbImage) -> Result<Vec<TableCell>, String> {
146 let input = preprocess(img)?;
147 let enc_out = self
148 .encoder
149 .run(ort::inputs!["image" => input])
150 .map_err(|e| format!("tableformer: encode: {e}"))?;
151 let (mshape, mem) = enc_out["memory"]
152 .try_extract_tensor::<f32>()
153 .map_err(|e| format!("tableformer: memory: {e}"))?;
154 let mshape: Vec<usize> = mshape.iter().map(|&x| x as usize).collect();
155 let mem: Vec<f32> = mem.to_vec();
156 let (eshape, eo) = enc_out["enc_out"]
157 .try_extract_tensor::<f32>()
158 .map_err(|e| format!("tableformer: enc_out: {e}"))?;
159 let eshape: Vec<usize> = eshape.iter().map(|&x| x as usize).collect();
160 let eo: Vec<f32> = eo.to_vec();
161
162 let mut tags: Vec<i64> = vec![START];
163 let mut otsl: Vec<i64> = Vec::new();
164 let mut hiddens: Vec<f32> = Vec::new(); let mut n = 0usize;
166 let mut prev_ucel = false;
167 let mut skip = true; let mut first_lcel = true;
169 let mut bbox_ind = 0usize;
170 let mut cur_bbox_ind = 0usize;
171 let mut merge: std::collections::HashMap<usize, i64> = std::collections::HashMap::new();
172 while otsl.len() < MAX_STEPS {
173 let tags_t = Tensor::from_array(([tags.len(), 1usize], tags.clone()))
174 .map_err(|e| format!("tableformer: tags: {e}"))?;
175 let mem_t = Tensor::from_array((mshape.clone(), mem.clone()))
176 .map_err(|e| format!("tableformer: mem: {e}"))?;
177 let dout = self
178 .decoder
179 .run(ort::inputs!["tags" => tags_t, "memory" => mem_t])
180 .map_err(|e| format!("tableformer: decode: {e}"))?;
181 let (_, logits) = dout["logits"]
182 .try_extract_tensor::<f32>()
183 .map_err(|e| format!("tableformer: logits: {e}"))?;
184 let mut tag = argmax(logits) as i64;
185 if tag == XCEL {
186 tag = LCEL;
187 }
188 if prev_ucel && tag == LCEL {
189 tag = FCEL;
190 }
191 if tag == END {
192 break;
193 }
194 let (_, hidden) = dout["hidden"]
195 .try_extract_tensor::<f32>()
196 .map_err(|e| format!("tableformer: hidden: {e}"))?;
197 if !skip && matches!(tag, FCEL | ECEL | CHED | RHED | SROW | NL | UCEL) {
199 hiddens.extend_from_slice(hidden);
200 n += 1;
201 if !first_lcel {
202 merge.insert(cur_bbox_ind, bbox_ind as i64);
203 }
204 bbox_ind += 1;
205 }
206 if tag != LCEL {
207 first_lcel = true;
208 } else if first_lcel {
209 hiddens.extend_from_slice(hidden);
210 n += 1;
211 first_lcel = false;
212 cur_bbox_ind = bbox_ind;
213 merge.insert(cur_bbox_ind, -1);
214 bbox_ind += 1;
215 }
216 skip = matches!(tag, NL | UCEL | XCEL);
217 prev_ucel = tag == UCEL;
218 otsl.push(tag);
219 tags.push(tag);
220 }
221 if n == 0 {
222 return Ok(Vec::new());
223 }
224 let tag_h = Tensor::from_array(([n, 512usize], hiddens))
225 .map_err(|e| format!("tableformer: tag_h: {e}"))?;
226 let eo_t = Tensor::from_array((eshape, eo)).map_err(|e| format!("tableformer: eo: {e}"))?;
227 let bout = self
228 .bbox
229 .run(ort::inputs!["enc_out" => eo_t, "tag_h" => tag_h])
230 .map_err(|e| format!("tableformer: bbox: {e}"))?;
231 let (_, raw) = bout["boxes"]
232 .try_extract_tensor::<f32>()
233 .map_err(|e| format!("tableformer: boxes: {e}"))?;
234 let boxes: Vec<[f32; 4]> = raw
235 .chunks_exact(4)
236 .map(|c| [c[0], c[1], c[2], c[3]])
237 .collect();
238 let merged = merge_spans(&boxes, &merge);
239 Ok(build_table_cells(&otsl, &merged))
240 }
241
242 pub fn predict_table_rows(
249 &mut self,
250 page_image: &RgbImage,
251 page_h: f32,
252 region: [f32; 4],
253 words: &[TextCell],
254 ) -> Option<Vec<Vec<String>>> {
255 let sf = 1024.0 / page_image.height() as f32;
257 let pw = (page_image.width() as f32 * sf) as u32;
258 let page1024 = crate::resample::inter_area(page_image, pw, 1024);
259 let k = 1024.0 / page_h;
260 let x = (region[0] * k).round().max(0.0) as u32;
261 let y = (region[1] * k).round().max(0.0) as u32;
262 let x2 = ((region[2] * k).round() as u32).min(page1024.width());
263 let y2 = ((region[3] * k).round() as u32).min(page1024.height());
264 if x2 <= x || y2 <= y {
265 return None;
266 }
267 let crop = image::imageops::crop_imm(&page1024, x, y, x2 - x, y2 - y).to_image();
268 let cells = self.predict_table_structure(&crop).ok()?;
269 if cells.is_empty() {
270 return None;
271 }
272 let (rw, rh) = (region[2] - region[0], region[3] - region[1]);
273
274 let boxes: Vec<[f32; 4]> = cells
276 .iter()
277 .map(|c| {
278 [
279 region[0] + (c.cx - c.w / 2.0) * rw,
280 region[1] + (c.cy - c.h / 2.0) * rh,
281 region[0] + (c.cx + c.w / 2.0) * rw,
282 region[1] + (c.cy + c.h / 2.0) * rh,
283 ]
284 })
285 .collect();
286
287 let mut cell_words: Vec<Vec<usize>> = vec![Vec::new(); cells.len()];
289 for (wi, w) in words.iter().enumerate() {
290 let wa = ((w.r - w.l) * (w.b - w.t)).max(1.0);
291 let mut best: Option<(f32, usize)> = None;
292 for (ci, b) in boxes.iter().enumerate() {
293 let ix = (w.r.min(b[2]) - w.l.max(b[0])).max(0.0);
294 let iy = (w.b.min(b[3]) - w.t.max(b[1])).max(0.0);
295 let io = ix * iy / wa;
296 if io > 0.0 && best.is_none_or(|(bo, _)| io > bo) {
297 best = Some((io, ci));
298 }
299 }
300 if let Some((_, ci)) = best {
301 cell_words[ci].push(wi);
302 }
303 }
304
305 let num_rows = cells.iter().map(|c| c.row + c.rowspan).max().unwrap_or(0);
306 let num_cols = cells.iter().map(|c| c.col + c.colspan).max().unwrap_or(0);
307 if num_rows == 0 || num_cols == 0 {
308 return None;
309 }
310 let mut grid = vec![vec![String::new(); num_cols]; num_rows];
311 for (ci, c) in cells.iter().enumerate() {
312 let wis = std::mem::take(&mut cell_words[ci]);
316 let text = wis
317 .iter()
318 .map(|&i| words[i].text.trim())
319 .collect::<Vec<_>>()
320 .join(" ");
321 for row in grid.iter_mut().skip(c.row).take(c.rowspan) {
323 for cell in row.iter_mut().skip(c.col).take(c.colspan) {
324 *cell = text.clone();
325 }
326 }
327 }
328 Some(grid)
329 }
330}
331
332fn preprocess(img: &RgbImage) -> Result<Tensor<f32>, String> {
337 let nn = (SIDE * SIDE) as usize;
338 let side = SIDE as usize;
339 let (sw, sh) = (img.width() as i32, img.height() as i32);
340 let sxr = sw as f32 / SIDE as f32;
341 let syr = sh as f32 / SIDE as f32;
342 let mut data = vec![0f32; 3 * nn];
343 for h in 0..side {
344 let fy = (h as f32 + 0.5) * syr - 0.5;
345 let wy = fy - fy.floor();
346 let y0c = (fy.floor() as i32).clamp(0, sh - 1) as u32;
347 let y1c = (fy.floor() as i32 + 1).clamp(0, sh - 1) as u32;
348 for w in 0..side {
349 let fx = (w as f32 + 0.5) * sxr - 0.5;
350 let wx = fx - fx.floor();
351 let x0c = (fx.floor() as i32).clamp(0, sw - 1) as u32;
352 let x1c = (fx.floor() as i32 + 1).clamp(0, sw - 1) as u32;
353 let p00 = img.get_pixel(x0c, y0c);
354 let p01 = img.get_pixel(x1c, y0c);
355 let p10 = img.get_pixel(x0c, y1c);
356 let p11 = img.get_pixel(x1c, y1c);
357 let idx = w * side + h; for c in 0..3 {
359 let top = p00[c] as f32 * (1.0 - wx) + p01[c] as f32 * wx;
360 let bot = p10[c] as f32 * (1.0 - wx) + p11[c] as f32 * wx;
361 let v = top * (1.0 - wy) + bot * wy;
362 data[c * nn + idx] = (v / 255.0 - MEAN[c]) / STD[c];
363 }
364 }
365 }
366 Tensor::from_array(([1usize, 3, side, side], data))
367 .map_err(|e| format!("tableformer: input: {e}"))
368}
369
370fn mergebboxes(b1: [f32; 4], b2: [f32; 4]) -> [f32; 4] {
373 let new_w = (b2[0] + b2[2] / 2.0) - (b1[0] - b1[2] / 2.0);
374 let new_h = (b2[1] + b2[3] / 2.0) - (b1[1] - b1[3] / 2.0);
375 let new_left = b1[0] - b1[2] / 2.0;
376 let new_top = (b2[1] - b2[3] / 2.0).min(b1[1] - b1[3] / 2.0);
377 [new_left + new_w / 2.0, new_top + new_h / 2.0, new_w, new_h]
378}
379
380fn merge_spans(boxes: &[[f32; 4]], merge: &std::collections::HashMap<usize, i64>) -> Vec<[f32; 4]> {
383 let skip: std::collections::HashSet<usize> = merge
384 .values()
385 .filter(|&&v| v >= 0)
386 .map(|&v| v as usize)
387 .collect();
388 let mut out = Vec::new();
389 for (i, &b) in boxes.iter().enumerate() {
390 if let Some(&j) = merge.get(&i) {
391 let partner = if j < 0 { boxes.len() - 1 } else { j as usize };
392 out.push(mergebboxes(b, boxes[partner.min(boxes.len() - 1)]));
393 } else if !skip.contains(&i) {
394 out.push(b);
395 }
396 }
397 out
398}
399
400const CELL_TAGS: [i64; 6] = [FCEL, ECEL, XCEL, CHED, RHED, SROW];
401
402fn build_table_cells(otsl: &[i64], boxes: &[[f32; 4]]) -> Vec<TableCell> {
408 let mut grid: Vec<Vec<i64>> = vec![Vec::new()];
410 for &t in otsl {
411 if t == NL {
412 grid.push(Vec::new());
413 } else {
414 grid.last_mut().unwrap().push(t);
415 }
416 }
417 let mut cells = Vec::new();
418 let mut cell_id = 0usize;
419 for (r, row) in grid.iter().enumerate() {
420 for (c, &tag) in row.iter().enumerate() {
421 if !CELL_TAGS.contains(&tag) {
422 continue;
423 }
424 let mut colspan = 1;
425 while c + colspan < row.len() && matches!(row[c + colspan], LCEL | XCEL) {
426 colspan += 1;
427 }
428 let mut rowspan = 1;
429 while r + rowspan < grid.len()
430 && grid[r + rowspan]
431 .get(c)
432 .is_some_and(|&t| matches!(t, UCEL | XCEL))
433 {
434 rowspan += 1;
435 }
436 let b = boxes.get(cell_id).copied().unwrap_or([0.0; 4]);
437 cells.push(TableCell {
438 row: r,
439 col: c,
440 colspan,
441 rowspan,
442 tag,
443 cx: b[0],
444 cy: b[1],
445 w: b[2],
446 h: b[3],
447 });
448 cell_id += 1;
449 }
450 }
451 cells
452}
453
454fn argmax(v: &[f32]) -> usize {
455 v.iter()
456 .enumerate()
457 .max_by(|a, b| a.1.total_cmp(b.1))
458 .map(|(i, _)| i)
459 .unwrap_or(0)
460}