use std::sync::Arc;
use crate::embedding::embedder::Embedder;
use crate::embedding::precomputed::PrecomputedEmbedder;
use crate::engine::schema::embedder::EmbedderDefinition;
use crate::error::{LaurusError, Result};
pub async fn create_embedder_from_definition(
_name: &str,
definition: &EmbedderDefinition,
) -> Result<Arc<dyn Embedder>> {
match definition {
EmbedderDefinition::Precomputed => Ok(Arc::new(PrecomputedEmbedder::new())),
#[cfg(feature = "embeddings-candle")]
EmbedderDefinition::CandleBert { model } => {
use crate::embedding::candle_bert_embedder::CandleBertEmbedder;
let embedder = CandleBertEmbedder::new(model)?;
Ok(Arc::new(embedder))
}
#[cfg(not(feature = "embeddings-candle"))]
EmbedderDefinition::CandleBert { .. } => Err(LaurusError::invalid_argument(
"candle_bert embedder requires the 'embeddings-candle' feature to be enabled",
)),
#[cfg(feature = "embeddings-multimodal")]
EmbedderDefinition::CandleClip { model } => {
use crate::embedding::candle_clip_embedder::CandleClipEmbedder;
let embedder = CandleClipEmbedder::new(model)?;
Ok(Arc::new(embedder))
}
#[cfg(not(feature = "embeddings-multimodal"))]
EmbedderDefinition::CandleClip { .. } => Err(LaurusError::invalid_argument(
"candle_clip embedder requires the 'embeddings-multimodal' feature to be enabled",
)),
#[cfg(feature = "embeddings-openai")]
EmbedderDefinition::Openai { model } => {
use crate::embedding::openai_embedder::OpenAIEmbedder;
let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
LaurusError::invalid_argument(
"OpenAI embedder requires the OPENAI_API_KEY environment variable to be set",
)
})?;
let embedder = OpenAIEmbedder::new(api_key, model.clone()).await?;
Ok(Arc::new(embedder))
}
#[cfg(not(feature = "embeddings-openai"))]
EmbedderDefinition::Openai { .. } => Err(LaurusError::invalid_argument(
"openai embedder requires the 'embeddings-openai' feature to be enabled",
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_create_precomputed() {
let def = EmbedderDefinition::Precomputed;
let embedder = create_embedder_from_definition("test", &def).await.unwrap();
assert_eq!(embedder.name(), "PrecomputedEmbedder");
}
#[tokio::test]
async fn test_precomputed_serde_roundtrip() {
let json = r#"{"type": "precomputed"}"#;
let def: EmbedderDefinition = serde_json::from_str(json).unwrap();
let embedder = create_embedder_from_definition("test", &def).await.unwrap();
assert_eq!(embedder.name(), "PrecomputedEmbedder");
}
}