use std::sync::Arc;
use async_trait::async_trait;
use cognis_core::{CognisError, Result};
use super::Embeddings;
pub trait EmbeddingRouter: Send + Sync {
fn pick_documents(&self, texts: &[String]) -> usize;
fn pick_query(&self, text: &str) -> usize;
}
pub struct FnRouter<F> {
f: F,
}
impl<F> FnRouter<F>
where
F: Fn(&[String]) -> usize + Send + Sync,
{
pub fn new(f: F) -> Self {
Self { f }
}
}
impl<F> EmbeddingRouter for FnRouter<F>
where
F: Fn(&[String]) -> usize + Send + Sync,
{
fn pick_documents(&self, texts: &[String]) -> usize {
(self.f)(texts)
}
fn pick_query(&self, text: &str) -> usize {
let v = vec![text.to_string()];
(self.f)(&v)
}
}
pub struct LengthRouter {
pub threshold: usize,
pub short_idx: usize,
pub long_idx: usize,
}
impl EmbeddingRouter for LengthRouter {
fn pick_documents(&self, texts: &[String]) -> usize {
let total: usize = texts.iter().map(|t| t.chars().count()).sum();
if total > self.threshold {
self.long_idx
} else {
self.short_idx
}
}
fn pick_query(&self, text: &str) -> usize {
if text.chars().count() > self.threshold {
self.long_idx
} else {
self.short_idx
}
}
}
pub struct EmbeddingsRouter {
providers: Vec<Arc<dyn Embeddings>>,
router: Arc<dyn EmbeddingRouter>,
name: String,
}
impl EmbeddingsRouter {
pub fn new<R: EmbeddingRouter + 'static>(
name: impl Into<String>,
providers: Vec<Arc<dyn Embeddings>>,
router: R,
) -> Result<Self> {
if providers.is_empty() {
return Err(CognisError::Configuration(
"EmbeddingsRouter requires at least one provider".into(),
));
}
Ok(Self {
providers,
router: Arc::new(router),
name: name.into(),
})
}
pub fn providers(&self) -> &[Arc<dyn Embeddings>] {
&self.providers
}
fn pick_documents(&self, texts: &[String]) -> &Arc<dyn Embeddings> {
let mut idx = self.router.pick_documents(texts);
if idx >= self.providers.len() {
idx = 0;
}
&self.providers[idx]
}
fn pick_query(&self, text: &str) -> &Arc<dyn Embeddings> {
let mut idx = self.router.pick_query(text);
if idx >= self.providers.len() {
idx = 0;
}
&self.providers[idx]
}
}
#[async_trait]
impl Embeddings for EmbeddingsRouter {
async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
let p = self.pick_documents(&texts).clone();
p.embed_documents(texts).await
}
async fn embed_query(&self, text: String) -> Result<Vec<f32>> {
let p = self.pick_query(&text).clone();
p.embed_query(text).await
}
fn dimensions(&self) -> Option<usize> {
None
}
fn model(&self) -> &str {
&self.name
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::embeddings::FakeEmbeddings;
struct Tagged {
tag: &'static str,
inner: Arc<dyn Embeddings>,
seen: std::sync::atomic::AtomicUsize,
}
#[async_trait]
impl Embeddings for Tagged {
async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
self.seen.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
self.inner.embed_documents(texts).await
}
fn model(&self) -> &str {
self.tag
}
}
fn tagged(tag: &'static str) -> Arc<Tagged> {
Arc::new(Tagged {
tag,
inner: Arc::new(FakeEmbeddings::new(4)),
seen: std::sync::atomic::AtomicUsize::new(0),
})
}
#[tokio::test]
async fn length_router_dispatches_short_and_long() {
let short = tagged("short");
let long = tagged("long");
let r = EmbeddingsRouter::new(
"router",
vec![
short.clone() as Arc<dyn Embeddings>,
long.clone() as Arc<dyn Embeddings>,
],
LengthRouter {
threshold: 50,
short_idx: 0,
long_idx: 1,
},
)
.unwrap();
let _ = r.embed_documents(vec!["abc".into()]).await.unwrap();
let _ = r.embed_documents(vec!["x".repeat(100)]).await.unwrap();
assert_eq!(short.seen.load(std::sync::atomic::Ordering::SeqCst), 1);
assert_eq!(long.seen.load(std::sync::atomic::Ordering::SeqCst), 1);
}
#[tokio::test]
async fn closure_router_works() {
let a = tagged("a");
let b = tagged("b");
let r = EmbeddingsRouter::new(
"router",
vec![
a.clone() as Arc<dyn Embeddings>,
b.clone() as Arc<dyn Embeddings>,
],
FnRouter::new(|texts: &[String]| {
if texts.iter().any(|t| t.starts_with('!')) {
1
} else {
0
}
}),
)
.unwrap();
let _ = r.embed_documents(vec!["plain".into()]).await.unwrap();
let _ = r.embed_documents(vec!["!special".into()]).await.unwrap();
assert_eq!(a.seen.load(std::sync::atomic::Ordering::SeqCst), 1);
assert_eq!(b.seen.load(std::sync::atomic::Ordering::SeqCst), 1);
}
#[tokio::test]
async fn out_of_range_index_clamps_to_zero() {
let a = tagged("a");
let r = EmbeddingsRouter::new(
"router",
vec![a.clone() as Arc<dyn Embeddings>],
FnRouter::new(|_| 99usize),
)
.unwrap();
let _ = r.embed_documents(vec!["x".into()]).await.unwrap();
assert_eq!(a.seen.load(std::sync::atomic::Ordering::SeqCst), 1);
}
#[test]
fn empty_providers_errors() {
let r = EmbeddingsRouter::new(
"router",
Vec::<Arc<dyn Embeddings>>::new(),
LengthRouter {
threshold: 0,
short_idx: 0,
long_idx: 0,
},
);
assert!(r.is_err());
}
}