use image::RgbImage;
use ndarray::Array4;
use ort::{inputs, session::Session, value::Tensor};
use crate::layout::error::LayoutError;
use crate::layout::session::build_session;
const INPUT_SIZE: u32 = 224;
const IMAGENET_MEAN_BGR: [f32; 3] = [0.485, 0.456, 0.406];
const IMAGENET_STD_BGR: [f32; 3] = [0.229, 0.224, 0.225];
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TableType {
Wired,
Wireless,
}
impl TableType {
pub fn name(&self) -> &'static str {
match self {
Self::Wired => "wired",
Self::Wireless => "wireless",
}
}
}
pub struct TableClassifier {
session: Session,
input_name: String,
}
impl TableClassifier {
pub fn from_file(path: &str) -> Result<Self, LayoutError> {
let budget = crate::core::config::concurrency::resolve_thread_budget(None);
let session = match build_session(path, None, budget) {
Ok(s) => s,
Err(first_err) => {
tracing::warn!("TableClassifier: platform EP failed ({first_err}), retrying CPU-only");
Self::build_cpu_session(path, budget)?
}
};
let input_name = session.inputs()[0].name().to_string();
Ok(Self { session, input_name })
}
fn build_cpu_session(path: &str, thread_budget: usize) -> Result<Session, LayoutError> {
use ort::session::builder::GraphOptimizationLevel;
let mut builder = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::All)
.map_err(|e| LayoutError::Ort(ort::Error::new(e.message())))?
.with_intra_threads(thread_budget)
.map_err(|e| LayoutError::Ort(ort::Error::new(e.message())))?
.with_inter_threads(1)
.map_err(|e| LayoutError::Ort(ort::Error::new(e.message())))?;
Ok(builder.commit_from_file(path)?)
}
pub fn classify(&mut self, table_img: &RgbImage) -> Result<TableType, LayoutError> {
tracing::trace!(
input_width = table_img.width(),
input_height = table_img.height(),
"TableClassifier: starting classification"
);
let input_tensor = preprocess_lcnet(table_img);
let tensor = Tensor::from_array(input_tensor)?;
let inference_start = std::time::Instant::now();
let outputs = self.session.run(inputs![
self.input_name.as_str() => tensor
])?;
let inference_ms = inference_start.elapsed().as_secs_f64() * 1000.0;
tracing::trace!(
inference_ms = format!("{:.1}", inference_ms),
"TableClassifier: inference complete"
);
for (_name, value) in outputs.iter() {
if let Ok(view) = value.try_extract_tensor::<f32>() {
let data = view.1;
if data.len() >= 2 {
let raw_wired = data[0];
let raw_wireless = data[1];
let max_val = raw_wired.max(raw_wireless);
let exp_wired = (raw_wired - max_val).exp();
let exp_wireless = (raw_wireless - max_val).exp();
let sum_exp = exp_wired + exp_wireless;
let prob_wired = exp_wired / sum_exp;
let prob_wireless = exp_wireless / sum_exp;
let result = if prob_wired >= prob_wireless {
TableType::Wired
} else {
TableType::Wireless
};
tracing::debug!(
raw_wired = format!("{:.3}", raw_wired),
raw_wireless = format!("{:.3}", raw_wireless),
prob_wired = format!("{:.3}", prob_wired),
prob_wireless = format!("{:.3}", prob_wireless),
result = result.name(),
"Table classification result"
);
return Ok(result);
}
}
}
tracing::warn!("TableClassifier: could not parse output, defaulting to wireless");
Ok(TableType::Wireless)
}
}
const MIN_EDGE: u32 = 256;
fn preprocess_lcnet(img: &RgbImage) -> Array4<f32> {
let orig_w = img.width();
let orig_h = img.height();
let scale = MIN_EDGE as f32 / orig_w.min(orig_h) as f32;
let new_w = (orig_w as f32 * scale).round().max(1.0) as u32;
let new_h = (orig_h as f32 * scale).round().max(1.0) as u32;
let resized = image::imageops::resize(img, new_w, new_h, image::imageops::FilterType::Triangle);
let crop_x = (new_w.saturating_sub(INPUT_SIZE)) / 2;
let crop_y = (new_h.saturating_sub(INPUT_SIZE)) / 2;
let crop_w = INPUT_SIZE.min(new_w);
let crop_h = INPUT_SIZE.min(new_h);
let cropped = image::imageops::crop_imm(&resized, crop_x, crop_y, crop_w, crop_h).to_image();
let w = INPUT_SIZE as usize;
let h = INPUT_SIZE as usize;
let hw = h * w;
const INV_255: f32 = 1.0 / 255.0;
let alpha_b = INV_255 / IMAGENET_STD_BGR[0]; let alpha_g = INV_255 / IMAGENET_STD_BGR[1]; let alpha_r = INV_255 / IMAGENET_STD_BGR[2]; let beta_b = -IMAGENET_MEAN_BGR[0] / IMAGENET_STD_BGR[0];
let beta_g = -IMAGENET_MEAN_BGR[1] / IMAGENET_STD_BGR[1];
let beta_r = -IMAGENET_MEAN_BGR[2] / IMAGENET_STD_BGR[2];
let mut data = vec![0.0f32; 3 * hw];
let pixels = cropped.as_raw();
for (i, pixel) in pixels.chunks_exact(3).enumerate() {
let r = pixel[0] as f32;
let g = pixel[1] as f32;
let b = pixel[2] as f32;
data[i] = b * alpha_b + beta_b;
data[hw + i] = g * alpha_g + beta_g;
data[2 * hw + i] = r * alpha_r + beta_r;
}
Array4::from_shape_vec((1, 3, h, w), data).expect("shape mismatch in preprocess_lcnet")
}