use anyhow::{anyhow, Context, Result};
use async_openai::{
config::OpenAIConfig, types::CreateEmbeddingRequestArgs, Client as OpenAIClient,
};
use async_trait::async_trait;
use std::sync::Arc;
#[async_trait]
pub trait EmbeddingsService: Send + Sync + 'static {
async fn create_embeddings(&self, model: &str, inputs: Vec<String>) -> Result<Vec<Vec<f32>>>;
}
#[derive(Clone)]
pub struct EmbeddingsClient {
model: String,
service: Arc<dyn EmbeddingsService>,
}
impl EmbeddingsClient {
pub fn new(model: impl Into<String>) -> Self {
Self::with_service(
model,
Arc::new(OpenAIEmbeddingsService::new()) as Arc<dyn EmbeddingsService>,
)
}
pub fn with_api_key(model: impl Into<String>, api_key: impl Into<String>) -> Self {
let service = OpenAIEmbeddingsService::with_api_key(api_key);
Self::with_service(model, Arc::new(service))
}
pub fn with_config(model: impl Into<String>, config: OpenAIConfig) -> Self {
let service = OpenAIEmbeddingsService::with_config(config);
Self::with_service(model, Arc::new(service))
}
pub fn with_service(model: impl Into<String>, service: Arc<dyn EmbeddingsService>) -> Self {
Self {
model: model.into(),
service,
}
}
pub async fn embed_batch<T>(&self, inputs: &[T]) -> Result<Vec<Vec<f32>>>
where
T: AsRef<str>,
{
if inputs.is_empty() {
return Ok(Vec::new());
}
let sanitized_inputs = inputs
.iter()
.map(|input| sanitize_embedding_input(input.as_ref()))
.collect::<Vec<_>>();
self.service
.create_embeddings(&self.model, sanitized_inputs)
.await
}
pub async fn embed(&self, input: &str) -> Result<Vec<f32>> {
let inputs = [input];
let mut embeddings = self.embed_batch(&inputs).await?;
Ok(embeddings.pop().unwrap_or_default())
}
}
fn sanitize_embedding_input(input: &str) -> String {
const MAX_LEN: usize = 4096;
let mut processed = input
.replace('\\', "\\\\")
.replace('\r', "\\r")
.replace('\n', "\\n");
if processed.len() > MAX_LEN {
processed.truncate(MAX_LEN);
processed.push_str("\\n[truncated]");
}
processed
}
#[cfg(test)]
mod embedding_sanitizer_tests {
use super::sanitize_embedding_input;
#[test]
fn sanitizes_newlines_and_backslashes() {
let raw = "line1\nline2\r\npath\\to\\file";
let sanitized = sanitize_embedding_input(raw);
assert_eq!(sanitized, "line1\\nline2\\r\\npath\\\\to\\\\file");
}
#[test]
fn truncates_long_payloads() {
let raw = "a".repeat(5000);
let sanitized = sanitize_embedding_input(&raw);
assert!(sanitized.ends_with("\\n[truncated]"));
assert!(sanitized.len() <= 4096 + "\\n[truncated]".len());
}
}
#[derive(Clone)]
pub struct OpenAIEmbeddingsService {
client: OpenAIClient<OpenAIConfig>,
}
impl Default for OpenAIEmbeddingsService {
fn default() -> Self {
Self::new()
}
}
impl OpenAIEmbeddingsService {
pub fn new() -> Self {
Self {
client: OpenAIClient::new(),
}
}
pub fn with_api_key(api_key: impl Into<String>) -> Self {
let config = OpenAIConfig::new().with_api_key(api_key);
Self::with_config(config)
}
pub fn with_config(config: OpenAIConfig) -> Self {
Self {
client: OpenAIClient::with_config(config),
}
}
}
#[async_trait]
impl EmbeddingsService for OpenAIEmbeddingsService {
async fn create_embeddings(&self, model: &str, inputs: Vec<String>) -> Result<Vec<Vec<f32>>> {
if inputs.is_empty() {
return Ok(Vec::new());
}
let request = CreateEmbeddingRequestArgs::default()
.model(model)
.input(inputs)
.build()
.context("Failed to build embedding request")?;
let response = self
.client
.embeddings()
.create(request)
.await
.context("OpenAI embeddings request failed")?;
let embeddings = response
.data
.into_iter()
.map(|item| item.embedding)
.collect::<Vec<_>>();
if embeddings.is_empty() {
Err(anyhow!("OpenAI embeddings response was empty"))
} else {
Ok(embeddings)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::anyhow;
use async_trait::async_trait;
use std::sync::Arc;
#[derive(Clone)]
struct DummyService {
embeddings: Vec<Vec<f32>>,
fail: bool,
}
impl DummyService {
fn ok_single(embedding: Vec<f32>) -> Self {
Self {
embeddings: vec![embedding],
fail: false,
}
}
fn ok_batch(embeddings: Vec<Vec<f32>>) -> Self {
Self {
embeddings,
fail: false,
}
}
fn err() -> Self {
Self {
embeddings: Vec::new(),
fail: true,
}
}
}
#[async_trait]
impl EmbeddingsService for DummyService {
async fn create_embeddings(
&self,
_model: &str,
_inputs: Vec<String>,
) -> Result<Vec<Vec<f32>>> {
if self.fail {
return Err(anyhow!("boom"));
}
if self.embeddings.is_empty() {
return Ok(Vec::new());
}
Ok(self.embeddings.clone())
}
}
#[tokio::test]
async fn embed_returns_the_service_embedding() {
let embedding = vec![0.1, 0.2];
let service = Arc::new(DummyService::ok_single(embedding.clone()));
let client = EmbeddingsClient::with_service("model", service);
let result = client.embed("input").await.unwrap();
assert_eq!(result, embedding);
}
#[tokio::test]
async fn embed_propagates_errors() {
let service = Arc::new(DummyService::err());
let client = EmbeddingsClient::with_service("model", service);
let result = client.embed("input").await;
assert!(result.is_err());
}
#[tokio::test]
async fn embed_batch_returns_all_embeddings() {
let service = Arc::new(DummyService::ok_batch(vec![vec![0.1, 0.2], vec![0.3, 0.4]]));
let client = EmbeddingsClient::with_service("model", service);
let inputs = ["first", "second"];
let result = client.embed_batch(&inputs).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0], vec![0.1, 0.2]);
assert_eq!(result[1], vec![0.3, 0.4]);
}
}