1use crate::{Result, error::Error, image_ops, tokenizer::Tokenizer};
4use image::DynamicImage;
5use ort::session::{Session, builder::GraphOptimizationLevel};
6use ort::value::Tensor;
7use std::path::Path;
8use std::sync::Mutex;
9
10#[cfg(feature = "download")]
11use log::info;
12#[cfg(feature = "download")]
13use reqwest::blocking::Client;
14#[cfg(feature = "download")]
15use std::fs;
16#[cfg(feature = "download")]
17use std::path::PathBuf;
18
19#[cfg(feature = "download")]
21const MODEL_URL_HUGGINGFACE: &str =
22 "https://huggingface.co/Milang/captcha-solver/resolve/main/captcha.onnx";
23
24#[cfg(feature = "embed-model")]
26const EMBEDDED_MODEL: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/model.onnx"));
27
28pub struct CaptchaModel {
32 session: Mutex<Session>,
33 tokenizer: Tokenizer,
34}
35
36impl std::fmt::Debug for CaptchaModel {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("CaptchaModel")
39 .field("tokenizer", &self.tokenizer)
40 .finish_non_exhaustive()
41 }
42}
43
44impl CaptchaModel {
45 #[allow(unused_mut)]
51 fn create_session_builder() -> Result<ort::session::builder::SessionBuilder> {
52 let mut builder = Session::builder().map_err(|e| Error::ModelLoad(e.to_string()))?;
53
54 let mut providers = Vec::new();
57
58 #[cfg(target_os = "macos")]
59 {
60 providers.push(ort::execution_providers::CoreMLExecutionProvider::default().build());
61 }
62
63 #[cfg(any(target_os = "windows", target_os = "linux"))]
64 {
65 providers.push(ort::execution_providers::CUDAExecutionProvider::default().build());
66 }
67
68 providers.push(ort::execution_providers::CPUExecutionProvider::default().build());
69
70 builder = builder
71 .with_execution_providers(providers)
72 .map_err(|e| Error::ModelLoad(e.to_string()))?
73 .with_optimization_level(GraphOptimizationLevel::Level3)
74 .map_err(|e| Error::ModelLoad(e.to_string()))?;
75
76 Ok(builder)
77 }
78
79 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
85 let session = Self::create_session_builder()?
86 .commit_from_file(path.as_ref())
87 .map_err(|e| Error::ModelLoad(e.to_string()))?;
88
89 Ok(Self {
90 session: Mutex::new(session),
91 tokenizer: Tokenizer::default(),
92 })
93 }
94
95 pub fn load_from_memory(model_bytes: &[u8]) -> Result<Self> {
101 let session = Self::create_session_builder()?
102 .commit_from_memory(model_bytes)
103 .map_err(|e| Error::ModelLoad(e.to_string()))?;
104
105 Ok(Self {
106 session: Mutex::new(session),
107 tokenizer: Tokenizer::default(),
108 })
109 }
110
111 #[cfg(feature = "embed-model")]
120 pub fn load_embedded() -> Result<Self> {
121 Self::load_from_memory(EMBEDDED_MODEL)
122 }
123
124 #[allow(clippy::significant_drop_tightening)]
130 pub fn predict(&self, image: &DynamicImage) -> Result<String> {
131 let array = image_ops::preprocess(image);
132
133 let tensor = Tensor::from_array(array).map_err(|e| Error::Inference(e.to_string()))?;
135
136 let mut session = self
137 .session
138 .lock()
139 .map_err(|_| Error::Inference("Session mutex poisoned".into()))?;
140
141 let outputs = session
142 .run(ort::inputs!["input" => tensor])
143 .map_err(|e| Error::Inference(e.to_string()))?;
144
145 let output = outputs
147 .iter()
148 .find(|(name, _)| *name == "output")
149 .map(|(_, v)| v)
150 .or_else(|| outputs.iter().next().map(|(_, v)| v))
151 .ok_or_else(|| Error::Inference("No output tensor".into()))?;
152
153 let probs = output
154 .try_extract_array::<f32>()
155 .map_err(|e| Error::Inference(e.to_string()))?;
156
157 Ok(self.tokenizer.decode(&probs.view()))
158 }
159
160 pub fn predict_file<P: AsRef<Path>>(&self, path: P) -> Result<String> {
166 let image = image::open(path)?;
167 self.predict(&image)
168 }
169}
170
171#[cfg(feature = "download")]
178pub fn ensure_model_downloaded<P: AsRef<Path>>(storage_dir: P) -> Result<PathBuf> {
179 let storage_dir = storage_dir.as_ref();
180 if !storage_dir.exists() {
181 fs::create_dir_all(storage_dir)?;
182 }
183
184 let model_path = storage_dir.join("captcha.onnx");
185
186 if model_path.exists() {
187 return Ok(model_path);
188 }
189
190 info!(
191 "Downloading captcha model to {path}",
192 path = model_path.display()
193 );
194
195 let client = Client::new();
196 let mut res = client.get(MODEL_URL_HUGGINGFACE).send()?;
197
198 if !res.status().is_success() {
199 return Err(Error::ModelDownload(format!(
200 "Failed to download model: status {}",
201 res.status()
202 )));
203 }
204
205 let mut file = fs::File::create(&model_path)?;
206 res.copy_to(&mut file)?;
207
208 Ok(model_path)
209}
210
211#[cfg(test)]
212mod tests {
213 #![allow(clippy::unwrap_used)]
214 use super::*;
215
216 #[cfg(feature = "embed-model")]
217 #[test]
218 fn test_embedded_model_loads() {
219 let result = CaptchaModel::load_embedded();
220 assert!(
221 result.is_ok(),
222 "Embedded model should load successfully: {:?}",
223 result.err()
224 );
225 }
226
227 #[test]
228 fn test_load_from_invalid_memory() {
229 let invalid_bytes = b"not a model";
230 let result = CaptchaModel::load_from_memory(invalid_bytes);
231 assert!(result.is_err(), "Loading from invalid bytes should fail");
232 }
233
234 #[cfg(feature = "embed-model")]
235 #[test]
236 fn test_prediction_dummy_image() {
237 use image::{DynamicImage, RgbImage};
238 let model = CaptchaModel::load_embedded().expect("Failed to load embedded model");
239 let img =
241 DynamicImage::ImageRgb8(RgbImage::from_pixel(215, 80, image::Rgb([255, 255, 255])));
242
243 let result = model.predict(&img);
245 assert!(result.is_ok());
246 }
247
248 #[test]
249 fn test_load_local_model() {
250 let path = Path::new("model.onnx");
251 if path.exists() {
252 let model = CaptchaModel::load(path);
253 assert!(model.is_ok(), "Failed to load local model.onnx");
254 } else {
255 println!("Skipping test_load_local_model: model.onnx not found");
256 }
257 }
258
259 #[test]
260 fn test_predict_real_image() {
261 let model_path = Path::new("model.onnx");
262 let image_path = Path::new("../../test-captcha/captcha-3q5hQL.png");
263
264 if model_path.exists() && image_path.exists() {
265 let model = CaptchaModel::load(model_path).expect("Failed to load model");
266 let result = model.predict_file(image_path);
267 assert!(result.is_ok(), "Prediction failed: {:?}", result.err());
268 let text = result.unwrap();
269 assert!(!text.is_empty(), "Prediction result is empty");
270 println!("Predicted text: {text}");
271 } else {
272 println!("Skipping test_predict_real_image: resources not found");
273 }
274 }
275}