Skip to main content

oar_ocr_core/core/traits/
adapter.rs

1//! Model adapter trait definitions for the OCR pipeline.
2//!
3//! This module defines the `ModelAdapter` trait and related types that adapt
4//! various model implementations to conform to task interfaces. Adapters handle
5//! preprocessing, inference, and postprocessing for specific models.
6
7use super::task::{Task, TaskSchema, TaskType};
8use crate::core::OCRError;
9use std::fmt::Debug;
10use std::path::Path;
11
12/// Information about a model adapter.
13#[derive(Debug, Clone)]
14pub struct AdapterInfo {
15    /// Name of the model (e.g., "DB", "CRNN", "RT-DETR")
16    pub model_name: String,
17    /// Task type this adapter supports
18    pub task_type: TaskType,
19    /// Description of the model
20    pub description: String,
21}
22
23impl AdapterInfo {
24    /// Creates a new adapter info.
25    pub fn new(
26        model_name: impl Into<String>,
27        task_type: TaskType,
28        description: impl Into<String>,
29    ) -> Self {
30        Self {
31            model_name: model_name.into(),
32            task_type,
33            description: description.into(),
34        }
35    }
36}
37
38/// Core trait for model adapters.
39///
40/// Adapters bridge the gap between task interfaces and concrete model implementations.
41/// They handle model-specific preprocessing, inference, and postprocessing while
42/// conforming to the task's input/output schema.
43pub trait ModelAdapter: Send + Sync + Debug {
44    /// The task type this adapter executes
45    type Task: Task;
46
47    /// Returns information about this adapter.
48    fn info(&self) -> AdapterInfo;
49
50    /// Returns the schema that this adapter conforms to.
51    fn schema(&self) -> TaskSchema {
52        TaskSchema::new(
53            self.info().task_type,
54            vec!["image".to_string()], // Most adapters work with images
55            vec!["result".to_string()],
56        )
57    }
58
59    /// Executes the model on the given input.
60    ///
61    /// This method handles the complete pipeline:
62    /// 1. Validate input
63    /// 2. Preprocess
64    /// 3. Run inference
65    /// 4. Postprocess
66    /// 5. Validate output
67    ///
68    /// # Arguments
69    ///
70    /// * `input` - The task input to process
71    /// * `config` - Optional configuration for execution
72    ///
73    /// # Returns
74    ///
75    /// The task output or an error
76    fn execute(
77        &self,
78        input: <Self::Task as Task>::Input,
79        config: Option<&<Self::Task as Task>::Config>,
80    ) -> Result<<Self::Task as Task>::Output, OCRError>;
81
82    /// Validates that this adapter is compatible with the given task schema.
83    ///
84    /// # Arguments
85    ///
86    /// * `schema` - The schema to check compatibility with
87    ///
88    /// # Returns
89    ///
90    /// Result indicating success or incompatibility error
91    fn validate_compatibility(&self, schema: &TaskSchema) -> Result<(), OCRError> {
92        let adapter_schema = self.schema();
93        if adapter_schema.task_type != schema.task_type {
94            return Err(OCRError::ConfigError {
95                message: format!(
96                    "Adapter task type {:?} does not match required task type {:?}",
97                    adapter_schema.task_type, schema.task_type
98                ),
99            });
100        }
101        Ok(())
102    }
103
104    /// Returns whether this adapter can handle batched inputs efficiently.
105    fn supports_batching(&self) -> bool {
106        true // Most models support batching
107    }
108
109    /// Returns the recommended batch size for this adapter.
110    fn recommended_batch_size(&self) -> usize {
111        6 // Default from constants
112    }
113}
114
115/// Builder trait for creating model adapters.
116///
117/// This trait defines the interface for building adapters with specific configurations.
118pub trait AdapterBuilder: Sized {
119    /// The configuration type for this builder
120    type Config: Send + Sync + Debug + Clone;
121
122    /// The adapter type that this builder creates
123    type Adapter: ModelAdapter;
124
125    /// Builds an adapter from a model file.
126    ///
127    /// # Arguments
128    ///
129    /// * `model_path` - Path to the model file (e.g., ONNX file)
130    ///
131    /// # Returns
132    ///
133    /// The built adapter or an error
134    fn build(self, model_path: &Path) -> Result<Self::Adapter, OCRError>;
135
136    /// Configures the builder with the given configuration.
137    ///
138    /// # Arguments
139    ///
140    /// * `config` - The configuration to use
141    ///
142    /// # Returns
143    ///
144    /// The configured builder
145    fn with_config(self, config: Self::Config) -> Self;
146
147    /// Returns the adapter type identifier.
148    fn adapter_type(&self) -> &str;
149}
150
151/// Trait for adapter builders that support ONNX Runtime session configuration.
152///
153/// This trait is implemented by builders that can be configured with ORT session
154/// settings like execution providers, thread count, and memory optimization.
155pub trait OrtConfigurable: Sized {
156    /// Configures the builder with ONNX Runtime session settings.
157    fn with_ort_config(self, config: crate::core::config::OrtSessionConfig) -> Self;
158}
159
160/// A wrapper that implements Task for an adapter's task type.
161///
162/// This allows adapters to be used polymorphically through the Task trait.
163#[derive(Debug)]
164pub struct AdapterTask<A: ModelAdapter> {
165    adapter: A,
166}
167
168impl<A: ModelAdapter> AdapterTask<A> {
169    /// Creates a new adapter task.
170    pub fn new(adapter: A) -> Self {
171        Self { adapter }
172    }
173
174    /// Returns a reference to the adapter.
175    pub fn adapter(&self) -> &A {
176        &self.adapter
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_adapter_info_creation() {
186        let info = AdapterInfo::new(
187            "DB",
188            TaskType::TextDetection,
189            "Differentiable Binarization text detector",
190        );
191
192        assert_eq!(info.model_name, "DB");
193        assert_eq!(info.task_type, TaskType::TextDetection);
194    }
195
196    #[test]
197    fn test_schema_validation() {
198        // This is a conceptual test - actual validation would be done with real adapters
199        let schema = TaskSchema::new(
200            TaskType::TextDetection,
201            vec!["image".to_string()],
202            vec!["text_boxes".to_string()],
203        );
204
205        assert_eq!(schema.task_type, TaskType::TextDetection);
206        assert_eq!(schema.input_types, vec!["image".to_string()]);
207    }
208}