oxify_connect_vision/providers/
mod.rs

1//! Vision/OCR provider implementations.
2
3mod mock;
4
5#[cfg(feature = "tesseract")]
6mod tesseract;
7
8#[cfg(feature = "surya")]
9mod surya;
10
11#[cfg(feature = "paddle")]
12mod paddle;
13
14#[cfg(feature = "google-vision")]
15mod google_vision;
16
17// Re-exports
18pub use mock::MockVisionProvider;
19
20#[cfg(feature = "tesseract")]
21pub use tesseract::TesseractProvider;
22
23#[cfg(feature = "surya")]
24pub use surya::SuryaClient;
25
26#[cfg(feature = "paddle")]
27pub use paddle::PaddleOcrClient;
28
29#[cfg(feature = "google-vision")]
30pub use google_vision::{CostStats, GoogleVisionConfig, GoogleVisionProvider};
31
32use crate::errors::{Result, VisionError};
33use crate::types::OcrResult;
34use async_trait::async_trait;
35
36/// Configuration for vision providers.
37#[derive(Debug, Clone)]
38pub struct VisionProviderConfig {
39    /// Provider name.
40    pub provider: String,
41    /// Path to model files (for ONNX-based providers).
42    pub model_path: Option<String>,
43    /// Output format preference.
44    pub output_format: String,
45    /// Whether to use GPU acceleration.
46    pub use_gpu: bool,
47    /// Target language(s) for OCR.
48    pub language: Option<String>,
49    /// Additional provider-specific options.
50    pub options: std::collections::HashMap<String, String>,
51}
52
53impl Default for VisionProviderConfig {
54    fn default() -> Self {
55        Self {
56            provider: "mock".to_string(),
57            model_path: None,
58            output_format: "markdown".to_string(),
59            use_gpu: false,
60            language: None,
61            options: std::collections::HashMap::new(),
62        }
63    }
64}
65
66impl VisionProviderConfig {
67    /// Create a new configuration for the mock provider.
68    pub fn mock() -> Self {
69        Self::default()
70    }
71
72    /// Create a new configuration for Tesseract.
73    pub fn tesseract(language: Option<&str>) -> Self {
74        Self {
75            provider: "tesseract".to_string(),
76            language: language.map(|s| s.to_string()),
77            ..Default::default()
78        }
79    }
80
81    /// Create a new configuration for Surya.
82    pub fn surya(model_path: &str, use_gpu: bool) -> Self {
83        Self {
84            provider: "surya".to_string(),
85            model_path: Some(model_path.to_string()),
86            use_gpu,
87            ..Default::default()
88        }
89    }
90
91    /// Create a new configuration for PaddleOCR.
92    pub fn paddle(model_path: &str, use_gpu: bool) -> Self {
93        Self {
94            provider: "paddle".to_string(),
95            model_path: Some(model_path.to_string()),
96            use_gpu,
97            ..Default::default()
98        }
99    }
100
101    /// Create a new configuration for Google Cloud Vision.
102    pub fn google_vision() -> Self {
103        Self {
104            provider: "google_vision".to_string(),
105            ..Default::default()
106        }
107    }
108}
109
110/// Core trait for vision/OCR providers.
111#[async_trait]
112pub trait VisionProvider: Send + Sync {
113    /// Process an image and extract text/layout information.
114    async fn process_image(&self, image_data: &[u8]) -> Result<OcrResult>;
115
116    /// Load the model into memory.
117    /// For some providers (like Mock), this is a no-op.
118    async fn load_model(&self) -> Result<()>;
119
120    /// Unload the model from memory.
121    async fn unload_model(&self) -> Result<()> {
122        Ok(()) // Default no-op
123    }
124
125    /// Check if the model is loaded.
126    fn is_model_loaded(&self) -> bool {
127        true // Default: always ready
128    }
129
130    /// Get the provider name.
131    fn provider_name(&self) -> &str;
132
133    /// Get provider capabilities.
134    fn capabilities(&self) -> ProviderCapabilities {
135        ProviderCapabilities::default()
136    }
137}
138
139/// Provider capabilities.
140#[derive(Debug, Clone, Default)]
141pub struct ProviderCapabilities {
142    /// Supports table detection and extraction.
143    pub table_detection: bool,
144    /// Supports layout analysis.
145    pub layout_analysis: bool,
146    /// Supports handwriting recognition.
147    pub handwriting: bool,
148    /// Supports multi-language detection.
149    pub multi_language: bool,
150    /// Supports GPU acceleration.
151    pub gpu_acceleration: bool,
152    /// Supported languages.
153    pub languages: Vec<String>,
154}
155
156/// Factory function to create a provider from configuration.
157pub fn create_provider(config: &VisionProviderConfig) -> Result<Box<dyn VisionProvider>> {
158    match config.provider.as_str() {
159        "mock" => Ok(Box::new(MockVisionProvider::new())),
160
161        #[cfg(feature = "tesseract")]
162        "tesseract" => Ok(Box::new(TesseractProvider::new(config.language.as_deref()))),
163
164        #[cfg(feature = "surya")]
165        "surya" => {
166            let model_path = config
167                .model_path
168                .as_ref()
169                .ok_or_else(|| VisionError::config("Surya requires model_path"))?;
170            Ok(Box::new(SuryaClient::new(model_path, config.use_gpu)))
171        }
172
173        #[cfg(feature = "paddle")]
174        "paddle" => {
175            let model_path = config
176                .model_path
177                .as_ref()
178                .ok_or_else(|| VisionError::config("PaddleOCR requires model_path"))?;
179            Ok(Box::new(PaddleOcrClient::new(model_path, config.use_gpu)))
180        }
181
182        #[cfg(feature = "google-vision")]
183        "google_vision" => {
184            let gcp_config = google_vision::GoogleVisionConfig::default();
185            Ok(Box::new(google_vision::GoogleVisionProvider::new(
186                gcp_config,
187            )))
188        }
189
190        provider => Err(VisionError::unsupported_provider(provider)),
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn test_create_mock_provider() {
200        let config = VisionProviderConfig::mock();
201        let provider = create_provider(&config).unwrap();
202        assert_eq!(provider.provider_name(), "mock");
203    }
204
205    #[test]
206    fn test_unsupported_provider() {
207        let config = VisionProviderConfig {
208            provider: "unsupported".to_string(),
209            ..Default::default()
210        };
211        let result = create_provider(&config);
212        assert!(result.is_err());
213    }
214}