use crate::{Result, error::Error, image_ops, tokenizer::Tokenizer};
use image::DynamicImage;
use rten::Model;
use rten_tensor::Tensor;
use std::path::Path;
#[cfg(feature = "download")]
use log::info;
#[cfg(feature = "download")]
use reqwest::blocking::Client;
#[cfg(feature = "download")]
use std::fs;
#[cfg(feature = "download")]
use std::path::PathBuf;
#[cfg(feature = "download")]
const MODEL_URL_HUGGINGFACE: &str =
"https://huggingface.co/Milang/captcha-solver/resolve/main/captcha.rten";
#[cfg(feature = "embed-model")]
const EMBEDDED_MODEL: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/model.rten"));
pub struct CaptchaModel {
model: Model,
tokenizer: Tokenizer,
}
impl std::fmt::Debug for CaptchaModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CaptchaModel")
.field("tokenizer", &self.tokenizer)
.finish_non_exhaustive()
}
}
impl CaptchaModel {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let model = Model::load_file(path).map_err(|e| Error::ModelLoad(e.to_string()))?;
Ok(Self {
model,
tokenizer: Tokenizer::default(),
})
}
pub fn load_from_memory(model_bytes: &[u8]) -> Result<Self> {
let model =
Model::load(model_bytes.to_vec()).map_err(|e| Error::ModelLoad(e.to_string()))?;
Ok(Self {
model,
tokenizer: Tokenizer::default(),
})
}
#[cfg(feature = "embed-model")]
pub fn load_embedded() -> Result<Self> {
Self::load_from_memory(EMBEDDED_MODEL)
}
pub fn predict(&self, image: &DynamicImage) -> Result<String> {
let input_tensor = image_ops::preprocess(image);
let input_id = self
.model
.node_id("input")
.map_err(|e| Error::Inference(format!("Input node 'input' error: {e}")))?;
let output_id = self
.model
.node_id("output")
.map_err(|e| Error::Inference(format!("Output node 'output' error: {e}")))?;
let inputs = vec![(input_id, input_tensor.into())];
let mut outputs = self
.model
.run(inputs, &[output_id], None)
.map_err(|e| Error::Inference(e.to_string()))?;
let output_value = outputs.remove(0);
let output_tensor: Tensor<f32> = output_value
.try_into()
.map_err(|_| Error::Inference("Output is not a float tensor".into()))?;
Ok(self.tokenizer.decode_rten(&output_tensor))
}
pub fn predict_file<P: AsRef<Path>>(&self, path: P) -> Result<String> {
let image = image::open(path)?;
self.predict(&image)
}
}
#[cfg(feature = "download")]
pub fn ensure_model_downloaded<P: AsRef<Path>>(storage_dir: P) -> Result<PathBuf> {
let storage_dir = storage_dir.as_ref();
if !storage_dir.exists() {
fs::create_dir_all(storage_dir)?;
}
let model_path = storage_dir.join("captcha.rten");
if model_path.exists() {
return Ok(model_path);
}
info!(
"Downloading captcha model to {path}",
path = model_path.display()
);
let client = Client::new();
let mut res = client.get(MODEL_URL_HUGGINGFACE).send()?;
if !res.status().is_success() {
return Err(Error::ModelDownload(format!(
"Failed to download model: status {}",
res.status()
)));
}
let mut file = fs::File::create(&model_path)?;
res.copy_to(&mut file)?;
Ok(model_path)
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[cfg(feature = "embed-model")]
#[test]
fn test_embedded_model_loads() {
let result = CaptchaModel::load_embedded();
if let Err(e) = &result {
println!(
"Embedded model load failed (expected if not building with build.rs): {}",
e
);
}
}
#[test]
fn test_load_from_invalid_memory() {
let invalid_bytes = b"not a model";
let result = CaptchaModel::load_from_memory(invalid_bytes);
assert!(result.is_err(), "Loading from invalid bytes should fail");
}
#[test]
fn test_load_local_model() {
let path = Path::new("model.rten");
if path.exists() {
let model = CaptchaModel::load(path);
assert!(model.is_ok(), "Failed to load local model.rten");
} else {
println!("Skipping test_load_local_model: model.rten not found");
}
}
}