Skip to main content

captcha_engine/
model.rs

1//! Model loading and prediction for captcha recognition.
2
3use crate::{Result, error::Error, image_ops, tokenizer::Tokenizer};
4use image::DynamicImage;
5use ort::session::{Session, builder::GraphOptimizationLevel};
6use ort::value::Tensor;
7use std::path::Path;
8use std::sync::Mutex;
9
10#[cfg(feature = "download")]
11use log::info;
12#[cfg(feature = "download")]
13use reqwest::blocking::Client;
14#[cfg(feature = "download")]
15use std::fs;
16#[cfg(feature = "download")]
17use std::path::PathBuf;
18
19/// URL to download the model from `HuggingFace`.
20#[cfg(feature = "download")]
21const MODEL_URL_HUGGINGFACE: &str =
22    "https://huggingface.co/Milang/captcha-solver/resolve/main/captcha.onnx";
23
24/// Embedded model bytes (only available with `embed-model` feature).
25#[cfg(feature = "embed-model")]
26const EMBEDDED_MODEL: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/model.onnx"));
27
28/// The main captcha-breaking model.
29///
30/// Wraps an ONNX model and tokenizer for end-to-end captcha recognition.
31pub struct CaptchaModel {
32    session: Mutex<Session>,
33    tokenizer: Tokenizer,
34}
35
36impl std::fmt::Debug for CaptchaModel {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        f.debug_struct("CaptchaModel")
39            .field("tokenizer", &self.tokenizer)
40            .finish_non_exhaustive()
41    }
42}
43
44impl CaptchaModel {
45    /// Create a configured ONNX Runtime `SessionBuilder`.
46    ///
47    /// # Errors
48    ///
49    /// Returns an error if the session builder cannot be created or configured.
50    #[allow(unused_mut)]
51    fn create_session_builder() -> Result<ort::session::builder::SessionBuilder> {
52        let mut builder = Session::builder().map_err(|e| Error::ModelLoad(e.to_string()))?;
53
54        // Configure execution providers with fallbacks
55        // Order matters: attempt specialized providers first, then CPU
56        let mut providers = Vec::new();
57
58        #[cfg(target_os = "macos")]
59        {
60            providers.push(ort::execution_providers::CoreMLExecutionProvider::default().build());
61        }
62
63        #[cfg(any(target_os = "windows", target_os = "linux"))]
64        {
65            providers.push(ort::execution_providers::CUDAExecutionProvider::default().build());
66        }
67
68        providers.push(ort::execution_providers::CPUExecutionProvider::default().build());
69
70        builder = builder
71            .with_execution_providers(providers)
72            .map_err(|e| Error::ModelLoad(e.to_string()))?
73            .with_optimization_level(GraphOptimizationLevel::Level3)
74            .map_err(|e| Error::ModelLoad(e.to_string()))?;
75
76        Ok(builder)
77    }
78
79    /// Load a model from an ONNX file.
80    ///
81    /// # Errors
82    ///
83    /// Returns an error if the model file cannot be read or loaded.
84    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
85        let session = Self::create_session_builder()?
86            .commit_from_file(path.as_ref())
87            .map_err(|e| Error::ModelLoad(e.to_string()))?;
88
89        Ok(Self {
90            session: Mutex::new(session),
91            tokenizer: Tokenizer::default(),
92        })
93    }
94
95    /// Load a model from memory (bytes).
96    ///
97    /// # Errors
98    ///
99    /// Returns an error if the model cannot be loaded.
100    pub fn load_from_memory(model_bytes: &[u8]) -> Result<Self> {
101        let session = Self::create_session_builder()?
102            .commit_from_memory(model_bytes)
103            .map_err(|e| Error::ModelLoad(e.to_string()))?;
104
105        Ok(Self {
106            session: Mutex::new(session),
107            tokenizer: Tokenizer::default(),
108        })
109    }
110
111    /// Load the embedded model (only available with `embed-model` feature).
112    ///
113    /// This loads the model directly from bytes compiled into the binary,
114    /// requiring no network access or external files.
115    ///
116    /// # Errors
117    ///
118    /// Returns an error if the embedded model cannot be loaded.
119    #[cfg(feature = "embed-model")]
120    pub fn load_embedded() -> Result<Self> {
121        Self::load_from_memory(EMBEDDED_MODEL)
122    }
123
124    /// Predict the text in a captcha image.
125    ///
126    /// # Errors
127    ///
128    /// Returns an error if inference fails.
129    #[allow(clippy::significant_drop_tightening)]
130    pub fn predict(&self, image: &DynamicImage) -> Result<String> {
131        let array = image_ops::preprocess(image);
132
133        // Create ort Tensor from ndarray
134        let tensor = Tensor::from_array(array).map_err(|e| Error::Inference(e.to_string()))?;
135
136        let mut session = self
137            .session
138            .lock()
139            .map_err(|_| Error::Inference("Session mutex poisoned".into()))?;
140
141        let outputs = session
142            .run(ort::inputs!["input" => tensor])
143            .map_err(|e| Error::Inference(e.to_string()))?;
144
145        // Get output by name or use first value
146        let output = outputs
147            .iter()
148            .find(|(name, _)| *name == "output")
149            .map(|(_, v)| v)
150            .or_else(|| outputs.iter().next().map(|(_, v)| v))
151            .ok_or_else(|| Error::Inference("No output tensor".into()))?;
152
153        let probs = output
154            .try_extract_array::<f32>()
155            .map_err(|e| Error::Inference(e.to_string()))?;
156
157        Ok(self.tokenizer.decode(&probs.view()))
158    }
159
160    /// Predict the text in a captcha image loaded from a file path.
161    ///
162    /// # Errors
163    ///
164    /// Returns an error if the image cannot be loaded or inference fails.
165    pub fn predict_file<P: AsRef<Path>>(&self, path: P) -> Result<String> {
166        let image = image::open(path)?;
167        self.predict(&image)
168    }
169}
170
171/// Ensures the model exists at the given path. If not, downloads it.
172///
173/// # Errors
174///
175/// Returns an error if the directory cannot be created, the download fails,
176/// or the file cannot be written.
177#[cfg(feature = "download")]
178pub fn ensure_model_downloaded<P: AsRef<Path>>(storage_dir: P) -> Result<PathBuf> {
179    let storage_dir = storage_dir.as_ref();
180    if !storage_dir.exists() {
181        fs::create_dir_all(storage_dir)?;
182    }
183
184    let model_path = storage_dir.join("captcha.onnx");
185
186    if model_path.exists() {
187        return Ok(model_path);
188    }
189
190    info!(
191        "Downloading captcha model to {path}",
192        path = model_path.display()
193    );
194
195    let client = Client::new();
196    let mut res = client.get(MODEL_URL_HUGGINGFACE).send()?;
197
198    if !res.status().is_success() {
199        return Err(Error::ModelDownload(format!(
200            "Failed to download model: status {}",
201            res.status()
202        )));
203    }
204
205    let mut file = fs::File::create(&model_path)?;
206    res.copy_to(&mut file)?;
207
208    Ok(model_path)
209}
210
211#[cfg(test)]
212mod tests {
213    #![allow(clippy::unwrap_used)]
214    use super::*;
215
216    #[cfg(feature = "embed-model")]
217    #[test]
218    fn test_embedded_model_loads() {
219        let result = CaptchaModel::load_embedded();
220        assert!(
221            result.is_ok(),
222            "Embedded model should load successfully: {:?}",
223            result.err()
224        );
225    }
226
227    #[test]
228    fn test_load_from_invalid_memory() {
229        let invalid_bytes = b"not a model";
230        let result = CaptchaModel::load_from_memory(invalid_bytes);
231        assert!(result.is_err(), "Loading from invalid bytes should fail");
232    }
233
234    #[cfg(feature = "embed-model")]
235    #[test]
236    fn test_prediction_dummy_image() {
237        use image::{DynamicImage, RgbImage};
238        let model = CaptchaModel::load_embedded().expect("Failed to load embedded model");
239        // Create dummy white image (215x80 RGB)
240        let img =
241            DynamicImage::ImageRgb8(RgbImage::from_pixel(215, 80, image::Rgb([255, 255, 255])));
242
243        // Predict
244        let result = model.predict(&img);
245        assert!(result.is_ok());
246    }
247
248    #[test]
249    fn test_load_local_model() {
250        let path = Path::new("model.onnx");
251        if path.exists() {
252            let model = CaptchaModel::load(path);
253            assert!(model.is_ok(), "Failed to load local model.onnx");
254        } else {
255            println!("Skipping test_load_local_model: model.onnx not found");
256        }
257    }
258
259    #[test]
260    fn test_predict_real_image() {
261        let model_path = Path::new("model.onnx");
262        let image_path = Path::new("../../test-captcha/captcha-3q5hQL.png");
263
264        if model_path.exists() && image_path.exists() {
265            let model = CaptchaModel::load(model_path).expect("Failed to load model");
266            let result = model.predict_file(image_path);
267            assert!(result.is_ok(), "Prediction failed: {:?}", result.err());
268            let text = result.unwrap();
269            assert!(!text.is_empty(), "Prediction result is empty");
270            println!("Predicted text: {text}");
271        } else {
272            println!("Skipping test_predict_real_image: resources not found");
273        }
274    }
275}