use crate::embed::{Embedding, EmbeddingProvider};
use crate::error::ImgFprintError;
use std::path::Path;
use tract_onnx::prelude::*;
type RunnableOnnxModel =
RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>;
#[derive(Debug, Clone)]
pub struct LocalProviderConfig {
pub input_size: usize,
pub normalize_mean: [f32; 3],
pub normalize_std: [f32; 3],
pub normalize_output: bool,
}
impl Default for LocalProviderConfig {
fn default() -> Self {
Self {
input_size: 224,
normalize_mean: [0.481_454_66, 0.457_827_5, 0.408_210_73],
normalize_std: [0.268_629_54, 0.261_302_6, 0.275_777_1],
normalize_output: true,
}
}
}
impl LocalProviderConfig {
#[must_use]
pub fn clip_vit_base_patch32() -> Self {
Self {
input_size: 224,
normalize_mean: [0.481_454_66, 0.457_827_5, 0.408_210_73],
normalize_std: [0.268_629_54, 0.261_302_6, 0.275_777_1],
normalize_output: true,
}
}
#[must_use]
pub fn clip_vit_large_patch14() -> Self {
Self {
input_size: 336,
normalize_mean: [0.481_454_66, 0.457_827_5, 0.408_210_73],
normalize_std: [0.268_629_54, 0.261_302_6, 0.275_777_1],
normalize_output: true,
}
}
}
pub struct LocalProvider {
model: RunnableOnnxModel,
config: LocalProviderConfig,
}
impl std::fmt::Debug for LocalProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalProvider")
.field("config", &self.config)
.field("model", &"<RunnableModel>")
.finish()
}
}
impl LocalProvider {
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, ImgFprintError> {
let config = LocalProviderConfig::default();
Self::from_file_with_config(path, config)
}
pub fn from_file_with_config<P: AsRef<Path>>(
path: P,
config: LocalProviderConfig,
) -> Result<Self, ImgFprintError> {
let model = tract_onnx::onnx()
.model_for_path(&path)
.map_err(|e| {
ImgFprintError::ProviderError(format!(
"Failed to load ONNX model from {}: {}",
path.as_ref().display(),
e
))
})?
.into_optimized()
.map_err(|e| ImgFprintError::ProviderError(format!("Failed to optimize model: {}", e)))?
.into_runnable()
.map_err(|e| {
ImgFprintError::ProviderError(format!("Failed to make model runnable: {}", e))
})?;
Ok(Self { model, config })
}
pub fn from_bytes(model_bytes: &[u8]) -> Result<Self, ImgFprintError> {
let config = LocalProviderConfig::default();
Self::from_bytes_with_config(model_bytes, config)
}
pub fn from_bytes_with_config(
model_bytes: &[u8],
config: LocalProviderConfig,
) -> Result<Self, ImgFprintError> {
let mut cursor = std::io::Cursor::new(model_bytes);
let model = tract_onnx::onnx()
.model_for_read(&mut cursor)
.map_err(|e| {
ImgFprintError::ProviderError(format!("Failed to parse ONNX model: {}", e))
})?
.into_optimized()
.map_err(|e| ImgFprintError::ProviderError(format!("Failed to optimize model: {}", e)))?
.into_runnable()
.map_err(|e| {
ImgFprintError::ProviderError(format!("Failed to make model runnable: {}", e))
})?;
Ok(Self { model, config })
}
pub fn config(&self) -> &LocalProviderConfig {
&self.config
}
fn preprocess_image(&self, image_bytes: &[u8]) -> Result<Tensor, ImgFprintError> {
let img = image::load_from_memory(image_bytes)
.map_err(|e| ImgFprintError::DecodeError(format!("Failed to decode image: {}", e)))?;
let resized = img.resize_exact(
self.config.input_size as u32,
self.config.input_size as u32,
image::imageops::FilterType::Lanczos3,
);
let rgb_img = resized.to_rgb8();
let size = self.config.input_size;
let mut tensor_data: Vec<f32> = Vec::with_capacity(3 * size * size);
for c in 0..3 {
for y in 0..size {
for x in 0..size {
let pixel = rgb_img.get_pixel(x as u32, y as u32);
let value = pixel[c] as f32 / 255.0;
let normalized =
(value - self.config.normalize_mean[c]) / self.config.normalize_std[c];
tensor_data.push(normalized);
}
}
}
let tensor = Tensor::from_shape(&[1, 3, size, size], &tensor_data).map_err(|e| {
ImgFprintError::ProcessingError(format!("Failed to create tensor: {}", e))
})?;
Ok(tensor)
}
fn l2_normalize(vector: &mut [f32]) {
let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in vector.iter_mut() {
*x /= norm;
}
}
}
}
impl EmbeddingProvider for LocalProvider {
fn embed(&self, image: &[u8]) -> Result<Embedding, ImgFprintError> {
let input_tensor = self.preprocess_image(image)?;
let output = self
.model
.run(tvec!(input_tensor.into()))
.map_err(|e| ImgFprintError::ProviderError(format!("Inference failed: {}", e)))?;
let output_tensor = output
.first()
.ok_or_else(|| ImgFprintError::ProviderError("Empty model output".to_string()))?;
let embedding_vec: Vec<f32> = output_tensor
.as_slice::<f32>()
.map_err(|e| ImgFprintError::ProviderError(format!("Failed to extract output: {}", e)))?
.to_vec();
let mut embedding_vec = embedding_vec;
if self.config.normalize_output {
Self::l2_normalize(&mut embedding_vec);
}
Embedding::new(embedding_vec)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = LocalProviderConfig::default();
assert_eq!(config.input_size, 224);
assert!(config.normalize_output);
}
#[test]
fn test_config_clip_vit_base() {
let config = LocalProviderConfig::clip_vit_base_patch32();
assert_eq!(config.input_size, 224);
}
#[test]
fn test_config_clip_vit_large() {
let config = LocalProviderConfig::clip_vit_large_patch14();
assert_eq!(config.input_size, 336);
}
#[test]
fn test_l2_normalize() {
let mut vec = vec![3.0, 4.0];
LocalProvider::l2_normalize(&mut vec);
assert!((vec[0] - 0.6).abs() < 1e-6);
assert!((vec[1] - 0.8).abs() < 1e-6);
}
#[test]
fn test_l2_normalize_zero_vector() {
let mut vec = vec![0.0, 0.0, 0.0];
LocalProvider::l2_normalize(&mut vec);
assert_eq!(vec, vec![0.0, 0.0, 0.0]);
}
}