cognis_rag/embeddings/
router.rs1use std::sync::Arc;
11
12use async_trait::async_trait;
13
14use cognis_core::{CognisError, Result};
15
16use super::Embeddings;
17
18pub trait EmbeddingRouter: Send + Sync {
20 fn pick_documents(&self, texts: &[String]) -> usize;
22 fn pick_query(&self, text: &str) -> usize;
24}
25
26pub struct FnRouter<F> {
28 f: F,
29}
30
31impl<F> FnRouter<F>
32where
33 F: Fn(&[String]) -> usize + Send + Sync,
34{
35 pub fn new(f: F) -> Self {
39 Self { f }
40 }
41}
42
43impl<F> EmbeddingRouter for FnRouter<F>
44where
45 F: Fn(&[String]) -> usize + Send + Sync,
46{
47 fn pick_documents(&self, texts: &[String]) -> usize {
48 (self.f)(texts)
49 }
50 fn pick_query(&self, text: &str) -> usize {
51 let v = vec![text.to_string()];
52 (self.f)(&v)
53 }
54}
55
56pub struct LengthRouter {
59 pub threshold: usize,
61 pub short_idx: usize,
63 pub long_idx: usize,
65}
66
67impl EmbeddingRouter for LengthRouter {
68 fn pick_documents(&self, texts: &[String]) -> usize {
69 let total: usize = texts.iter().map(|t| t.chars().count()).sum();
70 if total > self.threshold {
71 self.long_idx
72 } else {
73 self.short_idx
74 }
75 }
76 fn pick_query(&self, text: &str) -> usize {
77 if text.chars().count() > self.threshold {
78 self.long_idx
79 } else {
80 self.short_idx
81 }
82 }
83}
84
85pub struct EmbeddingsRouter {
87 providers: Vec<Arc<dyn Embeddings>>,
88 router: Arc<dyn EmbeddingRouter>,
89 name: String,
91}
92
93impl EmbeddingsRouter {
94 pub fn new<R: EmbeddingRouter + 'static>(
96 name: impl Into<String>,
97 providers: Vec<Arc<dyn Embeddings>>,
98 router: R,
99 ) -> Result<Self> {
100 if providers.is_empty() {
101 return Err(CognisError::Configuration(
102 "EmbeddingsRouter requires at least one provider".into(),
103 ));
104 }
105 Ok(Self {
106 providers,
107 router: Arc::new(router),
108 name: name.into(),
109 })
110 }
111
112 pub fn providers(&self) -> &[Arc<dyn Embeddings>] {
114 &self.providers
115 }
116
117 fn pick_documents(&self, texts: &[String]) -> &Arc<dyn Embeddings> {
118 let mut idx = self.router.pick_documents(texts);
119 if idx >= self.providers.len() {
120 idx = 0;
121 }
122 &self.providers[idx]
123 }
124
125 fn pick_query(&self, text: &str) -> &Arc<dyn Embeddings> {
126 let mut idx = self.router.pick_query(text);
127 if idx >= self.providers.len() {
128 idx = 0;
129 }
130 &self.providers[idx]
131 }
132}
133
134#[async_trait]
135impl Embeddings for EmbeddingsRouter {
136 async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
137 let p = self.pick_documents(&texts).clone();
138 p.embed_documents(texts).await
139 }
140 async fn embed_query(&self, text: String) -> Result<Vec<f32>> {
141 let p = self.pick_query(&text).clone();
142 p.embed_query(text).await
143 }
144 fn dimensions(&self) -> Option<usize> {
145 None
147 }
148 fn model(&self) -> &str {
149 &self.name
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use crate::embeddings::FakeEmbeddings;
157
158 struct Tagged {
160 tag: &'static str,
161 inner: Arc<dyn Embeddings>,
162 seen: std::sync::atomic::AtomicUsize,
163 }
164 #[async_trait]
165 impl Embeddings for Tagged {
166 async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
167 self.seen.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
168 self.inner.embed_documents(texts).await
169 }
170 fn model(&self) -> &str {
171 self.tag
172 }
173 }
174 fn tagged(tag: &'static str) -> Arc<Tagged> {
175 Arc::new(Tagged {
176 tag,
177 inner: Arc::new(FakeEmbeddings::new(4)),
178 seen: std::sync::atomic::AtomicUsize::new(0),
179 })
180 }
181
182 #[tokio::test]
183 async fn length_router_dispatches_short_and_long() {
184 let short = tagged("short");
185 let long = tagged("long");
186 let r = EmbeddingsRouter::new(
187 "router",
188 vec![
189 short.clone() as Arc<dyn Embeddings>,
190 long.clone() as Arc<dyn Embeddings>,
191 ],
192 LengthRouter {
193 threshold: 50,
194 short_idx: 0,
195 long_idx: 1,
196 },
197 )
198 .unwrap();
199 let _ = r.embed_documents(vec!["abc".into()]).await.unwrap();
200 let _ = r.embed_documents(vec!["x".repeat(100)]).await.unwrap();
201 assert_eq!(short.seen.load(std::sync::atomic::Ordering::SeqCst), 1);
202 assert_eq!(long.seen.load(std::sync::atomic::Ordering::SeqCst), 1);
203 }
204
205 #[tokio::test]
206 async fn closure_router_works() {
207 let a = tagged("a");
208 let b = tagged("b");
209 let r = EmbeddingsRouter::new(
210 "router",
211 vec![
212 a.clone() as Arc<dyn Embeddings>,
213 b.clone() as Arc<dyn Embeddings>,
214 ],
215 FnRouter::new(|texts: &[String]| {
216 if texts.iter().any(|t| t.starts_with('!')) {
217 1
218 } else {
219 0
220 }
221 }),
222 )
223 .unwrap();
224 let _ = r.embed_documents(vec!["plain".into()]).await.unwrap();
225 let _ = r.embed_documents(vec!["!special".into()]).await.unwrap();
226 assert_eq!(a.seen.load(std::sync::atomic::Ordering::SeqCst), 1);
227 assert_eq!(b.seen.load(std::sync::atomic::Ordering::SeqCst), 1);
228 }
229
230 #[tokio::test]
231 async fn out_of_range_index_clamps_to_zero() {
232 let a = tagged("a");
233 let r = EmbeddingsRouter::new(
234 "router",
235 vec![a.clone() as Arc<dyn Embeddings>],
236 FnRouter::new(|_| 99usize),
237 )
238 .unwrap();
239 let _ = r.embed_documents(vec!["x".into()]).await.unwrap();
240 assert_eq!(a.seen.load(std::sync::atomic::Ordering::SeqCst), 1);
241 }
242
243 #[test]
244 fn empty_providers_errors() {
245 let r = EmbeddingsRouter::new(
246 "router",
247 Vec::<Arc<dyn Embeddings>>::new(),
248 LengthRouter {
249 threshold: 0,
250 short_idx: 0,
251 long_idx: 0,
252 },
253 );
254 assert!(r.is_err());
255 }
256}