open_clip_inference 0.3.0

Run OpenCLIP compatible embedding models via ONNX Runtime
Documentation
use crate::ClipError;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::Path;

#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct ModelConfig {
    #[serde(default)]
    pub tokenizer_needs_lowercase: bool,
    pub activation_function: Option<String>,
    pub logit_scale: Option<f32>,
    pub logit_bias: Option<f32>,
    pub pad_id: Option<u32>,
}

impl ModelConfig {
    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, ClipError> {
        let content = fs::read_to_string(path)?;
        Ok(serde_json::from_str(&content)?)
    }
}

#[derive(Debug, Clone, Deserialize)]
pub struct OpenClipConfig {
    pub model_cfg: ModelCfg,
    pub preprocess_cfg: PreprocessCfg,
}

#[derive(Debug, Clone, Deserialize)]
pub struct ModelCfg {
    pub embed_dim: usize,
    pub vision_cfg: VisionCfg,
    pub text_cfg: TextCfg,
}

#[derive(Debug, Clone, Deserialize)]
pub struct VisionCfg {
    pub image_size: u32,
    pub layers: Option<usize>,
    pub width: Option<usize>,
}

#[derive(Debug, Clone, Deserialize)]
pub struct TextCfg {
    pub context_length: usize,
    pub hf_tokenizer_name: Option<String>,
}

#[derive(Debug, Clone, Deserialize)]
pub struct PreprocessCfg {
    pub mean: [f32; 3],
    pub std: [f32; 3],
    #[serde(default = "default_interpolation")]
    pub interpolation: String,
    #[serde(default = "default_resize_mode")]
    pub resize_mode: String,
}

fn default_interpolation() -> String {
    "bicubic".to_string()
}
fn default_resize_mode() -> String {
    "shortest".to_string()
}

impl OpenClipConfig {
    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, ClipError> {
        let content = fs::read_to_string(path)?;
        Ok(serde_json::from_str(&content)?)
    }
}