captcha-engine 0.4.10

ONNX-based captcha recognition engine
Documentation
use crate::{Result, error::Error, image_ops, tokenizer::Tokenizer};
use image::DynamicImage;
use rten::Model;
use rten_tensor::Tensor;
use std::path::Path;

#[cfg(feature = "download")]
use log::info;
#[cfg(feature = "download")]
use reqwest::blocking::Client;
#[cfg(feature = "download")]
use std::fs;
#[cfg(feature = "download")]
use std::path::PathBuf;

/// URL to download the model from `HuggingFace`.
#[cfg(feature = "download")]
const MODEL_URL_HUGGINGFACE: &str =
    "https://huggingface.co/Milang/captcha-solver/resolve/main/captcha.rten";

/// Embedded model bytes (only available with `embed-model` feature).
#[cfg(feature = "embed-model")]
const EMBEDDED_MODEL: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/model.rten"));

/// The main captcha-breaking model.
///
/// Wraps an `RTen` model and tokenizer for end-to-end captcha recognition.
pub struct CaptchaModel {
    model: Model,
    tokenizer: Tokenizer,
}

impl std::fmt::Debug for CaptchaModel {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("CaptchaModel")
            .field("tokenizer", &self.tokenizer)
            .finish_non_exhaustive()
    }
}

impl CaptchaModel {
    /// Load a model from an `RTen` file.
    ///
    /// # Errors
    ///
    /// Returns an error if the model file cannot be read or loaded.
    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
        let model = Model::load_file(path).map_err(|e| Error::ModelLoad(e.to_string()))?;

        Ok(Self {
            model,
            tokenizer: Tokenizer::default(),
        })
    }

    /// Load a model from memory (bytes).
    ///
    /// # Errors
    ///
    /// Returns an error if the model cannot be loaded.
    pub fn load_from_memory(model_bytes: &[u8]) -> Result<Self> {
        let model =
            Model::load(model_bytes.to_vec()).map_err(|e| Error::ModelLoad(e.to_string()))?;

        Ok(Self {
            model,
            tokenizer: Tokenizer::default(),
        })
    }

    /// Load the embedded model (only available with `embed-model` feature).
    ///
    /// This loads the model directly from bytes compiled into the binary,
    /// requiring no network access or external files.
    ///
    /// # Errors
    ///
    /// Returns an error if the embedded model cannot be loaded.
    #[cfg(feature = "embed-model")]
    pub fn load_embedded() -> Result<Self> {
        Self::load_from_memory(EMBEDDED_MODEL)
    }

    /// Predict the text in a captcha image.
    ///
    /// # Errors
    ///
    /// Returns an error if inference fails.
    pub fn predict(&self, image: &DynamicImage) -> Result<String> {
        // Preprocess image to tensor
        let input_tensor = image_ops::preprocess(image);

        // Run inference
        // rten 0.24 Model::run signature:
        // pub fn run(&self, inputs: Vec<(NodeId, Tensor)>, outputs: Vec<NodeId>, options: Option<RunOptions>) -> Result<Vec<Tensor>, RunError>
        // OR
        // run supports input names.
        // Actually, looking at 0.14 vs 0.24, usually strings are supported via helper or conversion.
        // But the error `expected &[NodeId], found Vec<String>` suggests 0.24 might need NodeIDs?
        // Wait, Model usually has a way to find NodeId by name.

        // Let's look up Node IDs from names.
        // node_id returns Result in 0.24
        let input_id = self
            .model
            .node_id("input")
            .map_err(|e| Error::Inference(format!("Input node 'input' error: {e}")))?;
        let output_id = self
            .model
            .node_id("output")
            .map_err(|e| Error::Inference(format!("Output node 'output' error: {e}")))?;

        let inputs = vec![(input_id, input_tensor.into())];

        let mut outputs = self
            .model
            .run(inputs, &[output_id], None)
            .map_err(|e| Error::Inference(e.to_string()))?;

        let output_value = outputs.remove(0);

        // Output should be float tensor.
        // Value::try_into() converts to Tensor if it matches.
        // We use owned Tensor for simplicity.
        let output_tensor: Tensor<f32> = output_value
            .try_into()
            .map_err(|_| Error::Inference("Output is not a float tensor".into()))?;

        // We can pass a view to the tokenizer if it takes ndarray or slice.
        Ok(self.tokenizer.decode_rten(&output_tensor))
    }

    /// Predict the text in a captcha image loaded from a file path.
    ///
    /// # Errors
    ///
    /// Returns an error if the image cannot be loaded or inference fails.
    pub fn predict_file<P: AsRef<Path>>(&self, path: P) -> Result<String> {
        let image = image::open(path)?;
        self.predict(&image)
    }
}

/// Ensures the model exists at the given path. If not, downloads it.
///
/// # Errors
///
/// Returns an error if the directory cannot be created, the download fails,
/// or the file cannot be written.
#[cfg(feature = "download")]
pub fn ensure_model_downloaded<P: AsRef<Path>>(storage_dir: P) -> Result<PathBuf> {
    let storage_dir = storage_dir.as_ref();
    if !storage_dir.exists() {
        fs::create_dir_all(storage_dir)?;
    }

    // Updated filename for rten
    let model_path = storage_dir.join("captcha.rten");

    if model_path.exists() {
        return Ok(model_path);
    }

    info!(
        "Downloading captcha model to {path}",
        path = model_path.display()
    );

    let client = Client::new();
    let mut res = client.get(MODEL_URL_HUGGINGFACE).send()?;

    if !res.status().is_success() {
        return Err(Error::ModelDownload(format!(
            "Failed to download model: status {}",
            res.status()
        )));
    }

    let mut file = fs::File::create(&model_path)?;
    res.copy_to(&mut file)?;

    Ok(model_path)
}

#[cfg(test)]
mod tests {
    #![allow(clippy::unwrap_used)]
    use super::*;

    #[cfg(feature = "embed-model")]
    #[test]
    fn test_embedded_model_loads() {
        let result = CaptchaModel::load_embedded();
        // This will fail if model.rten is not in OUT_DIR, which is normal during dev without build script
        if let Err(e) = &result {
            println!(
                "Embedded model load failed (expected if not building with build.rs): {}",
                e
            );
        }
    }

    #[test]
    fn test_load_from_invalid_memory() {
        let invalid_bytes = b"not a model";
        let result = CaptchaModel::load_from_memory(invalid_bytes);
        assert!(result.is_err(), "Loading from invalid bytes should fail");
    }

    #[test]
    fn test_load_local_model() {
        let path = Path::new("model.rten");
        if path.exists() {
            let model = CaptchaModel::load(path);
            assert!(model.is_ok(), "Failed to load local model.rten");
        } else {
            println!("Skipping test_load_local_model: model.rten not found");
        }
    }
}