#![allow(dead_code, clippy::unused_self, clippy::unnecessary_wraps)]
use std::path::Path;
use ndarray::Array4;
#[cfg(feature = "docling-ffi")]
use ort::{
session::{Session, builder::SessionBuilder},
value::Tensor,
};
use crate::error::{Result, TransmutationError};
use crate::ml::{DocumentModel, preprocessing};
#[derive(Debug, Clone)]
#[allow(missing_docs)]
pub struct TableCell {
pub row: usize,
pub col: usize,
pub row_span: usize,
pub col_span: usize,
pub bbox: (f32, f32, f32, f32), pub is_header: bool,
}
#[derive(Debug, Clone)]
#[allow(missing_docs)]
pub struct TableStructure {
pub cells: Vec<TableCell>,
pub num_rows: usize,
pub num_cols: usize,
}
#[derive(Debug, Clone)]
#[allow(missing_docs)]
pub struct TableInput {
pub image: image::DynamicImage,
pub table_bbox: (f32, f32, f32, f32),
}
#[derive(Debug)]
pub struct TableStructureModel {
#[cfg(feature = "docling-ffi")]
session: Session,
model_path: std::path::PathBuf,
scale: f32, }
impl TableStructureModel {
pub fn new<P: AsRef<Path>>(model_path: P, scale: f32) -> Result<Self> {
let model_path = model_path.as_ref().to_path_buf();
#[cfg(feature = "docling-ffi")]
{
let session = SessionBuilder::new()?
.with_intra_threads(4)?
.commit_from_file(&model_path)
.map_err(|e| TransmutationError::EngineError {
engine: "table-structure-model".to_string(),
message: format!("Failed to load ONNX model: {e}"),
source: None,
})?;
Ok(Self {
session,
model_path,
scale,
})
}
#[cfg(not(feature = "docling-ffi"))]
{
Err(TransmutationError::EngineError(
"table-structure-model".to_string(),
"docling-ffi feature not enabled".to_string(),
))
}
}
#[cfg(feature = "docling-ffi")]
fn run_inference(&mut self, input: &Array4<f32>) -> Result<TableStructure> {
let shape = input.shape().to_vec();
let data = input.iter().copied().collect::<Vec<f32>>();
let input_tensor = Tensor::from_array((shape, data))?;
let (row_data, row_shape, col_data, col_shape, cell_data, cell_shape) = {
let outputs = self.session.run(ort::inputs![input_tensor])?;
let (rs, rd) = outputs[0].try_extract_tensor::<f32>()?;
let (cs, cd) = outputs[1].try_extract_tensor::<f32>()?;
let (cells, celld) = outputs[2].try_extract_tensor::<f32>()?;
(
rd.to_vec(),
rs.to_vec(),
cd.to_vec(),
cs.to_vec(),
celld.to_vec(),
cells.to_vec(),
)
};
self.post_process_from_data(
&row_shape,
&row_data,
&col_shape,
&col_data,
&cell_shape,
&cell_data,
)
}
#[cfg(feature = "docling-ffi")]
fn post_process_from_data(
&self,
row_shape: &[i64],
row_data: &[f32],
col_shape: &[i64],
col_data: &[f32],
cell_shape: &[i64],
cell_data: &[f32],
) -> Result<TableStructure> {
use ndarray::{ArrayD, IxDyn};
let row_logits_array = ArrayD::from_shape_vec(
IxDyn(
row_shape
.iter()
.map(|&d| d as usize)
.collect::<Vec<_>>()
.as_slice(),
),
row_data.to_vec(),
)
.map_err(|e| crate::TransmutationError::EngineError {
engine: "table-structure-model".to_string(),
message: format!("Failed to reshape row tensor: {e}"),
source: None,
})?;
let col_logits_array = ArrayD::from_shape_vec(
IxDyn(
col_shape
.iter()
.map(|&d| d as usize)
.collect::<Vec<_>>()
.as_slice(),
),
col_data.to_vec(),
)
.map_err(|e| crate::TransmutationError::EngineError {
engine: "table-structure-model".to_string(),
message: format!("Failed to reshape col tensor: {e}"),
source: None,
})?;
let cell_logits_array = ArrayD::from_shape_vec(
IxDyn(
cell_shape
.iter()
.map(|&d| d as usize)
.collect::<Vec<_>>()
.as_slice(),
),
cell_data.to_vec(),
)
.map_err(|e| crate::TransmutationError::EngineError {
engine: "table-structure-model".to_string(),
message: format!("Failed to reshape cell tensor: {e}"),
source: None,
})?;
let rows = self.parse_structure_logits(&row_logits_array.view())?;
let cols = self.parse_structure_logits(&col_logits_array.view())?;
let cells = self.build_cell_grid(&rows, &cols, &cell_logits_array.view())?;
Ok(TableStructure {
cells,
num_rows: rows.len(),
num_cols: cols.len(),
})
}
#[cfg(feature = "docling-ffi")]
fn parse_structure_logits(
&self,
logits: &ndarray::ArrayView<f32, ndarray::Dim<ndarray::IxDynImpl>>,
) -> Result<Vec<f32>> {
let shape = logits.shape();
if shape.len() < 2 {
return Ok(Vec::new());
}
let seq_length = shape[1];
let threshold = 0.5;
let mut positions = Vec::new();
for i in 0..seq_length {
let value = logits[[0, i]];
if value > threshold {
positions.push(i as f32);
}
}
if positions.is_empty() {
for i in 0..seq_length.min(10) {
positions.push(i as f32);
}
}
Ok(positions)
}
#[cfg(feature = "docling-ffi")]
fn build_cell_grid(
&self,
rows: &[f32],
cols: &[f32],
cell_logits: &ndarray::ArrayView<f32, ndarray::Dim<ndarray::IxDynImpl>>,
) -> Result<Vec<TableCell>> {
let mut cells = Vec::new();
let num_rows = rows.len();
let num_cols = cols.len();
if num_rows == 0 || num_cols == 0 {
return Ok(cells);
}
for row in 0..num_rows {
for col in 0..num_cols {
let y0 = if row > 0 { rows[row - 1] } else { 0.0 };
let y1 = rows[row];
let x0 = if col > 0 { cols[col - 1] } else { 0.0 };
let x1 = cols[col];
let (row_span, col_span) =
self.detect_cell_spans(row, col, num_rows, num_cols, cell_logits);
let is_header = row == 0;
cells.push(TableCell {
row,
col,
row_span,
col_span,
bbox: (x0, y0, x1, y1),
is_header,
});
}
}
Ok(cells)
}
#[cfg(feature = "docling-ffi")]
fn detect_cell_spans(
&self,
_row: usize,
_col: usize,
_num_rows: usize,
_num_cols: usize,
_cell_logits: &ndarray::ArrayView<f32, ndarray::Dim<ndarray::IxDynImpl>>,
) -> (usize, usize) {
(1, 1) }
}
#[cfg(feature = "docling-ffi")]
impl DocumentModel for TableStructureModel {
type Input = TableInput;
type Output = TableStructure;
fn predict(&mut self, input: &Self::Input) -> Result<Self::Output> {
let (x0, y0, x1, y1) = input.table_bbox;
let table_img =
input
.image
.crop_imm(x0 as u32, y0 as u32, (x1 - x0) as u32, (y1 - y0) as u32);
let tensor = preprocessing::preprocess_for_table(&table_img, self.scale)?;
let structure = self.run_inference(&tensor)?;
Ok(structure)
}
fn name(&self) -> &str {
"TableStructureModel"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore] fn test_load_model() {
let _result = TableStructureModel::new("models/tableformer_fast.onnx", 2.0);
}
}