Skip to main content

cognis_rag/
example_selectors.rs

1//! Embedding-driven example selectors for few-shot prompts.
2//!
3//! Implements [`cognis_core::prompts::ExampleSelector`] for example
4//! pools where similarity to the input determines which examples to
5//! include.
6//!
7//! Two strategies are provided:
8//! - [`SemanticSimilarityExampleSelector`] — top-k by similarity to the
9//!   input. Cheap and effective when examples are diverse.
10//! - [`MmrExampleSelector`] — Maximal Marginal Relevance: balances
11//!   relevance to the input with novelty among already-picked examples.
12//!   Useful when the pool contains near-duplicates.
13//!
14//! Both delegate the actual example→string conversion to a user-supplied
15//! closure, so the selector works for any `E` type the user can describe.
16
17use std::sync::Arc;
18
19use async_trait::async_trait;
20
21use cognis_core::prompts::ExampleSelector;
22use cognis_core::{CognisError, Result};
23
24use crate::distance::Distance;
25use crate::embeddings::Embeddings;
26
27/// Function that turns an example into the text we embed when selecting.
28/// Often the same renderer used to inject the example into the prompt.
29pub type ExampleTextFn<E> = Arc<dyn Fn(&E) -> String + Send + Sync>;
30
31/// Cache mode for embedded examples. The pool's embeddings are the same
32/// across calls, so most users want `Cached` (the default).
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum EmbedMode {
35    /// Re-embed the pool on every `select` call. Slower but always
36    /// reflects a possibly-mutated pool.
37    Fresh,
38    /// Embed the pool once on first call and reuse. Use when the pool is
39    /// stable for the lifetime of the selector.
40    Cached,
41}
42
43// ---------------------------------------------------------------------------
44// SemanticSimilarityExampleSelector
45// ---------------------------------------------------------------------------
46
47/// Pick the top-`k` examples whose embeddings are most similar to the
48/// input.
49pub struct SemanticSimilarityExampleSelector<E> {
50    embeddings: Arc<dyn Embeddings>,
51    k: usize,
52    distance: Distance,
53    text_of: ExampleTextFn<E>,
54    mode: EmbedMode,
55    // Cached pool embeddings + original index. Wrapped in tokio::Mutex
56    // because we may need to embed inside the async `select`.
57    cache: Arc<tokio::sync::Mutex<Option<Vec<Vec<f32>>>>>,
58}
59
60impl<E> SemanticSimilarityExampleSelector<E>
61where
62    E: Send + Sync + 'static,
63{
64    /// Build a selector that picks the top-`k` examples by similarity.
65    pub fn new<F>(embeddings: Arc<dyn Embeddings>, k: usize, text_of: F) -> Self
66    where
67        F: Fn(&E) -> String + Send + Sync + 'static,
68    {
69        Self {
70            embeddings,
71            k,
72            distance: Distance::Cosine,
73            text_of: Arc::new(text_of),
74            mode: EmbedMode::Cached,
75            cache: Arc::new(tokio::sync::Mutex::new(None)),
76        }
77    }
78
79    /// Override the distance metric (default: Cosine).
80    pub fn with_distance(mut self, d: Distance) -> Self {
81        self.distance = d;
82        self
83    }
84
85    /// Override the embed mode (default: Cached).
86    pub fn with_embed_mode(mut self, m: EmbedMode) -> Self {
87        self.mode = m;
88        // Reset the cache when switching modes.
89        self.cache = Arc::new(tokio::sync::Mutex::new(None));
90        self
91    }
92
93    /// Embed the pool, optionally reading or populating the cache.
94    async fn embed_pool(&self, examples: &[E]) -> Result<Vec<Vec<f32>>> {
95        if matches!(self.mode, EmbedMode::Cached) {
96            let mut guard = self.cache.lock().await;
97            if let Some(cached) = guard.as_ref() {
98                if cached.len() == examples.len() {
99                    return Ok(cached.clone());
100                }
101            }
102            let texts: Vec<String> = examples.iter().map(|e| (self.text_of)(e)).collect();
103            let vecs = self.embeddings.embed_documents(texts).await?;
104            *guard = Some(vecs.clone());
105            return Ok(vecs);
106        }
107        let texts: Vec<String> = examples.iter().map(|e| (self.text_of)(e)).collect();
108        self.embeddings.embed_documents(texts).await
109    }
110
111    /// Async select. Trait method must be sync; this is the underlying
112    /// implementation.
113    async fn select_async(&self, input: &str, examples: &[E]) -> Result<Vec<E>>
114    where
115        E: Clone,
116    {
117        if examples.is_empty() {
118            return Ok(Vec::new());
119        }
120        let q = self.embeddings.embed_query(input.to_string()).await?;
121        let pool_vecs = self.embed_pool(examples).await?;
122        let mut scored: Vec<(usize, f32)> = pool_vecs
123            .iter()
124            .enumerate()
125            .map(|(i, v)| (i, self.distance.similarity(&q, v)))
126            .collect();
127        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
128        Ok(scored
129            .into_iter()
130            .take(self.k.min(examples.len()))
131            .map(|(i, _)| examples[i].clone())
132            .collect())
133    }
134}
135
136#[async_trait]
137impl<E> AsyncExampleSelector<E> for SemanticSimilarityExampleSelector<E>
138where
139    E: Clone + Send + Sync + 'static,
140{
141    async fn select_async(&self, input: &str, examples: &[E]) -> Result<Vec<E>> {
142        SemanticSimilarityExampleSelector::select_async(self, input, examples).await
143    }
144}
145
146impl<E> ExampleSelector<E> for SemanticSimilarityExampleSelector<E>
147where
148    E: Clone + Send + Sync + 'static,
149{
150    fn select(&self, input: &str, examples: &[E]) -> Result<Vec<E>> {
151        // The trait method is sync; we run the async impl on the current
152        // tokio runtime. This is the same pattern V1's `LengthBasedExampleSelector`
153        // uses to bridge async embeddings into a sync trait.
154        let handle = tokio::runtime::Handle::try_current().map_err(|_| {
155            CognisError::Configuration(
156                "SemanticSimilarityExampleSelector::select called outside a tokio runtime; \
157                 use AsyncExampleSelector::select_async for explicit await"
158                    .into(),
159            )
160        })?;
161        tokio::task::block_in_place(|| handle.block_on(self.select_async(input, examples)))
162    }
163}
164
165// ---------------------------------------------------------------------------
166// MmrExampleSelector
167// ---------------------------------------------------------------------------
168
169/// Maximal Marginal Relevance selector: trades relevance to the input
170/// against novelty among already-selected examples. `lambda` controls
171/// the trade-off — `1.0` is pure similarity (equivalent to the semantic
172/// selector); `0.0` is pure diversity.
173pub struct MmrExampleSelector<E> {
174    embeddings: Arc<dyn Embeddings>,
175    k: usize,
176    lambda: f32,
177    distance: Distance,
178    text_of: ExampleTextFn<E>,
179}
180
181impl<E> MmrExampleSelector<E>
182where
183    E: Send + Sync + 'static,
184{
185    /// Build with `lambda` clamped to `[0, 1]`. `k` is the number of
186    /// examples returned.
187    pub fn new<F>(embeddings: Arc<dyn Embeddings>, k: usize, lambda: f32, text_of: F) -> Self
188    where
189        F: Fn(&E) -> String + Send + Sync + 'static,
190    {
191        Self {
192            embeddings,
193            k,
194            lambda: lambda.clamp(0.0, 1.0),
195            distance: Distance::Cosine,
196            text_of: Arc::new(text_of),
197        }
198    }
199
200    /// Override the distance metric.
201    pub fn with_distance(mut self, d: Distance) -> Self {
202        self.distance = d;
203        self
204    }
205
206    async fn select_async(&self, input: &str, examples: &[E]) -> Result<Vec<E>>
207    where
208        E: Clone,
209    {
210        if examples.is_empty() {
211            return Ok(Vec::new());
212        }
213        let q = self.embeddings.embed_query(input.to_string()).await?;
214        let texts: Vec<String> = examples.iter().map(|e| (self.text_of)(e)).collect();
215        let pool_vecs = self.embeddings.embed_documents(texts).await?;
216        let n = examples.len();
217        let take = self.k.min(n);
218        let mut chosen: Vec<usize> = Vec::with_capacity(take);
219        let mut available: Vec<usize> = (0..n).collect();
220
221        for _ in 0..take {
222            let mut best_idx: Option<usize> = None;
223            let mut best_score = f32::NEG_INFINITY;
224            for &i in &available {
225                let sim_to_query = self.distance.similarity(&q, &pool_vecs[i]);
226                let max_sim_to_chosen = chosen
227                    .iter()
228                    .map(|&j| self.distance.similarity(&pool_vecs[i], &pool_vecs[j]))
229                    .fold(f32::NEG_INFINITY, f32::max);
230                let novelty = if chosen.is_empty() {
231                    0.0
232                } else {
233                    max_sim_to_chosen
234                };
235                let score = self.lambda * sim_to_query - (1.0 - self.lambda) * novelty;
236                if score > best_score {
237                    best_score = score;
238                    best_idx = Some(i);
239                }
240            }
241            let pick = match best_idx {
242                Some(i) => i,
243                None => break,
244            };
245            chosen.push(pick);
246            available.retain(|&i| i != pick);
247        }
248
249        Ok(chosen.into_iter().map(|i| examples[i].clone()).collect())
250    }
251}
252
253#[async_trait]
254impl<E> AsyncExampleSelector<E> for MmrExampleSelector<E>
255where
256    E: Clone + Send + Sync + 'static,
257{
258    async fn select_async(&self, input: &str, examples: &[E]) -> Result<Vec<E>> {
259        MmrExampleSelector::select_async(self, input, examples).await
260    }
261}
262
263impl<E> ExampleSelector<E> for MmrExampleSelector<E>
264where
265    E: Clone + Send + Sync + 'static,
266{
267    fn select(&self, input: &str, examples: &[E]) -> Result<Vec<E>> {
268        let handle = tokio::runtime::Handle::try_current().map_err(|_| {
269            CognisError::Configuration(
270                "MmrExampleSelector::select called outside a tokio runtime; \
271                 use AsyncExampleSelector::select_async for explicit await"
272                    .into(),
273            )
274        })?;
275        tokio::task::block_in_place(|| handle.block_on(self.select_async(input, examples)))
276    }
277}
278
279// ---------------------------------------------------------------------------
280// AsyncExampleSelector — explicit-async parallel trait.
281// ---------------------------------------------------------------------------
282
283/// Async-first variant of [`cognis_core::prompts::ExampleSelector`].
284/// Use when you want to call selection from within an async context
285/// without going through `block_in_place`.
286#[async_trait]
287pub trait AsyncExampleSelector<E>: Send + Sync
288where
289    E: Send + Sync + 'static,
290{
291    /// Select examples to include for `input`.
292    async fn select_async(&self, input: &str, examples: &[E]) -> Result<Vec<E>>;
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use crate::embeddings::FakeEmbeddings;
299
300    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
301    async fn semantic_selector_picks_topk() {
302        // FakeEmbeddings is deterministic — the same text → the same vector.
303        let embeddings: Arc<dyn Embeddings> = Arc::new(FakeEmbeddings::new(8));
304        let sel =
305            SemanticSimilarityExampleSelector::new(embeddings.clone(), 2, |s: &String| s.clone());
306        let pool: Vec<String> = vec![
307            "completely different".into(),
308            "rust programming".into(),
309            "python programming".into(),
310        ];
311        let picked = sel.select_async("rust programming", &pool).await.unwrap();
312        assert_eq!(picked.len(), 2);
313        // The exact match should always be in the top 2.
314        assert!(picked.iter().any(|s| s == "rust programming"));
315    }
316
317    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
318    async fn semantic_selector_handles_empty_pool() {
319        let embeddings: Arc<dyn Embeddings> = Arc::new(FakeEmbeddings::new(4));
320        let sel = SemanticSimilarityExampleSelector::new(embeddings, 3, |s: &String| s.clone());
321        let picked = sel.select_async("anything", &[]).await.unwrap();
322        assert!(picked.is_empty());
323    }
324
325    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
326    async fn mmr_selector_returns_k_distinct_picks() {
327        let embeddings: Arc<dyn Embeddings> = Arc::new(FakeEmbeddings::new(8));
328        let sel = MmrExampleSelector::new(embeddings, 2, 0.5, |s: &String| s.clone());
329        let pool: Vec<String> = vec!["a".into(), "b".into(), "c".into()];
330        let picked = sel.select_async("query", &pool).await.unwrap();
331        assert_eq!(picked.len(), 2);
332        assert_ne!(picked[0], picked[1]);
333    }
334
335    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
336    async fn semantic_selector_caches_pool_embeddings() {
337        let embeddings: Arc<dyn Embeddings> = Arc::new(FakeEmbeddings::new(4));
338        let sel = SemanticSimilarityExampleSelector::new(embeddings, 1, |s: &String| s.clone())
339            .with_embed_mode(EmbedMode::Cached);
340        let pool: Vec<String> = vec!["one".into(), "two".into()];
341        // First call populates cache.
342        let _ = sel.select_async("one", &pool).await.unwrap();
343        // Second call must reuse cache (we just check it doesn't error).
344        let picked = sel.select_async("one", &pool).await.unwrap();
345        assert_eq!(picked.len(), 1);
346    }
347
348    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
349    async fn semantic_selector_sync_select_works_in_runtime() {
350        let embeddings: Arc<dyn Embeddings> = Arc::new(FakeEmbeddings::new(4));
351        let sel = SemanticSimilarityExampleSelector::new(embeddings, 2, |s: &String| s.clone());
352        let pool: Vec<String> = vec!["a".into(), "b".into(), "c".into()];
353        // Call the sync trait method from inside a tokio runtime.
354        let picked = ExampleSelector::select(&sel, "a", &pool).unwrap();
355        assert_eq!(picked.len(), 2);
356    }
357}