Skip to main content

cognis_core/prompts/
example_selector.rs

1//! Few-shot example selection — pluggable strategies for picking a
2//! subset of examples to include in a prompt.
3//!
4//! [`ExampleSelector<E>`] is the core trait. Two stock implementations
5//! ship in `cognis-core`:
6//!
7//! - [`StaticExampleSelector`] — passes through a fixed list (useful as
8//!   a baseline / for tests).
9//! - [`LengthBasedExampleSelector`] — picks examples in order until a
10//!   token budget is exhausted, using any [`crate::tokenizer::Tokenizer`].
11//!
12//! Embedding-driven selectors (e.g. semantic similarity, MMR) live in
13//! `cognis-rag` because they require an `Embeddings` trait.
14//!
15//! Custom selectors plug in by implementing `ExampleSelector<E>`. The
16//! crate ships traits plus default impls; nothing about the selector
17//! contract is closed.
18
19use std::sync::Arc;
20
21use crate::tokenizer::Tokenizer;
22use crate::{CognisError, Result};
23
24/// Strategy for choosing which examples to include in a few-shot prompt.
25///
26/// Implementations receive an `input` value (typed) and the full pool of
27/// available examples, and return the examples to use, in render order.
28#[allow(clippy::ptr_arg)] // examples are stored as Vec to enable owned slicing
29pub trait ExampleSelector<E>: Send + Sync
30where
31    E: Send + Sync + 'static,
32{
33    /// Select examples from `examples` to include for `input`. Returns
34    /// owned clones (or constructs new examples) so the caller doesn't
35    /// have to manage lifetimes against the pool.
36    fn select(&self, input: &str, examples: &[E]) -> Result<Vec<E>>;
37}
38
39// ---------------------------------------------------------------------------
40// StaticExampleSelector
41// ---------------------------------------------------------------------------
42
43/// Trivial selector that always returns the entire example list,
44/// optionally truncated to `max`. Useful when a fixed shot count is
45/// desired without any dynamic logic.
46#[derive(Debug, Clone, Default)]
47pub struct StaticExampleSelector {
48    max: Option<usize>,
49}
50
51impl StaticExampleSelector {
52    /// Pass-through: every example is returned.
53    pub fn all() -> Self {
54        Self { max: None }
55    }
56
57    /// Cap the returned examples to at most `n`.
58    pub fn at_most(n: usize) -> Self {
59        Self { max: Some(n) }
60    }
61}
62
63impl<E> ExampleSelector<E> for StaticExampleSelector
64where
65    E: Clone + Send + Sync + 'static,
66{
67    fn select(&self, _input: &str, examples: &[E]) -> Result<Vec<E>> {
68        Ok(match self.max {
69            Some(n) => examples.iter().take(n).cloned().collect(),
70            None => examples.to_vec(),
71        })
72    }
73}
74
75// ---------------------------------------------------------------------------
76// LengthBasedExampleSelector
77// ---------------------------------------------------------------------------
78
79/// Function that converts an example into the text used for token
80/// counting. Often the same template used to render the example into
81/// the prompt; supplying it explicitly lets the selector budget against
82/// the actual prompt cost rather than a placeholder.
83pub type ExampleRenderFn<E> = Arc<dyn Fn(&E) -> String + Send + Sync>;
84
85/// Greedy: include each example in order until adding the next would
86/// exceed the configured `max_tokens` budget. The `input` itself is
87/// counted against the budget so callers know how much head-room is
88/// left for completion.
89#[derive(Clone)]
90pub struct LengthBasedExampleSelector<E> {
91    max_tokens: usize,
92    tokenizer: Arc<dyn Tokenizer>,
93    render: ExampleRenderFn<E>,
94}
95
96impl<E> std::fmt::Debug for LengthBasedExampleSelector<E> {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        f.debug_struct("LengthBasedExampleSelector")
99            .field("max_tokens", &self.max_tokens)
100            .finish()
101    }
102}
103
104impl<E> LengthBasedExampleSelector<E>
105where
106    E: Send + Sync + 'static,
107{
108    /// Build a length-budgeted selector.
109    ///
110    /// `render` converts an example to the text that will be substituted
111    /// into the prompt — used by the tokenizer to size each example.
112    pub fn new<F>(max_tokens: usize, tokenizer: Arc<dyn Tokenizer>, render: F) -> Self
113    where
114        F: Fn(&E) -> String + Send + Sync + 'static,
115    {
116        Self {
117            max_tokens,
118            tokenizer,
119            render: Arc::new(render),
120        }
121    }
122
123    /// Replace the render function (e.g. when the prompt template changes).
124    pub fn with_render<F>(mut self, render: F) -> Self
125    where
126        F: Fn(&E) -> String + Send + Sync + 'static,
127    {
128        self.render = Arc::new(render);
129        self
130    }
131
132    /// Replace the tokenizer (e.g. swap `CharTokenizer` for `tiktoken`).
133    pub fn with_tokenizer(mut self, tokenizer: Arc<dyn Tokenizer>) -> Self {
134        self.tokenizer = tokenizer;
135        self
136    }
137}
138
139impl<E> ExampleSelector<E> for LengthBasedExampleSelector<E>
140where
141    E: Clone + Send + Sync + 'static,
142{
143    fn select(&self, input: &str, examples: &[E]) -> Result<Vec<E>> {
144        let mut budget = self
145            .max_tokens
146            .checked_sub(self.tokenizer.count(input))
147            .ok_or_else(|| {
148                CognisError::Configuration(
149                    "LengthBasedExampleSelector: input alone exceeds max_tokens".into(),
150                )
151            })?;
152        let mut out = Vec::new();
153        for ex in examples {
154            let cost = self.tokenizer.count(&(self.render)(ex));
155            if cost > budget {
156                break;
157            }
158            budget -= cost;
159            out.push(ex.clone());
160        }
161        Ok(out)
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::tokenizer::CharTokenizer;
169
170    #[test]
171    fn static_selector_returns_all_by_default() {
172        let s = StaticExampleSelector::all();
173        let pool = vec!["a", "b", "c"];
174        let out: Vec<&str> = ExampleSelector::select(&s, "ignored", &pool).unwrap();
175        assert_eq!(out, pool);
176    }
177
178    #[test]
179    fn static_selector_caps_at_most() {
180        let s = StaticExampleSelector::at_most(2);
181        let pool = vec!["a", "b", "c"];
182        let out: Vec<&str> = ExampleSelector::select(&s, "ignored", &pool).unwrap();
183        assert_eq!(out, vec!["a", "b"]);
184    }
185
186    #[test]
187    fn length_based_stops_at_budget() {
188        let tokenizer: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
189        let sel: LengthBasedExampleSelector<String> =
190            LengthBasedExampleSelector::new(20, tokenizer, |s: &String| s.clone());
191        let pool = vec![
192            "five5".to_string(),  // 5 chars
193            "five5".to_string(),  // 5 chars (cumulative 10)
194            "five5".to_string(),  // 5 chars (cumulative 15) — input is 5 → budget 15
195            "ovrflw".to_string(), // 6 chars — would exceed
196        ];
197        let picked = sel.select("input", &pool).unwrap();
198        // Budget is 20 - 5(input) = 15. First three 5-char examples fit
199        // exactly; "ovrflw" is rejected.
200        assert_eq!(picked.len(), 3);
201    }
202
203    #[test]
204    fn length_based_rejects_input_alone_too_big() {
205        let tokenizer: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
206        let sel: LengthBasedExampleSelector<String> =
207            LengthBasedExampleSelector::new(3, tokenizer, |s: &String| s.clone());
208        let err = sel.select("longer-than-budget", &[]).unwrap_err();
209        assert!(matches!(err, CognisError::Configuration(_)));
210    }
211
212    #[test]
213    fn length_based_with_custom_renderer() {
214        // Render fn doubles the cost of each example.
215        let tokenizer: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
216        let sel: LengthBasedExampleSelector<String> =
217            LengthBasedExampleSelector::new(10, tokenizer, |s: &String| s.clone() + s);
218        let pool = vec![
219            "ab".to_string(),  // rendered as "abab" → 4
220            "ab".to_string(),  // 4
221            "abc".to_string(), // 6 — would push to 14, rejected (budget 10-input=10)
222        ];
223        let picked = sel.select("", &pool).unwrap();
224        assert_eq!(picked.len(), 2);
225    }
226}