object_detector 0.5.0

Object detection using ORT and the yoloe-26-seg model. This model can detect multiple objects per image, each having a tag, pixel-level mask, and a boundingbox. It's pretrained, it has a vocabulary of 4000+ objects.
Documentation
use crate::ObjectDetectorError;
#[cfg(feature = "hf-hub")]
use crate::model_manager::{HfModel, get_hf_model};
use crate::predictor::PromptFreeDetector;
#[cfg(feature = "promptable")]
use crate::predictor::PromptableDetector;
use crate::structs::{DetectedObject, DetectorType, ModelScale};
use bon::bon;
use image::DynamicImage;
use ort::ep::ExecutionProviderDispatch;

pub struct ObjectDetector {
    inner: ObjectDetectorInner,
}

enum ObjectDetectorInner {
    #[cfg(feature = "promptable")]
    Promptable(Box<PromptableDetector>),
    PromptFree(PromptFreeDetector),
}

#[bon]
impl ObjectDetector {
    /// Initialize predictor using models hosted on Hugging Face.
    #[cfg(feature = "hf-hub")]
    #[builder(finish_fn = build)]
    pub async fn from_hf(
        #[builder(start_fn)] detector_type: DetectorType,
        #[builder(default = ModelScale::Large)] scale: ModelScale,
        #[builder(default = true)] include_mask: bool,
        #[builder(default = &[])] with_execution_providers: &[ExecutionProviderDispatch],
    ) -> Result<Self, ObjectDetectorError> {
        let model_path = HfModel::get_model_file_path(detector_type, scale, include_mask);
        let model = HfModel {
            id: HfModel::DEFAULT_REPO_ID.to_owned(),
            file: model_path.clone(),
        };
        let data_model = HfModel {
            id: HfModel::DEFAULT_REPO_ID.to_owned(),
            file: format!("{model_path}.data"),
        };

        let model_path_local = get_hf_model(model).await?;
        get_hf_model(data_model).await?;

        let inner = match detector_type {
            #[cfg(feature = "promptable")]
            DetectorType::Promptable => {
                let text_embedder =
                    open_clip_inference::TextEmbedder::from_hf(&HfModel::default_clip_embedder())
                        .with_execution_providers(with_execution_providers)
                        .build()
                        .await
                        .map_err(|e| ObjectDetectorError::Ort(format!("CLIP error: {e}")))?;

                let detector = PromptableDetector::builder(model_path_local, text_embedder)
                    .with_execution_providers(with_execution_providers)
                    .build()?;
                ObjectDetectorInner::Promptable(Box::new(detector))
            }
            DetectorType::PromptFree => {
                let vocab_model = HfModel::default_vocabulary();
                let vocab_path = get_hf_model(vocab_model).await?;

                let detector = PromptFreeDetector::builder(model_path_local, vocab_path)
                    .with_execution_providers(with_execution_providers)
                    .build()?;
                ObjectDetectorInner::PromptFree(detector)
            }
            #[allow(unreachable_patterns)]
            _ => {
                return Err(ObjectDetectorError::InvalidModel(
                    "Promptable detector is disabled".into(),
                ));
            }
        };

        Ok(Self { inner })
    }

    #[builder]
    pub fn predict(
        &self,
        #[builder(start_fn)] img: &DynamicImage,
        #[builder(default = &[])] labels: &[&str],
        #[builder(default = 0.3)] confidence_threshold: f32,
        #[builder(default = 0.7)] intersection_over_union: f32,
    ) -> Result<Vec<DetectedObject>, ObjectDetectorError> {
        match &self.inner {
            #[cfg(feature = "promptable")]
            ObjectDetectorInner::Promptable(detector) => {
                if labels.is_empty() {
                    return Err(ObjectDetectorError::InvalidModel(
                        "Labels are required for Promptable detector".into(),
                    ));
                }
                detector
                    .predict(img, labels)
                    .confidence_threshold(confidence_threshold)
                    .intersection_over_union(intersection_over_union)
                    .call()
            }
            ObjectDetectorInner::PromptFree(detector) => {
                if !labels.is_empty() {
                    return Err(ObjectDetectorError::InvalidModel(
                        "Labels are not supported for PromptFree detector".into(),
                    ));
                }
                detector
                    .predict(img)
                    .confidence_threshold(confidence_threshold)
                    .intersection_over_union(intersection_over_union)
                    .call()
            }
            #[allow(unreachable_patterns)]
            _ => Err(ObjectDetectorError::InvalidModel(
                "Promptable detector is disabled".into(),
            )),
        }
    }
}