use crate::config::{ModelConfig, OpenClipConfig};
use crate::error::ClipError;
use crate::model_manager;
use crate::model_manager::get_default_base_folder;
use crate::onnx::OnnxSession;
use bon::bon;
#[cfg(feature = "fast_image_resize")]
use fast_image_resize::{
FilterType as FirFilterType, PixelType, ResizeAlg, ResizeOptions, Resizer, images::Image,
};
#[cfg(not(feature = "fast_image_resize"))]
use image::imageops::FilterType;
use image::{DynamicImage, GenericImageView};
use ndarray::{Array2, Array4, ArrayView, Axis, IxDyn};
use ort::ep::ExecutionProviderDispatch;
use ort::value::Value;
use rayon::prelude::*;
use std::path::Path;
pub struct VisionEmbedder {
pub session: OnnxSession,
pub config: OpenClipConfig,
pub model_config: ModelConfig,
pub input_name: String,
}
#[bon]
impl VisionEmbedder {
#[builder(finish_fn = build)]
#[cfg(feature = "hf-hub")]
pub async fn from_hf(
#[builder(start_fn)] model_id: &str,
with_execution_providers: Option<&[ExecutionProviderDispatch]>,
) -> Result<Self, ClipError> {
let model_dir = model_manager::get_hf_model(model_id).await?;
Self::from_local_dir(&model_dir)
.maybe_with_execution_providers(with_execution_providers)
.build()
}
#[builder(finish_fn = build)]
pub fn from_local_id(
#[builder(start_fn)] model_id: &str,
base_folder: Option<&Path>,
with_execution_providers: Option<&[ExecutionProviderDispatch]>,
) -> Result<Self, ClipError> {
let base_folder = base_folder.map_or_else(get_default_base_folder, ToOwned::to_owned);
Self::from_local_dir(&base_folder.join(model_id))
.maybe_with_execution_providers(with_execution_providers)
.build()
}
#[builder(finish_fn = build)]
pub fn from_local_dir(
#[builder(start_fn)] model_dir: &Path,
with_execution_providers: Option<&[ExecutionProviderDispatch]>,
) -> Result<Self, ClipError> {
model_manager::verify_model_dir(model_dir)?;
let model_path = model_dir.join("visual.onnx");
let config_path = model_dir.join("open_clip_config.json");
let local_config_path = model_dir.join("model_config.json");
let execution_providers = with_execution_providers.unwrap_or_default();
let session = OnnxSession::new(model_path, execution_providers)?;
let config = OpenClipConfig::from_file(config_path)?;
let model_config = ModelConfig::from_file(local_config_path)?;
let input_name = session
.find_input(&["pixel_values", "input"])?
.ok_or_else(|| ClipError::Config("Could not find vision input node".to_string()))?;
Ok(Self {
session,
config,
model_config,
input_name,
})
}
pub fn embed_image(&self, image: &DynamicImage) -> Result<ndarray::Array1<f32>, ClipError> {
let embs = self.embed_images(std::slice::from_ref(image))?;
let len = embs.len();
Ok(embs.into_shape_with_order(len)?)
}
#[allow(clippy::significant_drop_tightening)]
pub fn embed_images(&self, images: &[DynamicImage]) -> Result<Array2<f32>, ClipError> {
let batch_tensor = self.preprocess_batch(images)?;
let input_tensor = Value::from_array(batch_tensor)?;
let array = {
let mut session = self.session.session.write()?;
let outputs = session.run(ort::inputs![&self.input_name => input_tensor])?;
let (shape, data) = outputs[0].try_extract_tensor::<f32>()?;
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let shape_usize: Vec<usize> = shape.iter().map(|&x| x as usize).collect();
let view = ArrayView::from_shape(IxDyn(&shape_usize), data)?;
view.into_dimensionality::<ndarray::Ix2>()?.to_owned()
};
Ok(array)
}
pub fn preprocess_batch(&self, images: &[DynamicImage]) -> Result<Array4<f32>, ClipError> {
if images.is_empty() {
return Err(ClipError::Inference("Empty batch".to_string()));
}
let batch_size = images.len();
let size = self.config.model_cfg.vision_cfg.image_size as usize;
let mut batch_tensor = Array4::<f32>::zeros((batch_size, 3, size, size));
batch_tensor
.axis_iter_mut(Axis(0))
.into_par_iter()
.zip(images.par_iter())
.try_for_each(|(mut slot, img)| self.preprocess_into(img, &mut slot))?;
Ok(batch_tensor)
}
pub fn preprocess(&self, image: &DynamicImage) -> Result<Array4<f32>, ClipError> {
self.preprocess_batch(std::slice::from_ref(image))
}
fn preprocess_into(
&self,
image: &DynamicImage,
out_view: &mut ndarray::ArrayViewMut3<f32>,
) -> Result<(), ClipError> {
let size = self.config.model_cfg.vision_cfg.image_size;
#[cfg(feature = "fast_image_resize")]
let pixels_vec = self.resize_with_fast_image_resize(image, size)?;
#[cfg(feature = "fast_image_resize")]
let pixels = &pixels_vec;
#[cfg(not(feature = "fast_image_resize"))]
let resized = self.resize_with_image(image, size);
#[cfg(not(feature = "fast_image_resize"))]
let pixels = resized.as_raw();
self.normalize_pixels(pixels, size, out_view)?;
Ok(())
}
#[cfg(feature = "fast_image_resize")]
fn resize_with_fast_image_resize(
&self,
image: &DynamicImage,
size: u32,
) -> Result<Vec<u8>, ClipError> {
let (width, height) = image.dimensions();
let rgb_image = image.to_rgb8();
let src_image = Image::from_vec_u8(width, height, rgb_image.into_raw(), PixelType::U8x3)?;
let mut dst_image = Image::new(size, size, PixelType::U8x3);
let resize_alg = match self.config.preprocess_cfg.interpolation.as_str() {
"bicubic" => ResizeAlg::Convolution(FirFilterType::CatmullRom),
"bilinear" => ResizeAlg::Convolution(FirFilterType::Bilinear),
_ => ResizeAlg::Nearest,
};
let mut options = ResizeOptions::new().resize_alg(resize_alg);
if self.config.preprocess_cfg.resize_mode.as_str() != "squash" {
#[allow(clippy::cast_precision_loss)]
let scale = f64::from(size) / f64::from(width.min(height));
let crop_w = f64::from(size) / scale;
let crop_h = f64::from(size) / scale;
let crop_x = (f64::from(width) - crop_w) / 2.0;
let crop_y = (f64::from(height) - crop_h) / 2.0;
options = options.crop(crop_x, crop_y, crop_w, crop_h);
}
let mut resizer = Resizer::new();
resizer.resize(&src_image, &mut dst_image, &options)?;
Ok(dst_image.into_vec())
}
#[cfg(not(feature = "fast_image_resize"))]
fn resize_with_image(
&self,
image: &DynamicImage,
size: u32,
) -> image::ImageBuffer<image::Rgb<u8>, Vec<u8>> {
let interp = match self.config.preprocess_cfg.interpolation.as_str() {
"bicubic" => FilterType::CatmullRom,
"bilinear" => FilterType::Triangle,
_ => FilterType::Nearest,
};
#[allow(
clippy::single_match_else,
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
let img_resized = match self.config.preprocess_cfg.resize_mode.as_str() {
"squash" => image.resize_exact(size, size, interp),
_ => {
let (width, height) = image.dimensions();
let scale = size as f32 / width.min(height) as f32;
let scaled_width = (width as f32 * scale).round() as u32;
let scaled_height = (height as f32 * scale).round() as u32;
let resized = image.resize_exact(scaled_width, scaled_height, interp);
let x = ((scaled_width as f32 - size as f32) / 2.0).round() as u32;
let y = ((scaled_height as f32 - size as f32) / 2.0).round() as u32;
resized.crop_imm(x, y, size, size)
}
};
img_resized.to_rgb8()
}
fn normalize_pixels(
&self,
pixels: &[u8],
size: u32,
out_view: &mut ndarray::ArrayViewMut3<f32>,
) -> Result<(), ClipError> {
let (mean, std) = (
self.config.preprocess_cfg.mean,
self.config.preprocess_cfg.std,
);
let channel_len = (size as usize).pow(2);
for c in 0..3 {
let channel_slice = out_view.index_axis_mut(Axis(0), c);
let flat_channel = channel_slice
.into_slice()
.ok_or_else(|| ClipError::Inference("Layout mismatch".into()))?;
for i in 0..channel_len {
let val = f32::from(pixels[i * 3 + c]) / 255.0;
flat_channel[i] = (val - mean[c]) / std[c];
}
}
Ok(())
}
}