use crate::{FierrosError, FierrosResult};
use async_trait::async_trait;
#[async_trait]
pub trait Embedder: Send + Sync {
async fn embed(&self, inputs: &[String]) -> FierrosResult<Vec<Vec<f32>>>;
}
#[derive(Debug, Clone)]
pub struct MockEmbedder {
dimension: usize,
forced_error: Option<FierrosError>,
}
impl MockEmbedder {
pub fn new(dimension: usize) -> Self {
Self {
dimension,
forced_error: None,
}
}
pub fn failing(dimension: usize, error: FierrosError) -> Self {
Self {
dimension,
forced_error: Some(error),
}
}
pub fn with_error(mut self, error: FierrosError) -> Self {
self.forced_error = Some(error);
self
}
fn embed_one(&self, input: &str) -> Vec<f32> {
let mut out = vec![0.0; self.dimension];
if self.dimension == 0 {
return out;
}
for (idx, byte) in input.bytes().enumerate() {
out[idx % self.dimension] += byte as f32 / 255.0;
}
out
}
}
#[async_trait]
impl Embedder for MockEmbedder {
async fn embed(&self, inputs: &[String]) -> FierrosResult<Vec<Vec<f32>>> {
if let Some(error) = &self.forced_error {
return Err(error.clone());
}
if self.dimension == 0 {
return Err(FierrosError::InvalidInput(
"embedding dimension must be greater than zero".into(),
));
}
Ok(inputs.iter().map(|input| self.embed_one(input)).collect())
}
}
#[cfg(test)]
mod tests {
use super::{Embedder, MockEmbedder};
use crate::FierrosError;
#[tokio::test]
async fn mock_embedder_returns_expected_dimensions() {
let embedder = MockEmbedder::new(4);
let embeddings = embedder
.embed(&["hello".to_string(), "world".to_string()])
.await
.unwrap();
assert_eq!(embeddings.len(), 2);
assert_eq!(embeddings[0].len(), 4);
}
#[tokio::test]
async fn mock_embedder_rejects_zero_dimension() {
let embedder = MockEmbedder::new(0);
assert!(embedder.embed(&["x".to_string()]).await.is_err());
}
#[tokio::test]
async fn mock_embedder_can_return_configured_error() {
let embedder = MockEmbedder::failing(
4,
FierrosError::Provider("embedding endpoint timeout".into()),
);
let error = embedder.embed(&["x".to_string()]).await.unwrap_err();
assert_eq!(
error,
FierrosError::Provider("embedding endpoint timeout".into())
);
}
}