use std::sync::Arc;
use async_trait::async_trait;
use crate::embedding_model::{EmbedOptions, EmbedResult, EmbeddingModel};
use crate::error::Result;
#[async_trait]
pub trait EmbeddingModelMiddleware: Send + Sync + std::fmt::Debug {
fn override_provider(&self, _inner: &dyn EmbeddingModel) -> Option<String> {
None
}
fn override_model_id(&self, _inner: &dyn EmbeddingModel) -> Option<String> {
None
}
async fn override_max_embeddings_per_call(
&self,
_inner: &dyn EmbeddingModel,
) -> Option<Option<u32>> {
None
}
async fn override_supports_parallel_calls(&self, _inner: &dyn EmbeddingModel) -> Option<bool> {
None
}
async fn transform_params(
&self,
params: EmbedOptions,
_inner: &dyn EmbeddingModel,
) -> Result<EmbedOptions> {
Ok(params)
}
async fn wrap_embed(
&self,
next: &dyn EmbeddingModel,
params: EmbedOptions,
) -> Result<EmbedResult> {
next.do_embed(params).await
}
}
pub fn wrap_embedding_model<I>(
model: Arc<dyn EmbeddingModel>,
middleware: I,
) -> Arc<dyn EmbeddingModel>
where
I: IntoIterator<Item = Arc<dyn EmbeddingModelMiddleware>>,
{
let mut layers: Vec<Arc<dyn EmbeddingModelMiddleware>> = middleware.into_iter().collect();
layers.reverse();
layers
.into_iter()
.fold(model, |inner, mw| Arc::new(Wrapped::new(inner, mw)))
}
struct Wrapped {
inner: Arc<dyn EmbeddingModel>,
middleware: Arc<dyn EmbeddingModelMiddleware>,
provider: String,
model_id: String,
}
impl Wrapped {
fn new(inner: Arc<dyn EmbeddingModel>, middleware: Arc<dyn EmbeddingModelMiddleware>) -> Self {
let provider = middleware
.override_provider(inner.as_ref())
.unwrap_or_else(|| inner.provider().to_owned());
let model_id = middleware
.override_model_id(inner.as_ref())
.unwrap_or_else(|| inner.model_id().to_owned());
Self {
inner,
middleware,
provider,
model_id,
}
}
}
impl std::fmt::Debug for Wrapped {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Wrapped")
.field("provider", &self.provider)
.field("model_id", &self.model_id)
.field("middleware", &self.middleware)
.field("inner", &self.inner)
.finish()
}
}
#[async_trait]
impl EmbeddingModel for Wrapped {
fn provider(&self) -> &str {
&self.provider
}
fn model_id(&self) -> &str {
&self.model_id
}
async fn max_embeddings_per_call(&self) -> Option<u32> {
if let Some(custom) = self
.middleware
.override_max_embeddings_per_call(self.inner.as_ref())
.await
{
return custom;
}
self.inner.max_embeddings_per_call().await
}
async fn supports_parallel_calls(&self) -> bool {
if let Some(custom) = self
.middleware
.override_supports_parallel_calls(self.inner.as_ref())
.await
{
return custom;
}
self.inner.supports_parallel_calls().await
}
async fn do_embed(&self, options: EmbedOptions) -> Result<EmbedResult> {
let transformed = self
.middleware
.transform_params(options, self.inner.as_ref())
.await?;
self.middleware
.wrap_embed(self.inner.as_ref(), transformed)
.await
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use super::*;
#[derive(Debug, Default)]
struct MockEmbed {
provider: String,
model_id: String,
calls: AtomicUsize,
last_input_len: Mutex<usize>,
}
impl MockEmbed {
fn new(provider: &str, model_id: &str) -> Self {
Self {
provider: provider.to_owned(),
model_id: model_id.to_owned(),
calls: AtomicUsize::new(0),
last_input_len: Mutex::new(0),
}
}
}
#[async_trait]
impl EmbeddingModel for MockEmbed {
fn provider(&self) -> &str {
&self.provider
}
fn model_id(&self) -> &str {
&self.model_id
}
async fn do_embed(&self, options: EmbedOptions) -> Result<EmbedResult> {
self.calls.fetch_add(1, Ordering::SeqCst);
*self.last_input_len.lock().expect("mutex") = options.values.len();
Ok(EmbedResult {
embeddings: options.values.iter().map(|_| vec![0.0; 3]).collect(),
usage: None,
provider_metadata: None,
request: None,
response: None,
})
}
}
#[derive(Debug)]
struct OverrideAndDoubleInputs;
#[async_trait]
impl EmbeddingModelMiddleware for OverrideAndDoubleInputs {
fn override_provider(&self, _inner: &dyn EmbeddingModel) -> Option<String> {
Some("wrapped".to_owned())
}
async fn override_max_embeddings_per_call(
&self,
_inner: &dyn EmbeddingModel,
) -> Option<Option<u32>> {
Some(Some(42))
}
async fn transform_params(
&self,
mut params: EmbedOptions,
_inner: &dyn EmbeddingModel,
) -> Result<EmbedOptions> {
let original = params.values.clone();
params.values.extend(original);
Ok(params)
}
}
#[tokio::test]
async fn empty_middleware_returns_unchanged() {
let model = Arc::new(MockEmbed::new("p", "m"));
let wrapped: Arc<dyn EmbeddingModel> =
wrap_embedding_model(Arc::clone(&model) as _, Vec::new());
assert_eq!(wrapped.provider(), "p");
assert_eq!(wrapped.model_id(), "m");
}
#[tokio::test]
async fn overrides_and_transform_run() {
let model = Arc::new(MockEmbed::new("p", "m"));
let wrapped = wrap_embedding_model(
Arc::clone(&model) as _,
[Arc::new(OverrideAndDoubleInputs) as Arc<dyn EmbeddingModelMiddleware>],
);
assert_eq!(wrapped.provider(), "wrapped");
assert_eq!(wrapped.max_embeddings_per_call().await, Some(42));
wrapped
.do_embed(EmbedOptions {
values: vec!["a".into(), "b".into()],
..Default::default()
})
.await
.expect("embed");
assert_eq!(model.calls.load(Ordering::SeqCst), 1);
assert_eq!(*model.last_input_len.lock().expect("mutex"), 4);
}
}