oar_ocr/pipeline/stages/
extensible.rs

1//! Extensible pipeline stage traits and types.
2//!
3//! This module provides the core traits and types for building extensible
4//! pipeline stages that can be dynamically registered and executed.
5
6use image::RgbImage;
7use serde::{Deserialize, Serialize};
8use std::any::Any;
9use std::collections::HashMap;
10use std::fmt::Debug;
11use std::sync::Arc;
12
13use super::types::StageResult;
14use crate::core::OCRError;
15
16/// Unique identifier for a pipeline stage.
17#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
18pub struct StageId(pub String);
19
20impl std::fmt::Display for StageId {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        write!(f, "{}", self.0)
23    }
24}
25
26impl StageId {
27    /// Create a new stage ID.
28    pub fn new(id: impl Into<String>) -> Self {
29        Self(id.into())
30    }
31
32    /// Get the string representation of the stage ID.
33    pub fn as_str(&self) -> &str {
34        &self.0
35    }
36}
37
38impl From<&str> for StageId {
39    fn from(s: &str) -> Self {
40        Self(s.to_string())
41    }
42}
43
44impl From<String> for StageId {
45    fn from(s: String) -> Self {
46        Self(s)
47    }
48}
49
50/// Dependency specification for a pipeline stage.
51#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
52pub enum StageDependency {
53    /// Stage must run after the specified stage
54    After(StageId),
55    /// Stage must run before the specified stage
56    Before(StageId),
57    /// Stage requires output from the specified stage
58    Requires(StageId),
59    /// Stage provides input to the specified stage
60    Provides(StageId),
61}
62
63/// Context information available to pipeline stages during execution.
64#[derive(Debug)]
65pub struct StageContext {
66    /// Current image being processed
67    pub current_image: Arc<RgbImage>,
68    /// Original input image (before any transformations)
69    pub original_image: Arc<RgbImage>,
70    /// Index of the current image in batch processing
71    pub image_index: usize,
72    /// Global pipeline configuration
73    pub global_config: HashMap<String, serde_json::Value>,
74    /// Results from previous stages
75    pub stage_results: HashMap<StageId, Box<dyn Any + Send + Sync>>,
76}
77
78impl StageContext {
79    /// Create a new stage context.
80    pub fn new(
81        current_image: Arc<RgbImage>,
82        original_image: Arc<RgbImage>,
83        image_index: usize,
84    ) -> Self {
85        Self {
86            current_image,
87            original_image,
88            image_index,
89            global_config: HashMap::new(),
90            stage_results: HashMap::new(),
91        }
92    }
93
94    /// Get a result from a previous stage.
95    pub fn get_stage_result<T: 'static>(&self, stage_id: &StageId) -> Option<&T> {
96        self.stage_results
97            .get(stage_id)
98            .and_then(|result| result.downcast_ref::<T>())
99    }
100
101    /// Set a result from a stage.
102    pub fn set_stage_result<T: 'static + Send + Sync>(&mut self, stage_id: StageId, result: T) {
103        self.stage_results.insert(stage_id, Box::new(result));
104    }
105
106    /// Get a global configuration value.
107    pub fn get_config<T>(&self, key: &str) -> Option<T>
108    where
109        T: for<'de> Deserialize<'de>,
110    {
111        self.global_config
112            .get(key)
113            .and_then(|value| serde_json::from_value(value.clone()).ok())
114    }
115
116    /// Set a global configuration value.
117    pub fn set_config<T>(&mut self, key: String, value: T)
118    where
119        T: Serialize,
120    {
121        if let Ok(json_value) = serde_json::to_value(value) {
122            self.global_config.insert(key, json_value);
123        }
124    }
125}
126
127/// Data that flows between pipeline stages.
128#[derive(Debug, Clone)]
129pub struct StageData {
130    /// The current processed image
131    pub image: RgbImage,
132    /// Metadata associated with the stage processing
133    pub metadata: HashMap<String, serde_json::Value>,
134}
135
136impl StageData {
137    /// Create new stage data with an image.
138    pub fn new(image: RgbImage) -> Self {
139        Self {
140            image,
141            metadata: HashMap::new(),
142        }
143    }
144
145    /// Add metadata to the stage data.
146    pub fn with_metadata<T: Serialize>(mut self, key: String, value: T) -> Self {
147        if let Ok(json_value) = serde_json::to_value(value) {
148            self.metadata.insert(key, json_value);
149        }
150        self
151    }
152
153    /// Get metadata from the stage data.
154    pub fn get_metadata<T>(&self, key: &str) -> Option<T>
155    where
156        T: for<'de> Deserialize<'de>,
157    {
158        self.metadata
159            .get(key)
160            .and_then(|value| serde_json::from_value(value.clone()).ok())
161    }
162}
163
164/// Trait for extensible pipeline stages.
165///
166/// This trait defines the interface that all pipeline stages must implement
167/// to participate in the extensible pipeline system.
168///
169/// # Default Contract
170///
171/// All pipeline stages must implement a **default contract** to ensure consistency,
172/// reliability, and maintainability. This contract eliminates silent failures and
173/// ensures all stages provide meaningful defaults and proper validation.
174///
175/// ## Contract Requirements:
176///
177/// 1. **Mandatory Default Configuration**: `default_config()` must return a valid `Self::Config`
178/// 2. **ConfigValidator Implementation**: All config types must implement `ConfigValidator`
179/// 3. **Validation Integration**: Stages must use ConfigValidator in `validate_config()`
180/// 4. **Required Traits**: Config types must implement `ConfigValidator + Default + Clone + Debug + Send + Sync`
181///
182/// ## Example Implementation:
183///
184/// ```rust
185/// use oar_ocr::pipeline::stages::{PipelineStage, StageId, StageContext, StageData, StageResult};
186/// use oar_ocr::core::config::{ConfigValidator, ConfigError};
187/// use oar_ocr::core::OCRError;
188/// use serde::{Serialize, Deserialize};
189///
190/// #[derive(Debug, Clone, Serialize, Deserialize, Default)]
191/// pub struct MyStageConfig {
192///     pub threshold: f32,
193///     pub enabled: bool,
194/// }
195///
196/// impl ConfigValidator for MyStageConfig {
197///     fn validate(&self) -> Result<(), ConfigError> {
198///         if !(0.0..=1.0).contains(&self.threshold) {
199///             return Err(ConfigError::InvalidConfig {
200///                 message: "threshold must be between 0.0 and 1.0".to_string(),
201///             });
202///         }
203///         Ok(())
204///     }
205///
206///     fn get_defaults() -> Self {
207///         Self { threshold: 0.5, enabled: true }
208///     }
209/// }
210///
211/// #[derive(Debug)]
212/// pub struct MyStage;
213///
214/// impl PipelineStage for MyStage {
215///     type Config = MyStageConfig;
216///     type Result = String;
217///
218///     fn stage_id(&self) -> StageId { StageId::new("my_stage") }
219///     fn stage_name(&self) -> &str { "My Stage" }
220///
221///     fn validate_config(&self, config: &Self::Config) -> Result<(), OCRError> {
222///         config.validate().map_err(|e| OCRError::ConfigError {
223///             message: format!("MyStageConfig validation failed: {}", e),
224///         })
225///     }
226///
227///     fn default_config(&self) -> Self::Config {
228///         MyStageConfig::get_defaults()
229///     }
230///
231///     fn process(
232///         &self,
233///         _context: &mut StageContext,
234///         _data: StageData,
235///         config: Option<&Self::Config>,
236///     ) -> Result<StageResult<Self::Result>, OCRError> {
237///         let config = config.cloned().unwrap_or_else(|| self.default_config());
238///         self.validate_config(&config)?;
239///         // Process with validated configuration...
240///         # Ok(StageResult::new("result".to_string(), Default::default()))
241///     }
242/// }
243/// ```
244///
245/// See [`DEFAULT_CONTRACT.md`](./DEFAULT_CONTRACT.md) for detailed documentation.
246pub trait PipelineStage: Send + Sync + Debug {
247    /// The configuration type for this stage.
248    ///
249    /// Must implement ConfigValidator to ensure validation is never skipped.
250    type Config: Send + Sync + Debug + crate::core::config::ConfigValidator + Default;
251
252    /// The result type produced by this stage.
253    type Result: Send + Sync + Debug + 'static;
254
255    /// Get the unique identifier for this stage.
256    fn stage_id(&self) -> StageId;
257
258    /// Get the human-readable name of this stage.
259    fn stage_name(&self) -> &str;
260
261    /// Get the dependencies for this stage.
262    fn dependencies(&self) -> Vec<StageDependency> {
263        Vec::new()
264    }
265
266    /// Check if this stage is enabled based on the context and configuration.
267    fn is_enabled(&self, context: &StageContext, config: Option<&Self::Config>) -> bool {
268        let _ = (context, config);
269        true
270    }
271
272    /// Process the stage with the given context and configuration.
273    ///
274    /// # Arguments
275    ///
276    /// * `context` - The stage execution context
277    /// * `data` - The input data for this stage
278    /// * `config` - Optional stage-specific configuration
279    ///
280    /// # Returns
281    ///
282    /// A StageResult containing the processed data and metrics
283    fn process(
284        &self,
285        context: &mut StageContext,
286        data: StageData,
287        config: Option<&Self::Config>,
288    ) -> Result<StageResult<Self::Result>, OCRError>;
289
290    /// Validate the stage configuration.
291    fn validate_config(&self, config: &Self::Config) -> Result<(), OCRError> {
292        let _ = config;
293        Ok(())
294    }
295
296    /// Get default configuration for this stage.
297    ///
298    /// This method must return a valid configuration. The default contract
299    /// ensures all stages provide meaningful defaults.
300    fn default_config(&self) -> Self::Config {
301        Self::Config::default()
302    }
303}