use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use parking_lot::RwLock;
use crate::embedding::embedder::{EmbedInput, EmbedInputType, Embedder};
use crate::error::Result;
use crate::vector::core::vector::Vector;
pub struct PerFieldEmbedder {
default_embedder: Arc<dyn Embedder>,
field_embedders: RwLock<HashMap<String, Arc<dyn Embedder>>>,
}
impl Clone for PerFieldEmbedder {
fn clone(&self) -> Self {
Self {
default_embedder: self.default_embedder.clone(),
field_embedders: RwLock::new(self.field_embedders.read().clone()),
}
}
}
impl std::fmt::Debug for PerFieldEmbedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PerFieldEmbedder")
.field("default_embedder", &self.default_embedder.name())
.field(
"fields",
&self.field_embedders.read().keys().collect::<Vec<_>>(),
)
.finish()
}
}
impl PerFieldEmbedder {
pub fn new(default_embedder: Arc<dyn Embedder>) -> Self {
Self {
default_embedder,
field_embedders: RwLock::new(HashMap::new()),
}
}
pub fn add_embedder(&self, field: impl Into<String>, embedder: Arc<dyn Embedder>) {
self.field_embedders.write().insert(field.into(), embedder);
}
pub fn remove_embedder(&self, field: &str) {
self.field_embedders.write().remove(field);
}
pub fn get_embedder(&self, field: &str) -> Arc<dyn Embedder> {
let guard = self.field_embedders.read();
guard
.get(field)
.cloned()
.unwrap_or_else(|| self.default_embedder.clone())
}
pub fn default_embedder(&self) -> &Arc<dyn Embedder> {
&self.default_embedder
}
pub async fn embed_field(&self, field: &str, input: &EmbedInput<'_>) -> Result<Vector> {
self.get_embedder(field).embed(input).await
}
pub fn configured_fields(&self) -> Vec<String> {
self.field_embedders.read().keys().cloned().collect()
}
pub fn field_supports(&self, field: &str, input_type: EmbedInputType) -> bool {
self.get_embedder(field).supports(input_type)
}
}
#[async_trait]
impl Embedder for PerFieldEmbedder {
async fn embed(&self, input: &EmbedInput<'_>) -> Result<Vector> {
self.default_embedder.embed(input).await
}
async fn embed_batch(&self, inputs: &[EmbedInput<'_>]) -> Result<Vec<Vector>> {
self.default_embedder.embed_batch(inputs).await
}
fn supported_input_types(&self) -> Vec<EmbedInputType> {
use std::collections::HashSet;
let mut types: HashSet<EmbedInputType> = self
.default_embedder
.supported_input_types()
.into_iter()
.collect();
let guard = self.field_embedders.read();
for emb in guard.values() {
for t in emb.supported_input_types() {
types.insert(t);
}
}
types.into_iter().collect()
}
fn name(&self) -> &str {
"PerFieldEmbedder"
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::LaurusError;
#[derive(Debug)]
struct MockEmbedder {
name: String,
dim: usize,
}
#[async_trait]
impl Embedder for MockEmbedder {
async fn embed(&self, input: &EmbedInput<'_>) -> Result<Vector> {
match input {
EmbedInput::Text(_) => Ok(Vector::new(vec![0.0; self.dim])),
_ => Err(LaurusError::invalid_argument("only text supported")),
}
}
fn supported_input_types(&self) -> Vec<EmbedInputType> {
vec![EmbedInputType::Text]
}
fn name(&self) -> &str {
&self.name
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[tokio::test]
async fn test_per_field_embedder() {
let default: Arc<dyn Embedder> = Arc::new(MockEmbedder {
name: "default".into(),
dim: 384,
});
let per_field = PerFieldEmbedder::new(default);
let title_embedder: Arc<dyn Embedder> = Arc::new(MockEmbedder {
name: "title".into(),
dim: 768,
});
per_field.add_embedder("title", Arc::clone(&title_embedder));
per_field.add_embedder("description", title_embedder);
let input = EmbedInput::Text("hello");
let title_vec = per_field.embed_field("title", &input).await.unwrap();
assert_eq!(title_vec.dimension(), 768);
let desc_vec = per_field.embed_field("description", &input).await.unwrap();
assert_eq!(desc_vec.dimension(), 768);
let content_vec = per_field.embed_field("content", &input).await.unwrap();
assert_eq!(content_vec.dimension(), 384);
}
#[tokio::test]
async fn test_default_embedder_when_field_not_configured() {
let default: Arc<dyn Embedder> = Arc::new(MockEmbedder {
name: "default".into(),
dim: 384,
});
let per_field = PerFieldEmbedder::new(default);
let input = EmbedInput::Text("hello");
let vec = per_field
.embed_field("unknown_field", &input)
.await
.unwrap();
assert_eq!(vec.dimension(), 384);
assert_eq!(per_field.get_embedder("unknown_field").name(), "default");
}
#[tokio::test]
async fn test_as_embedder_trait() {
let default: Arc<dyn Embedder> = Arc::new(MockEmbedder {
name: "default".into(),
dim: 384,
});
let per_field = PerFieldEmbedder::new(default);
let embedder: &dyn Embedder = &per_field;
assert!(embedder.supports_text());
let vec = embedder.embed(&EmbedInput::Text("hello")).await.unwrap();
assert_eq!(vec.dimension(), 384);
}
#[tokio::test]
async fn test_embed_field() {
let default: Arc<dyn Embedder> = Arc::new(MockEmbedder {
name: "default".into(),
dim: 384,
});
let per_field = PerFieldEmbedder::new(default);
let title_embedder: Arc<dyn Embedder> = Arc::new(MockEmbedder {
name: "title".into(),
dim: 768,
});
per_field.add_embedder("title", title_embedder);
let input = EmbedInput::Text("hello");
let vec = per_field.embed_field("title", &input).await.unwrap();
assert_eq!(vec.dimension(), 768);
let vec = per_field.embed_field("unknown", &input).await.unwrap();
assert_eq!(vec.dimension(), 384);
}
#[test]
fn test_configured_fields() {
let default: Arc<dyn Embedder> = Arc::new(MockEmbedder {
name: "default".into(),
dim: 384,
});
let per_field = PerFieldEmbedder::new(default);
let embedder: Arc<dyn Embedder> = Arc::new(MockEmbedder {
name: "special".into(),
dim: 512,
});
per_field.add_embedder("title", Arc::clone(&embedder));
per_field.add_embedder("body", embedder);
let fields = per_field.configured_fields();
assert!(fields.contains(&"title".to_string()));
assert!(fields.contains(&"body".to_string()));
assert!(!fields.contains(&"unknown".to_string()));
}
#[test]
fn test_field_supports() {
let default: Arc<dyn Embedder> = Arc::new(MockEmbedder {
name: "default".into(),
dim: 384,
});
let per_field = PerFieldEmbedder::new(default);
assert!(per_field.field_supports("any", EmbedInputType::Text));
assert!(!per_field.field_supports("any", EmbedInputType::Image));
}
}