oxify_connect_vision/providers/
mod.rs1mod mock;
4
5#[cfg(feature = "tesseract")]
6mod tesseract;
7
8#[cfg(feature = "surya")]
9mod surya;
10
11#[cfg(feature = "paddle")]
12mod paddle;
13
14#[cfg(feature = "google-vision")]
15mod google_vision;
16
17pub use mock::MockVisionProvider;
19
20#[cfg(feature = "tesseract")]
21pub use tesseract::TesseractProvider;
22
23#[cfg(feature = "surya")]
24pub use surya::SuryaClient;
25
26#[cfg(feature = "paddle")]
27pub use paddle::PaddleOcrClient;
28
29#[cfg(feature = "google-vision")]
30pub use google_vision::{CostStats, GoogleVisionConfig, GoogleVisionProvider};
31
32use crate::errors::{Result, VisionError};
33use crate::types::OcrResult;
34use async_trait::async_trait;
35
36#[derive(Debug, Clone)]
38pub struct VisionProviderConfig {
39 pub provider: String,
41 pub model_path: Option<String>,
43 pub output_format: String,
45 pub use_gpu: bool,
47 pub language: Option<String>,
49 pub options: std::collections::HashMap<String, String>,
51}
52
53impl Default for VisionProviderConfig {
54 fn default() -> Self {
55 Self {
56 provider: "mock".to_string(),
57 model_path: None,
58 output_format: "markdown".to_string(),
59 use_gpu: false,
60 language: None,
61 options: std::collections::HashMap::new(),
62 }
63 }
64}
65
66impl VisionProviderConfig {
67 pub fn mock() -> Self {
69 Self::default()
70 }
71
72 pub fn tesseract(language: Option<&str>) -> Self {
74 Self {
75 provider: "tesseract".to_string(),
76 language: language.map(|s| s.to_string()),
77 ..Default::default()
78 }
79 }
80
81 pub fn surya(model_path: &str, use_gpu: bool) -> Self {
83 Self {
84 provider: "surya".to_string(),
85 model_path: Some(model_path.to_string()),
86 use_gpu,
87 ..Default::default()
88 }
89 }
90
91 pub fn paddle(model_path: &str, use_gpu: bool) -> Self {
93 Self {
94 provider: "paddle".to_string(),
95 model_path: Some(model_path.to_string()),
96 use_gpu,
97 ..Default::default()
98 }
99 }
100
101 pub fn google_vision() -> Self {
103 Self {
104 provider: "google_vision".to_string(),
105 ..Default::default()
106 }
107 }
108}
109
110#[async_trait]
112pub trait VisionProvider: Send + Sync {
113 async fn process_image(&self, image_data: &[u8]) -> Result<OcrResult>;
115
116 async fn load_model(&self) -> Result<()>;
119
120 async fn unload_model(&self) -> Result<()> {
122 Ok(()) }
124
125 fn is_model_loaded(&self) -> bool {
127 true }
129
130 fn provider_name(&self) -> &str;
132
133 fn capabilities(&self) -> ProviderCapabilities {
135 ProviderCapabilities::default()
136 }
137}
138
139#[derive(Debug, Clone, Default)]
141pub struct ProviderCapabilities {
142 pub table_detection: bool,
144 pub layout_analysis: bool,
146 pub handwriting: bool,
148 pub multi_language: bool,
150 pub gpu_acceleration: bool,
152 pub languages: Vec<String>,
154}
155
156pub fn create_provider(config: &VisionProviderConfig) -> Result<Box<dyn VisionProvider>> {
158 match config.provider.as_str() {
159 "mock" => Ok(Box::new(MockVisionProvider::new())),
160
161 #[cfg(feature = "tesseract")]
162 "tesseract" => Ok(Box::new(TesseractProvider::new(config.language.as_deref()))),
163
164 #[cfg(feature = "surya")]
165 "surya" => {
166 let model_path = config
167 .model_path
168 .as_ref()
169 .ok_or_else(|| VisionError::config("Surya requires model_path"))?;
170 Ok(Box::new(SuryaClient::new(model_path, config.use_gpu)))
171 }
172
173 #[cfg(feature = "paddle")]
174 "paddle" => {
175 let model_path = config
176 .model_path
177 .as_ref()
178 .ok_or_else(|| VisionError::config("PaddleOCR requires model_path"))?;
179 Ok(Box::new(PaddleOcrClient::new(model_path, config.use_gpu)))
180 }
181
182 #[cfg(feature = "google-vision")]
183 "google_vision" => {
184 let gcp_config = google_vision::GoogleVisionConfig::default();
185 Ok(Box::new(google_vision::GoogleVisionProvider::new(
186 gcp_config,
187 )))
188 }
189
190 provider => Err(VisionError::unsupported_provider(provider)),
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn test_create_mock_provider() {
200 let config = VisionProviderConfig::mock();
201 let provider = create_provider(&config).unwrap();
202 assert_eq!(provider.provider_name(), "mock");
203 }
204
205 #[test]
206 fn test_unsupported_provider() {
207 let config = VisionProviderConfig {
208 provider: "unsupported".to_string(),
209 ..Default::default()
210 };
211 let result = create_provider(&config);
212 assert!(result.is_err());
213 }
214}