use super::task::{Task, TaskSchema, TaskType};
use crate::core::OCRError;
use std::fmt::Debug;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct AdapterInfo {
pub model_name: String,
pub task_type: TaskType,
pub description: String,
}
impl AdapterInfo {
pub fn new(
model_name: impl Into<String>,
task_type: TaskType,
description: impl Into<String>,
) -> Self {
Self {
model_name: model_name.into(),
task_type,
description: description.into(),
}
}
}
pub trait ModelAdapter: Send + Sync + Debug {
type Task: Task;
fn info(&self) -> AdapterInfo;
fn schema(&self) -> TaskSchema {
TaskSchema::new(
self.info().task_type,
vec!["image".to_string()], vec!["result".to_string()],
)
}
fn execute(
&self,
input: <Self::Task as Task>::Input,
config: Option<&<Self::Task as Task>::Config>,
) -> Result<<Self::Task as Task>::Output, OCRError>;
fn validate_compatibility(&self, schema: &TaskSchema) -> Result<(), OCRError> {
let adapter_schema = self.schema();
if adapter_schema.task_type != schema.task_type {
return Err(OCRError::ConfigError {
message: format!(
"Adapter task type {:?} does not match required task type {:?}",
adapter_schema.task_type, schema.task_type
),
});
}
Ok(())
}
fn supports_batching(&self) -> bool {
true }
fn recommended_batch_size(&self) -> usize {
6 }
}
pub trait AdapterBuilder: Sized {
type Config: Send + Sync + Debug + Clone;
type Adapter: ModelAdapter;
fn build(self, model_path: &Path) -> Result<Self::Adapter, OCRError>;
fn with_config(self, config: Self::Config) -> Self;
fn adapter_type(&self) -> &str;
}
pub trait OrtConfigurable: Sized {
fn with_ort_config(self, config: crate::core::config::OrtSessionConfig) -> Self;
}
#[derive(Debug)]
pub struct AdapterTask<A: ModelAdapter> {
adapter: A,
}
impl<A: ModelAdapter> AdapterTask<A> {
pub fn new(adapter: A) -> Self {
Self { adapter }
}
pub fn adapter(&self) -> &A {
&self.adapter
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_adapter_info_creation() {
let info = AdapterInfo::new(
"DB",
TaskType::TextDetection,
"Differentiable Binarization text detector",
);
assert_eq!(info.model_name, "DB");
assert_eq!(info.task_type, TaskType::TextDetection);
}
#[test]
fn test_schema_validation() {
let schema = TaskSchema::new(
TaskType::TextDetection,
vec!["image".to_string()],
vec!["text_boxes".to_string()],
);
assert_eq!(schema.task_type, TaskType::TextDetection);
assert_eq!(schema.input_types, vec!["image".to_string()]);
}
}