use std::any::Any;
use std::fmt::Debug;
use async_trait::async_trait;
use crate::error::Result;
use crate::vector::core::vector::Vector;
#[derive(Debug, Clone)]
pub enum EmbedInput<'a> {
Text(&'a str),
Bytes(&'a [u8], Option<&'a str>),
}
impl<'a> EmbedInput<'a> {
pub fn input_type(&self) -> EmbedInputType {
match self {
EmbedInput::Text(_) => EmbedInputType::Text,
EmbedInput::Bytes(_, mime) => {
if let Some(mime) = mime
&& mime.starts_with("text/")
{
return EmbedInputType::Text;
}
EmbedInputType::Image
}
}
}
pub fn is_text(&self) -> bool {
match self {
EmbedInput::Text(_) => true,
EmbedInput::Bytes(_, mime) => mime.is_some_and(|m| m.starts_with("text/")),
}
}
pub fn is_image(&self) -> bool {
match self {
EmbedInput::Bytes(_, Some(mime)) => mime.starts_with("image/"),
_ => false,
}
}
pub fn as_text(&self) -> Option<&'a str> {
match self {
EmbedInput::Text(text) => Some(text),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum EmbedInputType {
Text,
Image,
}
#[async_trait]
pub trait Embedder: Send + Sync + Debug {
async fn embed(&self, input: &EmbedInput<'_>) -> Result<Vector>;
async fn embed_batch(&self, inputs: &[EmbedInput<'_>]) -> Result<Vec<Vector>> {
let mut results = Vec::with_capacity(inputs.len());
for input in inputs {
results.push(self.embed(input).await?);
}
Ok(results)
}
fn supported_input_types(&self) -> Vec<EmbedInputType>;
fn supports(&self, input_type: EmbedInputType) -> bool {
self.supported_input_types().contains(&input_type)
}
fn supports_text(&self) -> bool {
self.supports(EmbedInputType::Text)
}
fn supports_image(&self) -> bool {
self.supports(EmbedInputType::Image)
}
fn is_multimodal(&self) -> bool {
self.supports_text() && self.supports_image()
}
fn name(&self) -> &str {
"unknown"
}
fn as_any(&self) -> &dyn Any;
}
#[cfg(test)]
mod tests {
use crate::error::LaurusError;
use super::*;
#[derive(Debug)]
struct MockTextEmbedder {
dimension: usize,
}
#[async_trait]
impl Embedder for MockTextEmbedder {
async fn embed(&self, input: &EmbedInput<'_>) -> Result<Vector> {
match input {
EmbedInput::Text(_) => Ok(Vector::new(vec![0.0; self.dimension])),
_ => Err(LaurusError::invalid_argument(
"this embedder only supports text input",
)),
}
}
fn supported_input_types(&self) -> Vec<EmbedInputType> {
vec![EmbedInputType::Text]
}
fn name(&self) -> &str {
"mock-text"
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[derive(Debug)]
struct MockMultimodalEmbedder;
#[async_trait]
impl Embedder for MockMultimodalEmbedder {
async fn embed(&self, input: &EmbedInput<'_>) -> Result<Vector> {
match input {
EmbedInput::Text(_) | EmbedInput::Bytes(_, _) => Ok(Vector::new(vec![0.0; 3])),
}
}
fn supported_input_types(&self) -> Vec<EmbedInputType> {
vec![EmbedInputType::Text, EmbedInputType::Image]
}
fn name(&self) -> &str {
"mock-multimodal"
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[test]
fn test_embed_input_type() {
assert_eq!(EmbedInput::Text("hello").input_type(), EmbedInputType::Text);
assert_eq!(
EmbedInput::Bytes(&[0, 1, 2], None).input_type(),
EmbedInputType::Image
);
assert_eq!(
EmbedInput::Bytes(&[0, 1, 2], Some("text/plain")).input_type(),
EmbedInputType::Text
);
}
#[test]
fn test_text_embedder_supports() {
let embedder = MockTextEmbedder { dimension: 384 };
assert!(embedder.supports_text());
assert!(!embedder.supports_image());
assert!(!embedder.is_multimodal());
}
#[test]
fn test_multimodal_embedder_supports() {
let embedder = MockMultimodalEmbedder;
assert!(embedder.supports_text());
assert!(embedder.supports_image());
assert!(embedder.is_multimodal());
}
#[tokio::test]
async fn test_text_embedder_embed() {
let embedder = MockTextEmbedder { dimension: 384 };
let result = embedder.embed(&EmbedInput::Text("hello")).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().data.len(), 384);
let result = embedder.embed(&EmbedInput::Bytes(&[], None)).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_multimodal_embedder_embed() {
let embedder = MockMultimodalEmbedder;
let text_result = embedder.embed(&EmbedInput::Text("hello")).await;
assert!(text_result.is_ok());
let image_result = embedder.embed(&EmbedInput::Bytes(&[], None)).await;
assert!(image_result.is_ok());
}
}