#[cfg(feature = "embeddings-multimodal")]
use std::any::Any;
#[cfg(feature = "embeddings-multimodal")]
use async_trait::async_trait;
#[cfg(feature = "embeddings-multimodal")]
use candle_core::{DType, Device, Module, Tensor};
#[cfg(feature = "embeddings-multimodal")]
use candle_nn::{Linear, VarBuilder};
#[cfg(feature = "embeddings-multimodal")]
use candle_transformers::models::clip;
#[cfg(feature = "embeddings-multimodal")]
use hf_hub::api::sync::ApiBuilder;
#[cfg(feature = "embeddings-multimodal")]
use tokenizers::Tokenizer;
#[cfg(feature = "embeddings-multimodal")]
use crate::embedding::embedder::{EmbedInput, EmbedInputType, Embedder};
#[cfg(feature = "embeddings-multimodal")]
use crate::error::{LaurusError, Result};
#[cfg(feature = "embeddings-multimodal")]
use crate::vector::core::vector::Vector;
#[cfg(feature = "embeddings-multimodal")]
pub struct CandleClipEmbedder {
text_model: clip::text_model::ClipTextTransformer,
vision_model: clip::vision_model::ClipVisionTransformer,
text_projection: Linear,
vision_projection: Linear,
tokenizer: Tokenizer,
device: Device,
dimension: usize,
model_name: String,
image_size: usize,
}
#[cfg(feature = "embeddings-multimodal")]
impl std::fmt::Debug for CandleClipEmbedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CandleClipEmbedder")
.field("model_name", &self.model_name)
.field("dimension", &self.dimension)
.field("image_size", &self.image_size)
.finish()
}
}
#[cfg(feature = "embeddings-multimodal")]
impl CandleClipEmbedder {
pub fn new(model_name: &str) -> Result<Self> {
let device = Device::cuda_if_available(0)
.map_err(|e| LaurusError::InvalidOperation(format!("Device setup failed: {}", e)))?;
let cache_dir = std::env::var("HF_HOME")
.or_else(|_| std::env::var("HOME").map(|home| format!("{}/.cache/huggingface", home)))
.unwrap_or_else(|_| "/tmp/huggingface".to_string());
let api = ApiBuilder::new()
.with_cache_dir(cache_dir.into())
.build()
.map_err(|e| {
LaurusError::InvalidOperation(format!("HF API initialization failed: {}", e))
})?;
let repo = api.model(model_name.to_string());
let config = clip::ClipConfig::vit_base_patch32();
let weights_filename = repo
.get("model.safetensors")
.or_else(|_| repo.get("pytorch_model.bin"))
.map_err(|e| {
LaurusError::InvalidOperation(format!("Weights download failed: {}", e))
})?;
let vb = if weights_filename.to_string_lossy().ends_with(".safetensors") {
unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)
.map_err(|e| {
LaurusError::InvalidOperation(format!("VarBuilder creation failed: {}", e))
})?
}
} else {
VarBuilder::from_pth(&weights_filename, DType::F32, &device).map_err(|e| {
LaurusError::InvalidOperation(format!("VarBuilder creation failed: {}", e))
})?
};
let text_model =
clip::text_model::ClipTextTransformer::new(vb.pp("text_model"), &config.text_config)
.map_err(|e| {
LaurusError::InvalidOperation(format!("Text model load failed: {}", e))
})?;
let vision_model = clip::vision_model::ClipVisionTransformer::new(
vb.pp("vision_model"),
&config.vision_config,
)
.map_err(|e| LaurusError::InvalidOperation(format!("Vision model load failed: {}", e)))?;
let projection_dim = config.text_config.projection_dim;
let text_projection = candle_nn::linear_no_bias(
config.text_config.embed_dim,
projection_dim,
vb.pp("text_projection"),
)
.map_err(|e| {
LaurusError::InvalidOperation(format!("Text projection load failed: {}", e))
})?;
let vision_projection = candle_nn::linear_no_bias(
config.vision_config.embed_dim,
projection_dim,
vb.pp("visual_projection"),
)
.map_err(|e| {
LaurusError::InvalidOperation(format!("Vision projection load failed: {}", e))
})?;
let tokenizer_filename = repo.get("tokenizer.json").map_err(|e| {
LaurusError::InvalidOperation(format!("Tokenizer download failed: {}", e))
})?;
let tokenizer = Tokenizer::from_file(tokenizer_filename)
.map_err(|e| LaurusError::InvalidOperation(format!("Tokenizer load failed: {}", e)))?;
let dimension = projection_dim;
let image_size = config.vision_config.image_size;
Ok(Self {
text_model,
vision_model,
text_projection,
vision_projection,
tokenizer,
device,
dimension,
model_name: model_name.to_string(),
image_size,
})
}
async fn embed_text(&self, text: &str) -> Result<Vector> {
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| LaurusError::InvalidOperation(format!("Tokenization failed: {}", e)))?;
let token_ids = encoding.get_ids();
let token_ids_tensor = Tensor::new(token_ids, &self.device)
.map_err(|e| LaurusError::InvalidOperation(format!("Tensor creation failed: {}", e)))?
.unsqueeze(0)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
let text_features = self.text_model.forward(&token_ids_tensor).map_err(|e| {
LaurusError::InvalidOperation(format!("Text model forward failed: {}", e))
})?;
let projected = self
.text_projection
.forward(&text_features)
.map_err(|e| LaurusError::InvalidOperation(format!("Text projection failed: {}", e)))?;
let normalized = self.normalize(&projected)?;
let vector_data: Vec<f32> = normalized
.squeeze(0)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?
.to_vec1()
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
Ok(Vector::new(vector_data))
}
async fn embed_image_bytes(&self, bytes: &[u8]) -> Result<Vector> {
let image_tensor = self.preprocess_image_bytes(bytes)?;
let vision_features = self.vision_model.forward(&image_tensor).map_err(|e| {
LaurusError::InvalidOperation(format!("Vision model forward failed: {}", e))
})?;
let projected = self
.vision_projection
.forward(&vision_features)
.map_err(|e| {
LaurusError::InvalidOperation(format!("Vision projection failed: {}", e))
})?;
let normalized = self.normalize(&projected)?;
let vector_data: Vec<f32> = normalized
.squeeze(0)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?
.to_vec1()
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
Ok(Vector::new(vector_data))
}
fn preprocess_image_bytes(&self, bytes: &[u8]) -> Result<Tensor> {
use image::ImageReader;
use std::io::Cursor;
let img_reader = ImageReader::new(Cursor::new(bytes))
.with_guessed_format()
.map_err(|e| {
LaurusError::InvalidOperation(format!("Image format guess failed: {}", e))
})?;
let img = img_reader
.decode()
.map_err(|e| LaurusError::InvalidOperation(format!("Image decode failed: {}", e)))?;
let img = img.resize_exact(
self.image_size as u32,
self.image_size as u32,
image::imageops::FilterType::Triangle,
);
let img = match img {
image::DynamicImage::ImageRgb8(img) => img,
img => img.to_rgb8(),
};
let img_data = img.into_raw();
let img_tensor = Tensor::from_vec(
img_data,
(self.image_size, self.image_size, 3),
&self.device,
)
.map_err(|e| LaurusError::InvalidOperation(format!("Tensor creation failed: {}", e)))?;
let mean = Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &self.device)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?
.reshape((1, 1, 3))
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
let std = Tensor::new(&[0.2686295_f32, 0.2613026, 0.2757771], &self.device)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?
.reshape((1, 1, 3))
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
let normalized = img_tensor
.to_dtype(DType::F32)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?
.affine(1.0 / 255.0, 0.0)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?
.broadcast_sub(&mean)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?
.broadcast_div(&std)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
let normalized = normalized
.permute((2, 0, 1))
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?
.unsqueeze(0)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
Ok(normalized)
}
fn normalize(&self, tensor: &Tensor) -> Result<Tensor> {
let norm = tensor
.sqr()
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?
.sum_keepdim(1)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?
.sqrt()
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))?;
tensor
.broadcast_div(&norm)
.map_err(|e| LaurusError::InvalidOperation(e.to_string()))
}
}
#[cfg(feature = "embeddings-multimodal")]
#[async_trait]
impl Embedder for CandleClipEmbedder {
async fn embed(&self, input: &EmbedInput<'_>) -> Result<Vector> {
match input {
EmbedInput::Text(text) => self.embed_text(text).await,
EmbedInput::Bytes(bytes, mime) => {
if let Some(mime_type) = mime
&& mime_type.starts_with("text/")
{
let text = std::str::from_utf8(bytes).map_err(|e| {
LaurusError::invalid_argument(format!(
"invalid utf-8 for text bytes: {}",
e
))
})?;
return self.embed_text(text).await;
}
self.embed_image_bytes(bytes).await
}
}
}
fn supported_input_types(&self) -> Vec<EmbedInputType> {
vec![EmbedInputType::Text, EmbedInputType::Image]
}
fn name(&self) -> &str {
&self.model_name
}
fn as_any(&self) -> &dyn Any {
self
}
}