cognis_core/prompts/
example_selector.rs1use std::sync::Arc;
20
21use crate::tokenizer::Tokenizer;
22use crate::{CognisError, Result};
23
24#[allow(clippy::ptr_arg)] pub trait ExampleSelector<E>: Send + Sync
30where
31 E: Send + Sync + 'static,
32{
33 fn select(&self, input: &str, examples: &[E]) -> Result<Vec<E>>;
37}
38
39#[derive(Debug, Clone, Default)]
47pub struct StaticExampleSelector {
48 max: Option<usize>,
49}
50
51impl StaticExampleSelector {
52 pub fn all() -> Self {
54 Self { max: None }
55 }
56
57 pub fn at_most(n: usize) -> Self {
59 Self { max: Some(n) }
60 }
61}
62
63impl<E> ExampleSelector<E> for StaticExampleSelector
64where
65 E: Clone + Send + Sync + 'static,
66{
67 fn select(&self, _input: &str, examples: &[E]) -> Result<Vec<E>> {
68 Ok(match self.max {
69 Some(n) => examples.iter().take(n).cloned().collect(),
70 None => examples.to_vec(),
71 })
72 }
73}
74
75pub type ExampleRenderFn<E> = Arc<dyn Fn(&E) -> String + Send + Sync>;
84
85#[derive(Clone)]
90pub struct LengthBasedExampleSelector<E> {
91 max_tokens: usize,
92 tokenizer: Arc<dyn Tokenizer>,
93 render: ExampleRenderFn<E>,
94}
95
96impl<E> std::fmt::Debug for LengthBasedExampleSelector<E> {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 f.debug_struct("LengthBasedExampleSelector")
99 .field("max_tokens", &self.max_tokens)
100 .finish()
101 }
102}
103
104impl<E> LengthBasedExampleSelector<E>
105where
106 E: Send + Sync + 'static,
107{
108 pub fn new<F>(max_tokens: usize, tokenizer: Arc<dyn Tokenizer>, render: F) -> Self
113 where
114 F: Fn(&E) -> String + Send + Sync + 'static,
115 {
116 Self {
117 max_tokens,
118 tokenizer,
119 render: Arc::new(render),
120 }
121 }
122
123 pub fn with_render<F>(mut self, render: F) -> Self
125 where
126 F: Fn(&E) -> String + Send + Sync + 'static,
127 {
128 self.render = Arc::new(render);
129 self
130 }
131
132 pub fn with_tokenizer(mut self, tokenizer: Arc<dyn Tokenizer>) -> Self {
134 self.tokenizer = tokenizer;
135 self
136 }
137}
138
139impl<E> ExampleSelector<E> for LengthBasedExampleSelector<E>
140where
141 E: Clone + Send + Sync + 'static,
142{
143 fn select(&self, input: &str, examples: &[E]) -> Result<Vec<E>> {
144 let mut budget = self
145 .max_tokens
146 .checked_sub(self.tokenizer.count(input))
147 .ok_or_else(|| {
148 CognisError::Configuration(
149 "LengthBasedExampleSelector: input alone exceeds max_tokens".into(),
150 )
151 })?;
152 let mut out = Vec::new();
153 for ex in examples {
154 let cost = self.tokenizer.count(&(self.render)(ex));
155 if cost > budget {
156 break;
157 }
158 budget -= cost;
159 out.push(ex.clone());
160 }
161 Ok(out)
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::tokenizer::CharTokenizer;
169
170 #[test]
171 fn static_selector_returns_all_by_default() {
172 let s = StaticExampleSelector::all();
173 let pool = vec!["a", "b", "c"];
174 let out: Vec<&str> = ExampleSelector::select(&s, "ignored", &pool).unwrap();
175 assert_eq!(out, pool);
176 }
177
178 #[test]
179 fn static_selector_caps_at_most() {
180 let s = StaticExampleSelector::at_most(2);
181 let pool = vec!["a", "b", "c"];
182 let out: Vec<&str> = ExampleSelector::select(&s, "ignored", &pool).unwrap();
183 assert_eq!(out, vec!["a", "b"]);
184 }
185
186 #[test]
187 fn length_based_stops_at_budget() {
188 let tokenizer: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
189 let sel: LengthBasedExampleSelector<String> =
190 LengthBasedExampleSelector::new(20, tokenizer, |s: &String| s.clone());
191 let pool = vec![
192 "five5".to_string(), "five5".to_string(), "five5".to_string(), "ovrflw".to_string(), ];
197 let picked = sel.select("input", &pool).unwrap();
198 assert_eq!(picked.len(), 3);
201 }
202
203 #[test]
204 fn length_based_rejects_input_alone_too_big() {
205 let tokenizer: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
206 let sel: LengthBasedExampleSelector<String> =
207 LengthBasedExampleSelector::new(3, tokenizer, |s: &String| s.clone());
208 let err = sel.select("longer-than-budget", &[]).unwrap_err();
209 assert!(matches!(err, CognisError::Configuration(_)));
210 }
211
212 #[test]
213 fn length_based_with_custom_renderer() {
214 let tokenizer: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
216 let sel: LengthBasedExampleSelector<String> =
217 LengthBasedExampleSelector::new(10, tokenizer, |s: &String| s.clone() + s);
218 let pool = vec![
219 "ab".to_string(), "ab".to_string(), "abc".to_string(), ];
223 let picked = sel.select("", &pool).unwrap();
224 assert_eq!(picked.len(), 2);
225 }
226}