1use crate::{Result, error::Error, image_ops, tokenizer::Tokenizer};
2use image::DynamicImage;
3use rten::Model;
4use rten_tensor::Tensor;
5use std::path::Path;
6
7#[cfg(feature = "download")]
8use log::info;
9#[cfg(feature = "download")]
10use reqwest::blocking::Client;
11#[cfg(feature = "download")]
12use std::fs;
13#[cfg(feature = "download")]
14use std::path::PathBuf;
15
16#[cfg(feature = "download")]
18const MODEL_URL_HUGGINGFACE: &str =
19 "https://huggingface.co/Milang/captcha-solver/resolve/main/captcha.rten";
20
21#[cfg(feature = "embed-model")]
23const EMBEDDED_MODEL: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/model.rten"));
24
25pub struct CaptchaModel {
29 model: Model,
30 tokenizer: Tokenizer,
31}
32
33impl std::fmt::Debug for CaptchaModel {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("CaptchaModel")
36 .field("tokenizer", &self.tokenizer)
37 .finish_non_exhaustive()
38 }
39}
40
41impl CaptchaModel {
42 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
48 let model = Model::load_file(path).map_err(|e| Error::ModelLoad(e.to_string()))?;
49
50 Ok(Self {
51 model,
52 tokenizer: Tokenizer::default(),
53 })
54 }
55
56 pub fn load_from_memory(model_bytes: &[u8]) -> Result<Self> {
62 let model =
63 Model::load(model_bytes.to_vec()).map_err(|e| Error::ModelLoad(e.to_string()))?;
64
65 Ok(Self {
66 model,
67 tokenizer: Tokenizer::default(),
68 })
69 }
70
71 #[cfg(feature = "embed-model")]
80 pub fn load_embedded() -> Result<Self> {
81 Self::load_from_memory(EMBEDDED_MODEL)
82 }
83
84 pub fn predict(&self, image: &DynamicImage) -> Result<String> {
90 let input_tensor = image_ops::preprocess(image);
92
93 let input_id = self
105 .model
106 .node_id("input")
107 .map_err(|e| Error::Inference(format!("Input node 'input' error: {e}")))?;
108 let output_id = self
109 .model
110 .node_id("output")
111 .map_err(|e| Error::Inference(format!("Output node 'output' error: {e}")))?;
112
113 let inputs = vec![(input_id, input_tensor.into())];
114
115 let mut outputs = self
116 .model
117 .run(inputs, &[output_id], None)
118 .map_err(|e| Error::Inference(e.to_string()))?;
119
120 let output_value = outputs.remove(0);
121
122 let output_tensor: Tensor<f32> = output_value
126 .try_into()
127 .map_err(|_| Error::Inference("Output is not a float tensor".into()))?;
128
129 Ok(self.tokenizer.decode_rten(&output_tensor))
131 }
132
133 pub fn predict_file<P: AsRef<Path>>(&self, path: P) -> Result<String> {
139 let image = image::open(path)?;
140 self.predict(&image)
141 }
142}
143
144#[cfg(feature = "download")]
151pub fn ensure_model_downloaded<P: AsRef<Path>>(storage_dir: P) -> Result<PathBuf> {
152 let storage_dir = storage_dir.as_ref();
153 if !storage_dir.exists() {
154 fs::create_dir_all(storage_dir)?;
155 }
156
157 let model_path = storage_dir.join("captcha.rten");
159
160 if model_path.exists() {
161 return Ok(model_path);
162 }
163
164 info!(
165 "Downloading captcha model to {path}",
166 path = model_path.display()
167 );
168
169 let client = Client::new();
170 let mut res = client.get(MODEL_URL_HUGGINGFACE).send()?;
171
172 if !res.status().is_success() {
173 return Err(Error::ModelDownload(format!(
174 "Failed to download model: status {}",
175 res.status()
176 )));
177 }
178
179 let mut file = fs::File::create(&model_path)?;
180 res.copy_to(&mut file)?;
181
182 Ok(model_path)
183}
184
185#[cfg(test)]
186mod tests {
187 #![allow(clippy::unwrap_used)]
188 use super::*;
189
190 #[cfg(feature = "embed-model")]
191 #[test]
192 fn test_embedded_model_loads() {
193 let result = CaptchaModel::load_embedded();
194 if let Err(e) = &result {
196 println!(
197 "Embedded model load failed (expected if not building with build.rs): {}",
198 e
199 );
200 }
201 }
202
203 #[test]
204 fn test_load_from_invalid_memory() {
205 let invalid_bytes = b"not a model";
206 let result = CaptchaModel::load_from_memory(invalid_bytes);
207 assert!(result.is_err(), "Loading from invalid bytes should fail");
208 }
209
210 #[test]
211 fn test_load_local_model() {
212 let path = Path::new("model.rten");
213 if path.exists() {
214 let model = CaptchaModel::load(path);
215 assert!(model.is_ok(), "Failed to load local model.rten");
216 } else {
217 println!("Skipping test_load_local_model: model.rten not found");
218 }
219 }
220}