Skip to main content

cognis_rag/embeddings/
router.rs

1//! Embeddings router — dispatch each call to one of N inner providers
2//! based on a pluggable strategy.
3//!
4//! Use cases:
5//! - Long-text → high-context embedder; short-text → cheaper embedder.
6//! - Per-tenant routing for billing isolation.
7//! - Failover during a provider outage (paired with [`super::Embeddings`]
8//!   wrappers like CachedEmbeddings).
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13
14use cognis_core::{CognisError, Result};
15
16use super::Embeddings;
17
18/// Pluggable routing decision: an index into the configured provider list.
19pub trait EmbeddingRouter: Send + Sync {
20    /// Pick a provider for this batch (`embed_documents`) call.
21    fn pick_documents(&self, texts: &[String]) -> usize;
22    /// Pick a provider for this single-query call.
23    fn pick_query(&self, text: &str) -> usize;
24}
25
26/// Closure-based router. Both methods use the same closure.
27pub struct FnRouter<F> {
28    f: F,
29}
30
31impl<F> FnRouter<F>
32where
33    F: Fn(&[String]) -> usize + Send + Sync,
34{
35    /// Build with a single closure that picks based on a slice of inputs.
36    /// `pick_query` wraps the query text in a single-element slice
37    /// before delegating.
38    pub fn new(f: F) -> Self {
39        Self { f }
40    }
41}
42
43impl<F> EmbeddingRouter for FnRouter<F>
44where
45    F: Fn(&[String]) -> usize + Send + Sync,
46{
47    fn pick_documents(&self, texts: &[String]) -> usize {
48        (self.f)(texts)
49    }
50    fn pick_query(&self, text: &str) -> usize {
51        let v = vec![text.to_string()];
52        (self.f)(&v)
53    }
54}
55
56/// Stock router: route by length. Inputs whose total chars across the
57/// batch exceed `threshold` go to `long_idx`; otherwise `short_idx`.
58pub struct LengthRouter {
59    /// Total-chars threshold for "long" routing.
60    pub threshold: usize,
61    /// Provider index used for "short" inputs.
62    pub short_idx: usize,
63    /// Provider index used for "long" inputs.
64    pub long_idx: usize,
65}
66
67impl EmbeddingRouter for LengthRouter {
68    fn pick_documents(&self, texts: &[String]) -> usize {
69        let total: usize = texts.iter().map(|t| t.chars().count()).sum();
70        if total > self.threshold {
71            self.long_idx
72        } else {
73            self.short_idx
74        }
75    }
76    fn pick_query(&self, text: &str) -> usize {
77        if text.chars().count() > self.threshold {
78            self.long_idx
79        } else {
80            self.short_idx
81        }
82    }
83}
84
85/// Embedding-provider router.
86pub struct EmbeddingsRouter {
87    providers: Vec<Arc<dyn Embeddings>>,
88    router: Arc<dyn EmbeddingRouter>,
89    /// Reported model id (used as the union name).
90    name: String,
91}
92
93impl EmbeddingsRouter {
94    /// Build with a list of providers and a routing strategy.
95    pub fn new<R: EmbeddingRouter + 'static>(
96        name: impl Into<String>,
97        providers: Vec<Arc<dyn Embeddings>>,
98        router: R,
99    ) -> Result<Self> {
100        if providers.is_empty() {
101            return Err(CognisError::Configuration(
102                "EmbeddingsRouter requires at least one provider".into(),
103            ));
104        }
105        Ok(Self {
106            providers,
107            router: Arc::new(router),
108            name: name.into(),
109        })
110    }
111
112    /// Borrow the configured providers (read-only).
113    pub fn providers(&self) -> &[Arc<dyn Embeddings>] {
114        &self.providers
115    }
116
117    fn pick_documents(&self, texts: &[String]) -> &Arc<dyn Embeddings> {
118        let mut idx = self.router.pick_documents(texts);
119        if idx >= self.providers.len() {
120            idx = 0;
121        }
122        &self.providers[idx]
123    }
124
125    fn pick_query(&self, text: &str) -> &Arc<dyn Embeddings> {
126        let mut idx = self.router.pick_query(text);
127        if idx >= self.providers.len() {
128            idx = 0;
129        }
130        &self.providers[idx]
131    }
132}
133
134#[async_trait]
135impl Embeddings for EmbeddingsRouter {
136    async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
137        let p = self.pick_documents(&texts).clone();
138        p.embed_documents(texts).await
139    }
140    async fn embed_query(&self, text: String) -> Result<Vec<f32>> {
141        let p = self.pick_query(&text).clone();
142        p.embed_query(text).await
143    }
144    fn dimensions(&self) -> Option<usize> {
145        // Unknown — providers may have different dims.
146        None
147    }
148    fn model(&self) -> &str {
149        &self.name
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use crate::embeddings::FakeEmbeddings;
157
158    /// Tag-bearing fake embedder so we can identify which provider ran.
159    struct Tagged {
160        tag: &'static str,
161        inner: Arc<dyn Embeddings>,
162        seen: std::sync::atomic::AtomicUsize,
163    }
164    #[async_trait]
165    impl Embeddings for Tagged {
166        async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
167            self.seen.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
168            self.inner.embed_documents(texts).await
169        }
170        fn model(&self) -> &str {
171            self.tag
172        }
173    }
174    fn tagged(tag: &'static str) -> Arc<Tagged> {
175        Arc::new(Tagged {
176            tag,
177            inner: Arc::new(FakeEmbeddings::new(4)),
178            seen: std::sync::atomic::AtomicUsize::new(0),
179        })
180    }
181
182    #[tokio::test]
183    async fn length_router_dispatches_short_and_long() {
184        let short = tagged("short");
185        let long = tagged("long");
186        let r = EmbeddingsRouter::new(
187            "router",
188            vec![
189                short.clone() as Arc<dyn Embeddings>,
190                long.clone() as Arc<dyn Embeddings>,
191            ],
192            LengthRouter {
193                threshold: 50,
194                short_idx: 0,
195                long_idx: 1,
196            },
197        )
198        .unwrap();
199        let _ = r.embed_documents(vec!["abc".into()]).await.unwrap();
200        let _ = r.embed_documents(vec!["x".repeat(100)]).await.unwrap();
201        assert_eq!(short.seen.load(std::sync::atomic::Ordering::SeqCst), 1);
202        assert_eq!(long.seen.load(std::sync::atomic::Ordering::SeqCst), 1);
203    }
204
205    #[tokio::test]
206    async fn closure_router_works() {
207        let a = tagged("a");
208        let b = tagged("b");
209        let r = EmbeddingsRouter::new(
210            "router",
211            vec![
212                a.clone() as Arc<dyn Embeddings>,
213                b.clone() as Arc<dyn Embeddings>,
214            ],
215            FnRouter::new(|texts: &[String]| {
216                if texts.iter().any(|t| t.starts_with('!')) {
217                    1
218                } else {
219                    0
220                }
221            }),
222        )
223        .unwrap();
224        let _ = r.embed_documents(vec!["plain".into()]).await.unwrap();
225        let _ = r.embed_documents(vec!["!special".into()]).await.unwrap();
226        assert_eq!(a.seen.load(std::sync::atomic::Ordering::SeqCst), 1);
227        assert_eq!(b.seen.load(std::sync::atomic::Ordering::SeqCst), 1);
228    }
229
230    #[tokio::test]
231    async fn out_of_range_index_clamps_to_zero() {
232        let a = tagged("a");
233        let r = EmbeddingsRouter::new(
234            "router",
235            vec![a.clone() as Arc<dyn Embeddings>],
236            FnRouter::new(|_| 99usize),
237        )
238        .unwrap();
239        let _ = r.embed_documents(vec!["x".into()]).await.unwrap();
240        assert_eq!(a.seen.load(std::sync::atomic::Ordering::SeqCst), 1);
241    }
242
243    #[test]
244    fn empty_providers_errors() {
245        let r = EmbeddingsRouter::new(
246            "router",
247            Vec::<Arc<dyn Embeddings>>::new(),
248            LengthRouter {
249                threshold: 0,
250                short_idx: 0,
251                long_idx: 0,
252            },
253        );
254        assert!(r.is_err());
255    }
256}