#[cfg(feature = "hf-hub")]
use hf_hub::api::sync::ApiRepo;
use image::DynamicImage;
use ndarray::{Array3, ArrayView3};
use ort::{
session::{builder::GraphOptimizationLevel, Session},
value::Value,
};
#[cfg(feature = "hf-hub")]
use std::path::PathBuf;
use std::{io::Cursor, path::Path, thread::available_parallelism};
use crate::{
common::normalize, models::image_embedding::models_list, Embedding, ImageEmbeddingModel,
ModelInfo,
};
use anyhow::anyhow;
#[cfg(feature = "hf-hub")]
use anyhow::Context;
#[cfg(feature = "hf-hub")]
use super::ImageInitOptions;
use super::{
init::{ImageInitOptionsUserDefined, UserDefinedImageEmbeddingModel},
utils::{Compose, Transform, TransformData},
ImageEmbedding, DEFAULT_BATCH_SIZE,
};
impl ImageEmbedding {
#[cfg(feature = "hf-hub")]
pub fn try_new(options: ImageInitOptions) -> anyhow::Result<Self> {
let ImageInitOptions {
model_name,
execution_providers,
cache_dir,
show_download_progress,
} = options;
let threads = available_parallelism()?.get();
let model_repo = ImageEmbedding::retrieve_model(
model_name.clone(),
cache_dir.clone(),
show_download_progress,
)?;
let preprocessor_file = model_repo
.get("preprocessor_config.json")
.context("Failed to retrieve preprocessor_config.json")?;
let preprocessor = Compose::from_file(preprocessor_file)?;
let model_file_name = ImageEmbedding::get_model_info(&model_name).model_file;
let model_file_reference = model_repo
.get(&model_file_name)
.context(format!("Failed to retrieve {}", model_file_name))?;
#[cfg(feature = "directml")]
let has_directml = execution_providers
.iter()
.any(|ep| ep.downcast_ref::<ort::ep::DirectML>().is_some());
#[cfg(not(feature = "directml"))]
let has_directml = false;
let mut builder = Session::builder()?
.with_execution_providers(execution_providers)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(threads)?;
if has_directml {
builder = builder
.with_memory_pattern(false)?
.with_parallel_execution(false)?;
}
let session = builder.commit_from_file(model_file_reference)?;
Ok(Self::new(preprocessor, session))
}
pub fn try_new_from_user_defined(
model: UserDefinedImageEmbeddingModel,
options: ImageInitOptionsUserDefined,
) -> anyhow::Result<Self> {
let ImageInitOptionsUserDefined {
execution_providers,
} = options;
let threads = available_parallelism()?.get();
let preprocessor = Compose::from_bytes(model.preprocessor_file)?;
#[cfg(feature = "directml")]
let has_directml = execution_providers
.iter()
.any(|ep| ep.downcast_ref::<ort::ep::DirectML>().is_some());
#[cfg(not(feature = "directml"))]
let has_directml = false;
let mut builder = Session::builder()?
.with_execution_providers(execution_providers)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(threads)?;
if has_directml {
builder = builder
.with_memory_pattern(false)?
.with_parallel_execution(false)?;
}
let session = builder.commit_from_memory(&model.onnx_file)?;
Ok(Self::new(preprocessor, session))
}
fn new(preprocessor: Compose, session: Session) -> Self {
Self {
preprocessor,
session,
}
}
#[cfg(feature = "hf-hub")]
fn retrieve_model(
model: ImageEmbeddingModel,
cache_dir: PathBuf,
show_download_progress: bool,
) -> anyhow::Result<ApiRepo> {
use crate::common::pull_from_hf;
pull_from_hf(model.to_string(), cache_dir, show_download_progress)
}
pub fn list_supported_models() -> Vec<ModelInfo<ImageEmbeddingModel>> {
models_list()
}
pub fn get_model_info(model: &ImageEmbeddingModel) -> ModelInfo<ImageEmbeddingModel> {
ImageEmbedding::list_supported_models()
.into_iter()
.find(|m| &m.model == model)
.expect("Model not found in supported models list. This is a bug - please report it.")
}
pub fn embed_bytes(
&mut self,
images: &[&[u8]],
batch_size: Option<usize>,
) -> anyhow::Result<Vec<Embedding>> {
let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
let output = images
.chunks(batch_size)
.map(|batch| {
let inputs = batch
.iter()
.map(|img| {
image::ImageReader::new(Cursor::new(img))
.with_guessed_format()?
.decode()
.map_err(|err| anyhow!("image decode: {}", err))
})
.collect::<Result<_, _>>()?;
self.embed_images(inputs)
})
.collect::<anyhow::Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect();
Ok(output)
}
pub fn embed<S: AsRef<Path> + Send + Sync>(
&mut self,
images: impl AsRef<[S]>,
batch_size: Option<usize>,
) -> anyhow::Result<Vec<Embedding>> {
let images = images.as_ref();
let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
let output = images
.chunks(batch_size)
.map(|batch| {
let inputs = batch
.iter()
.map(|img| {
image::ImageReader::open(img)?
.decode()
.map_err(|err| anyhow!("image decode: {}", err))
})
.collect::<Result<_, _>>()?;
self.embed_images(inputs)
})
.collect::<anyhow::Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect();
Ok(output)
}
pub fn embed_images(&mut self, imgs: Vec<DynamicImage>) -> anyhow::Result<Vec<Embedding>> {
let inputs = imgs
.into_iter()
.map(|img| {
let pixels = self.preprocessor.transform(TransformData::Image(img))?;
match pixels {
TransformData::NdArray(array) => Ok(array),
_ => Err(anyhow!("Preprocessor configuration error!")),
}
})
.collect::<anyhow::Result<Vec<Array3<f32>>>>()?;
let inputs_view: Vec<ArrayView3<f32>> = inputs.iter().map(|img| img.view()).collect();
let pixel_values_array = ndarray::stack(ndarray::Axis(0), &inputs_view)?;
let input_name = self.session.inputs()[0].name().to_string();
let session_inputs = ort::inputs![
input_name => Value::from_array(pixel_values_array)?,
];
let outputs = self.session.run(session_inputs)?;
let last_hidden_state_key = match outputs.len() {
1 => vec![outputs
.keys()
.next()
.ok_or_else(|| anyhow!("Expected one output but found none"))?],
_ => vec!["image_embeds", "last_hidden_state"],
};
let (shape, data) = last_hidden_state_key
.iter()
.find_map(|&key| {
outputs
.get(key)
.and_then(|v| v.try_extract_tensor::<f32>().ok())
})
.ok_or_else(|| anyhow!("Could not extract tensor from any known output key"))?;
let shape: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
let output_array = ndarray::ArrayViewD::from_shape(shape.as_slice(), data)?;
let embeddings = match output_array.ndim() {
3 => {
(0..output_array.shape()[0])
.map(|batch_idx| {
let cls_embedding = output_array
.slice(ndarray::s![batch_idx, 0, ..])
.to_owned()
.to_vec();
normalize(&cls_embedding)
})
.collect()
}
2 => {
output_array
.outer_iter()
.map(|row| {
row.as_slice()
.ok_or_else(|| anyhow!("Failed to convert array row to slice"))
.map(normalize)
})
.collect::<anyhow::Result<Vec<_>>>()?
}
_ => {
return Err(anyhow!(
"Unexpected output tensor shape: {:?}",
output_array.shape()
))
}
};
Ok(embeddings)
}
}