open_clip_inference 0.1.4

Run OpenCLIP compatible embedding models via ONNX Runtime
Documentation
use crate::ClipError;
use ort::ep::ExecutionProviderDispatch;
use ort::session::{Session, builder::GraphOptimizationLevel};
use std::env;
use std::path::{Path, PathBuf};

pub struct OnnxSession {
    pub session: Session,
}

impl OnnxSession {
    pub fn new(
        path: impl AsRef<Path>,
        execution_providers: &[ExecutionProviderDispatch],
    ) -> Result<Self, ClipError> {
        let threads = num_cpus::get();
        let session = Session::builder()?
            .with_execution_providers(execution_providers)?
            .with_optimization_level(GraphOptimizationLevel::Level3)?
            .with_intra_threads(threads)?
            .commit_from_file(path)?;

        Ok(Self { session })
    }

    /// Helper to check if the model expects a specific input name
    #[must_use]
    pub fn has_input(&self, name: &str) -> bool {
        self.session.inputs().iter().any(|i| i.name() == name)
    }

    /// Helper to find the first likely input name for a specific role
    #[must_use]
    pub fn find_input(&self, possibilities: &[&str]) -> Option<String> {
        for &p in possibilities {
            if self.has_input(p) {
                return Some(p.to_string());
            }
        }
        None
    }

    /// Get model directory by `model_id`
    #[must_use]
    pub fn get_model_dir(model_id: &str) -> PathBuf {
        let base_folder = env::home_dir().map_or_else(
            || Path::new(".open_clip_cache").to_owned(),
            |p| p.join(".cache/open_clip_rs"),
        );
        base_folder.join(model_id)
    }

    pub fn verify_model_dir(model_dir: &Path) -> Result<(), ClipError> {
        if !model_dir.exists() {
            return Err(ClipError::ModelFolderNotFound(model_dir.to_owned()));
        }
        Ok(())
    }
}