use crate::apply_ort_config;
use crate::core::OCRError;
use crate::core::traits::adapter::{AdapterInfo, ModelAdapter};
use crate::core::traits::task::Task;
use crate::domain::tasks::document_rectification::{
DocumentRectificationConfig, DocumentRectificationOutput, DocumentRectificationTask,
};
use crate::impl_adapter_builder;
use crate::models::rectification::uvdoc::{UVDocModel, UVDocModelBuilder, UVDocPreprocessConfig};
#[derive(Debug)]
pub struct UVDocRectifierAdapter {
model: UVDocModel,
info: AdapterInfo,
_config: DocumentRectificationConfig,
}
impl ModelAdapter for UVDocRectifierAdapter {
type Task = DocumentRectificationTask;
fn info(&self) -> AdapterInfo {
self.info.clone()
}
fn execute(
&self,
input: <Self::Task as Task>::Input,
_config: Option<&<Self::Task as Task>::Config>,
) -> Result<<Self::Task as Task>::Output, OCRError> {
let batch_len = input.images.len();
let model_output = self.model.forward(input.images).map_err(|e| {
OCRError::adapter_execution_error(
"UVDocRectifierAdapter",
format!("model forward (batch_size={})", batch_len),
e,
)
})?;
Ok(DocumentRectificationOutput {
rectified_images: model_output.images,
})
}
fn supports_batching(&self) -> bool {
true
}
fn recommended_batch_size(&self) -> usize {
8
}
}
impl_adapter_builder! {
builder_name: UVDocRectifierAdapterBuilder,
adapter_name: UVDocRectifierAdapter,
config_type: DocumentRectificationConfig,
adapter_type: "uvdoc_rectifier",
adapter_desc: "Corrects geometric distortions in document images",
task_type: DocumentRectification,
fields: {
preprocess_config: UVDocPreprocessConfig = UVDocPreprocessConfig::default(),
model_name_override: Option<String> = None,
},
methods: {
pub fn model_name(mut self, model_name: impl Into<String>) -> Self {
self.model_name_override = Some(model_name.into());
self
}
}
overrides: {
with_config: |builder: UVDocRectifierAdapterBuilder, config: DocumentRectificationConfig| -> UVDocRectifierAdapterBuilder {
let mut result = builder;
result.preprocess_config.rec_image_shape = config.rec_image_shape;
result.config = result.config.with_task_config(config);
result
},
}
build: |builder: UVDocRectifierAdapterBuilder, model_path: &std::path::Path| -> Result<UVDocRectifierAdapter, OCRError> {
let (task_config, ort_config) = builder.config
.into_validated_parts()
.map_err(|err| OCRError::ConfigError {
message: err.to_string(),
})?;
let model = apply_ort_config!(
UVDocModelBuilder::new().preprocess_config(builder.preprocess_config),
ort_config
)
.build(model_path)?;
let mut info = UVDocRectifierAdapterBuilder::base_adapter_info();
if let Some(model_name) = builder.model_name_override {
info.model_name = model_name;
}
Ok(UVDocRectifierAdapter {
model,
info,
_config: task_config,
})
},
}
impl UVDocRectifierAdapterBuilder {
pub fn input_shape(mut self, shape: [usize; 3]) -> Self {
self.preprocess_config.rec_image_shape = shape;
let mut task_config = self.config.task_config().clone();
task_config.rec_image_shape = shape;
self.config = self.config.with_task_config(task_config);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::traits::adapter::AdapterBuilder;
#[test]
fn test_builder_creation() {
let builder = UVDocRectifierAdapterBuilder::new();
assert_eq!(builder.adapter_type(), "uvdoc_rectifier");
}
#[test]
fn test_builder_with_config() {
let config = DocumentRectificationConfig {
rec_image_shape: [3, 1024, 1024],
};
let builder = UVDocRectifierAdapterBuilder::new().with_config(config.clone());
assert_eq!(
builder.config.task_config().rec_image_shape,
[3, 1024, 1024]
);
assert_eq!(builder.preprocess_config.rec_image_shape, [3, 1024, 1024]);
}
#[test]
fn test_builder_fluent_api() {
let builder = UVDocRectifierAdapterBuilder::new().input_shape([3, 768, 768]);
assert_eq!(builder.config.task_config().rec_image_shape, [3, 768, 768]);
assert_eq!(builder.preprocess_config.rec_image_shape, [3, 768, 768]);
}
#[test]
fn test_default_builder() {
let builder = UVDocRectifierAdapterBuilder::default();
assert_eq!(builder.adapter_type(), "uvdoc_rectifier");
assert_eq!(builder.config.task_config().rec_image_shape, [3, 0, 0]);
}
}