use std::fs;
use std::path::PathBuf;
use image::RgbImage;
use ort::session::Session;
use ort::session::builder::{GraphOptimizationLevel, SessionBuilder};
use ort::value::Tensor;
use crate::Result;
use crate::error::KreuzbergError;
const HF_REPO_ID: &str = "Kreuzberg/paddleocr-onnx-models";
const REMOTE_FILENAME: &str = "v2/classifiers/PP-LCNet_x1_0_doc_ori.onnx";
const SHA256: &str = "6b742aebce6f0f7f71f747931ac7becfc7c96c51641e14943b291eeb334e7947";
const INPUT_SIZE: u32 = 224;
const RESIZE_SHORT: u32 = 256;
const ORIENTATION_LABELS: [u32; 4] = [0, 90, 180, 270];
pub const MIN_CONFIDENCE: f32 = 0.35;
#[derive(Debug, Clone, Copy)]
pub struct OrientationResult {
pub degrees: u32,
pub confidence: f32,
}
pub struct DocOrientationDetector {
session: once_cell::sync::OnceCell<Session>,
cache_dir: PathBuf,
acceleration: Option<crate::core::config::acceleration::AccelerationConfig>,
}
impl DocOrientationDetector {
pub fn new(cache_dir: PathBuf) -> Self {
Self {
session: once_cell::sync::OnceCell::new(),
cache_dir,
acceleration: None,
}
}
pub fn with_acceleration(
cache_dir: PathBuf,
accel: Option<crate::core::config::acceleration::AccelerationConfig>,
) -> Self {
Self {
session: once_cell::sync::OnceCell::new(),
cache_dir,
acceleration: accel,
}
}
pub fn detect(&self, image: &RgbImage) -> Result<OrientationResult> {
let session = self.get_or_init_session()?;
let preprocessed = preprocess(image);
let input_tensor = normalize(&preprocessed);
let tensor = Tensor::from_array(input_tensor).map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to create doc_ori input tensor: {e}"),
source: None,
})?;
#[allow(unsafe_code)]
let outputs = unsafe {
let session_ptr = session as *const Session as *mut Session;
(*session_ptr).run(ort::inputs!["x" => tensor])
}
.map_err(|e| KreuzbergError::Ocr {
message: format!("Doc orientation inference failed: {e}"),
source: None,
})?;
let (_, output_value) = outputs.iter().next().ok_or_else(|| KreuzbergError::Ocr {
message: "No output from doc orientation model".to_string(),
source: None,
})?;
let scores: Vec<f32> = output_value
.try_extract_tensor::<f32>()
.map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to extract doc_ori output: {e}"),
source: None,
})?
.1
.to_vec();
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum();
let probabilities: Vec<f32> = exp_scores.iter().map(|&e| e / sum_exp).collect();
let (best_idx, &best_prob) = probabilities
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or((0, &0.0));
let degrees = ORIENTATION_LABELS.get(best_idx).copied().unwrap_or(0);
Ok(OrientationResult {
degrees,
confidence: best_prob,
})
}
fn ensure_model(&self) -> Result<PathBuf> {
let model_dir = self.cache_dir.join("doc-orientation");
let model_file = model_dir.join("model.onnx");
if model_file.exists() {
return Ok(model_file);
}
tracing::info!("Downloading document orientation model...");
fs::create_dir_all(&model_dir)?;
let cached_path =
crate::model_download::hf_download(HF_REPO_ID, REMOTE_FILENAME).map_err(|e| KreuzbergError::Plugin {
message: e,
plugin_name: "auto-rotate".to_string(),
})?;
crate::model_download::verify_sha256(&cached_path, SHA256, "doc_ori").map_err(|e| {
KreuzbergError::Validation {
message: e,
source: None,
}
})?;
fs::copy(&cached_path, &model_file).map_err(|e| KreuzbergError::Plugin {
message: format!("Failed to copy doc_ori model: {e}"),
plugin_name: "auto-rotate".to_string(),
})?;
tracing::info!("Document orientation model saved");
Ok(model_file)
}
fn get_or_init_session(&self) -> Result<&Session> {
self.session.get_or_try_init(|| {
let model_path = self.ensure_model()?;
crate::ort_discovery::ensure_ort_available();
let num_threads = crate::core::config::concurrency::resolve_thread_budget(None);
let builder = SessionBuilder::new()
.map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to create doc_ori session builder: {e}"),
source: None,
})?
.with_optimization_level(GraphOptimizationLevel::All)
.map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to set doc_ori optimization level: {e}"),
source: None,
})?
.with_intra_threads(num_threads)
.map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to set doc_ori thread count: {e}"),
source: None,
})?
.with_inter_threads(1)
.map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to set doc_ori inter threads: {e}"),
source: None,
})?;
let mut builder = crate::ort_discovery::apply_execution_providers(builder, self.acceleration.as_ref())
.map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to set doc_ori execution providers: {e}"),
source: None,
})?;
let session = builder.commit_from_file(&model_path).map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to load doc_ori model: {e}"),
source: None,
})?;
tracing::info!("Doc orientation model loaded");
Ok(session)
})
}
}
pub fn resolve_cache_dir() -> PathBuf {
crate::cache_dir::resolve_cache_dir("auto-rotate")
}
pub fn detect_and_rotate(detector: &DocOrientationDetector, image_bytes: &[u8]) -> Result<Option<Vec<u8>>> {
let img = image::load_from_memory(image_bytes)
.map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to load image for orientation detection: {e}"),
source: None,
})?
.to_rgb8();
let result = detector.detect(&img)?;
tracing::debug!(
degrees = result.degrees,
confidence = result.confidence,
"Document orientation detected"
);
if result.degrees == 0 || result.confidence < MIN_CONFIDENCE {
return Ok(None);
}
let rotated = match result.degrees {
90 => image::imageops::rotate270(&img),
180 => image::imageops::rotate180(&img),
270 => image::imageops::rotate90(&img),
_ => return Ok(None),
};
let mut buf = std::io::Cursor::new(Vec::new());
rotated
.write_to(&mut buf, image::ImageFormat::Png)
.map_err(|e| KreuzbergError::Ocr {
message: format!("Failed to encode rotated image: {e}"),
source: None,
})?;
tracing::info!(
degrees = result.degrees,
confidence = result.confidence,
"Auto-rotated document page"
);
Ok(Some(buf.into_inner()))
}
fn preprocess(image: &RgbImage) -> RgbImage {
let (w, h) = (image.width(), image.height());
let (new_w, new_h) = if w < h {
let scale = RESIZE_SHORT as f32 / w as f32;
(RESIZE_SHORT, (h as f32 * scale).round() as u32)
} else {
let scale = RESIZE_SHORT as f32 / h as f32;
((w as f32 * scale).round() as u32, RESIZE_SHORT)
};
let resized = image::imageops::resize(image, new_w, new_h, image::imageops::FilterType::Triangle);
let x_offset = (new_w.saturating_sub(INPUT_SIZE)) / 2;
let y_offset = (new_h.saturating_sub(INPUT_SIZE)) / 2;
let crop_w = INPUT_SIZE.min(new_w);
let crop_h = INPUT_SIZE.min(new_h);
image::imageops::crop_imm(&resized, x_offset, y_offset, crop_w, crop_h).to_image()
}
fn normalize(image: &RgbImage) -> ndarray::Array4<f32> {
let (w, h) = (image.width() as usize, image.height() as usize);
let mut tensor = ndarray::Array4::<f32>::zeros((1, 3, h, w));
const BGR_MEAN: [f32; 3] = [0.406 * 255.0, 0.456 * 255.0, 0.485 * 255.0];
const BGR_NORM: [f32; 3] = [1.0 / (0.225 * 255.0), 1.0 / (0.224 * 255.0), 1.0 / (0.229 * 255.0)];
for y in 0..h {
for x in 0..w {
let pixel = image.get_pixel(x as u32, y as u32);
let r = pixel[0] as f32;
let g = pixel[1] as f32;
let b = pixel[2] as f32;
tensor[[0, 0, y, x]] = (b - BGR_MEAN[0]) * BGR_NORM[0];
tensor[[0, 1, y, x]] = (g - BGR_MEAN[1]) * BGR_NORM[1];
tensor[[0, 2, y, x]] = (r - BGR_MEAN[2]) * BGR_NORM[2];
}
}
tensor
}