use std::sync::Arc;
use crate::error::{EngramError, Result};
use crate::multimodal::vision::{VisionInput, VisionOptions, VisionProviderFactory};
use super::{Embedder, OpenAIEmbedder};
pub trait MultimodalEmbedder: Embedder {
fn embed_image_sync(&self, image_bytes: &[u8], mime_type: &str) -> Result<Vec<f32>>;
fn multimodal_provider_name(&self) -> &str;
}
pub struct ClipEmbedder {
text_embedder: Arc<OpenAIEmbedder>,
}
impl ClipEmbedder {
pub fn new(api_key: String) -> Self {
Self {
text_embedder: Arc::new(OpenAIEmbedder::new(api_key)),
}
}
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
EngramError::Config(
"OPENAI_API_KEY is required for CLIP embeddings".to_string(),
)
})?;
Ok(Self::new(api_key))
}
pub async fn embed_image_async(
&self,
image_bytes: &[u8],
mime_type: &str,
) -> Result<Vec<f32>> {
let vision = VisionProviderFactory::from_env().map_err(|e| {
EngramError::Config(format!(
"Vision provider required for image embedding: {}",
e
))
})?;
let input = VisionInput {
image_bytes: image_bytes.to_vec(),
mime_type: mime_type.to_string(),
};
let opts = VisionOptions {
prompt: Some(
"Describe this image in detail, including objects, colors, layout, and any text visible. Be precise and comprehensive.".to_string(),
),
max_tokens: Some(512),
};
let description = vision.describe_image(input, opts).await?;
self.text_embedder.embed_async(&description.text).await
}
}
impl Embedder for ClipEmbedder {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
self.text_embedder.embed(text)
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
self.text_embedder.embed_batch(texts)
}
fn dimensions(&self) -> usize {
self.text_embedder.dimensions()
}
fn model_name(&self) -> &str {
"clip-description-mediated"
}
}
impl MultimodalEmbedder for ClipEmbedder {
fn embed_image_sync(&self, image_bytes: &[u8], mime_type: &str) -> Result<Vec<f32>> {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(self.embed_image_async(image_bytes, mime_type))
})
}
fn multimodal_provider_name(&self) -> &str {
"clip"
}
}
pub fn create_clip_embedder() -> Result<Arc<ClipEmbedder>> {
ClipEmbedder::from_env().map(Arc::new)
}
pub const CLIP_PROVIDER_NAME: &str = "clip";
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_clip_embedder_implements_multimodal_embedder() {
fn assert_multimodal<T: MultimodalEmbedder + Embedder>() {}
assert_multimodal::<ClipEmbedder>();
}
#[test]
fn test_clip_provider_name() {
let embedder = ClipEmbedder::new("dummy-key".to_string());
assert_eq!(embedder.multimodal_provider_name(), "clip");
assert_eq!(embedder.model_name(), "clip-description-mediated");
}
#[test]
fn test_clip_dimensions_match_openai() {
let embedder = ClipEmbedder::new("dummy-key".to_string());
assert_eq!(embedder.dimensions(), 1536);
}
#[test]
fn test_from_env_fails_without_api_key() {
let saved = std::env::var("OPENAI_API_KEY").ok();
std::env::remove_var("OPENAI_API_KEY");
let result = ClipEmbedder::from_env();
assert!(result.is_err(), "should fail without OPENAI_API_KEY");
match result.err().unwrap() {
EngramError::Config(msg) => {
assert!(
msg.contains("OPENAI_API_KEY"),
"error should mention OPENAI_API_KEY"
);
}
e => panic!("expected Config error, got: {:?}", e),
}
if let Some(key) = saved {
std::env::set_var("OPENAI_API_KEY", key);
}
}
#[test]
fn test_clip_provider_constant() {
assert_eq!(CLIP_PROVIDER_NAME, "clip");
}
#[test]
fn test_create_clip_embedder_type() {
std::env::set_var("OPENAI_API_KEY", "test-key");
let result = create_clip_embedder();
assert!(
result.is_ok(),
"create_clip_embedder should succeed when OPENAI_API_KEY is set"
);
let arc = result.unwrap();
assert_eq!(arc.multimodal_provider_name(), "clip");
std::env::remove_var("OPENAI_API_KEY");
}
}