Skip to main content

captcha_engine/
model.rs

1use crate::{Result, error::Error, image_ops, tokenizer::Tokenizer};
2use image::DynamicImage;
3use rten::Model;
4use rten_tensor::Tensor;
5use std::path::Path;
6
7#[cfg(feature = "download")]
8use log::info;
9#[cfg(feature = "download")]
10use reqwest::blocking::Client;
11#[cfg(feature = "download")]
12use std::fs;
13#[cfg(feature = "download")]
14use std::path::PathBuf;
15
16/// URL to download the model from `HuggingFace`.
17#[cfg(feature = "download")]
18const MODEL_URL_HUGGINGFACE: &str =
19    "https://huggingface.co/Milang/captcha-solver/resolve/main/captcha.rten";
20
21/// Embedded model bytes (only available with `embed-model` feature).
22#[cfg(feature = "embed-model")]
23const EMBEDDED_MODEL: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/model.rten"));
24
25/// The main captcha-breaking model.
26///
27/// Wraps an `RTen` model and tokenizer for end-to-end captcha recognition.
28pub struct CaptchaModel {
29    model: Model,
30    tokenizer: Tokenizer,
31}
32
33impl std::fmt::Debug for CaptchaModel {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("CaptchaModel")
36            .field("tokenizer", &self.tokenizer)
37            .finish_non_exhaustive()
38    }
39}
40
41impl CaptchaModel {
42    /// Load a model from an `RTen` file.
43    ///
44    /// # Errors
45    ///
46    /// Returns an error if the model file cannot be read or loaded.
47    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
48        let model = Model::load_file(path).map_err(|e| Error::ModelLoad(e.to_string()))?;
49
50        Ok(Self {
51            model,
52            tokenizer: Tokenizer::default(),
53        })
54    }
55
56    /// Load a model from memory (bytes).
57    ///
58    /// # Errors
59    ///
60    /// Returns an error if the model cannot be loaded.
61    pub fn load_from_memory(model_bytes: &[u8]) -> Result<Self> {
62        let model =
63            Model::load(model_bytes.to_vec()).map_err(|e| Error::ModelLoad(e.to_string()))?;
64
65        Ok(Self {
66            model,
67            tokenizer: Tokenizer::default(),
68        })
69    }
70
71    /// Load the embedded model (only available with `embed-model` feature).
72    ///
73    /// This loads the model directly from bytes compiled into the binary,
74    /// requiring no network access or external files.
75    ///
76    /// # Errors
77    ///
78    /// Returns an error if the embedded model cannot be loaded.
79    #[cfg(feature = "embed-model")]
80    pub fn load_embedded() -> Result<Self> {
81        Self::load_from_memory(EMBEDDED_MODEL)
82    }
83
84    /// Predict the text in a captcha image.
85    ///
86    /// # Errors
87    ///
88    /// Returns an error if inference fails.
89    pub fn predict(&self, image: &DynamicImage) -> Result<String> {
90        // Preprocess image to tensor
91        let input_tensor = image_ops::preprocess(image);
92
93        // Run inference
94        // rten 0.24 Model::run signature:
95        // pub fn run(&self, inputs: Vec<(NodeId, Tensor)>, outputs: Vec<NodeId>, options: Option<RunOptions>) -> Result<Vec<Tensor>, RunError>
96        // OR
97        // run supports input names.
98        // Actually, looking at 0.14 vs 0.24, usually strings are supported via helper or conversion.
99        // But the error `expected &[NodeId], found Vec<String>` suggests 0.24 might need NodeIDs?
100        // Wait, Model usually has a way to find NodeId by name.
101
102        // Let's look up Node IDs from names.
103        // node_id returns Result in 0.24
104        let input_id = self
105            .model
106            .node_id("input")
107            .map_err(|e| Error::Inference(format!("Input node 'input' error: {e}")))?;
108        let output_id = self
109            .model
110            .node_id("output")
111            .map_err(|e| Error::Inference(format!("Output node 'output' error: {e}")))?;
112
113        let inputs = vec![(input_id, input_tensor.into())];
114
115        let mut outputs = self
116            .model
117            .run(inputs, &[output_id], None)
118            .map_err(|e| Error::Inference(e.to_string()))?;
119
120        let output_value = outputs.remove(0);
121
122        // Output should be float tensor.
123        // Value::try_into() converts to Tensor if it matches.
124        // We use owned Tensor for simplicity.
125        let output_tensor: Tensor<f32> = output_value
126            .try_into()
127            .map_err(|_| Error::Inference("Output is not a float tensor".into()))?;
128
129        // We can pass a view to the tokenizer if it takes ndarray or slice.
130        Ok(self.tokenizer.decode_rten(&output_tensor))
131    }
132
133    /// Predict the text in a captcha image loaded from a file path.
134    ///
135    /// # Errors
136    ///
137    /// Returns an error if the image cannot be loaded or inference fails.
138    pub fn predict_file<P: AsRef<Path>>(&self, path: P) -> Result<String> {
139        let image = image::open(path)?;
140        self.predict(&image)
141    }
142}
143
144/// Ensures the model exists at the given path. If not, downloads it.
145///
146/// # Errors
147///
148/// Returns an error if the directory cannot be created, the download fails,
149/// or the file cannot be written.
150#[cfg(feature = "download")]
151pub fn ensure_model_downloaded<P: AsRef<Path>>(storage_dir: P) -> Result<PathBuf> {
152    let storage_dir = storage_dir.as_ref();
153    if !storage_dir.exists() {
154        fs::create_dir_all(storage_dir)?;
155    }
156
157    // Updated filename for rten
158    let model_path = storage_dir.join("captcha.rten");
159
160    if model_path.exists() {
161        return Ok(model_path);
162    }
163
164    info!(
165        "Downloading captcha model to {path}",
166        path = model_path.display()
167    );
168
169    let client = Client::new();
170    let mut res = client.get(MODEL_URL_HUGGINGFACE).send()?;
171
172    if !res.status().is_success() {
173        return Err(Error::ModelDownload(format!(
174            "Failed to download model: status {}",
175            res.status()
176        )));
177    }
178
179    let mut file = fs::File::create(&model_path)?;
180    res.copy_to(&mut file)?;
181
182    Ok(model_path)
183}
184
185#[cfg(test)]
186mod tests {
187    #![allow(clippy::unwrap_used)]
188    use super::*;
189
190    #[cfg(feature = "embed-model")]
191    #[test]
192    fn test_embedded_model_loads() {
193        let result = CaptchaModel::load_embedded();
194        // This will fail if model.rten is not in OUT_DIR, which is normal during dev without build script
195        if let Err(e) = &result {
196            println!(
197                "Embedded model load failed (expected if not building with build.rs): {}",
198                e
199            );
200        }
201    }
202
203    #[test]
204    fn test_load_from_invalid_memory() {
205        let invalid_bytes = b"not a model";
206        let result = CaptchaModel::load_from_memory(invalid_bytes);
207        assert!(result.is_err(), "Loading from invalid bytes should fail");
208    }
209
210    #[test]
211    fn test_load_local_model() {
212        let path = Path::new("model.rten");
213        if path.exists() {
214            let model = CaptchaModel::load(path);
215            assert!(model.is_ok(), "Failed to load local model.rten");
216        } else {
217            println!("Skipping test_load_local_model: model.rten not found");
218        }
219    }
220}