1use std::sync::Arc;
18
19use async_trait::async_trait;
20
21use cognis_core::prompts::ExampleSelector;
22use cognis_core::{CognisError, Result};
23
24use crate::distance::Distance;
25use crate::embeddings::Embeddings;
26
27pub type ExampleTextFn<E> = Arc<dyn Fn(&E) -> String + Send + Sync>;
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum EmbedMode {
35 Fresh,
38 Cached,
41}
42
43pub struct SemanticSimilarityExampleSelector<E> {
50 embeddings: Arc<dyn Embeddings>,
51 k: usize,
52 distance: Distance,
53 text_of: ExampleTextFn<E>,
54 mode: EmbedMode,
55 cache: Arc<tokio::sync::Mutex<Option<Vec<Vec<f32>>>>>,
58}
59
60impl<E> SemanticSimilarityExampleSelector<E>
61where
62 E: Send + Sync + 'static,
63{
64 pub fn new<F>(embeddings: Arc<dyn Embeddings>, k: usize, text_of: F) -> Self
66 where
67 F: Fn(&E) -> String + Send + Sync + 'static,
68 {
69 Self {
70 embeddings,
71 k,
72 distance: Distance::Cosine,
73 text_of: Arc::new(text_of),
74 mode: EmbedMode::Cached,
75 cache: Arc::new(tokio::sync::Mutex::new(None)),
76 }
77 }
78
79 pub fn with_distance(mut self, d: Distance) -> Self {
81 self.distance = d;
82 self
83 }
84
85 pub fn with_embed_mode(mut self, m: EmbedMode) -> Self {
87 self.mode = m;
88 self.cache = Arc::new(tokio::sync::Mutex::new(None));
90 self
91 }
92
93 async fn embed_pool(&self, examples: &[E]) -> Result<Vec<Vec<f32>>> {
95 if matches!(self.mode, EmbedMode::Cached) {
96 let mut guard = self.cache.lock().await;
97 if let Some(cached) = guard.as_ref() {
98 if cached.len() == examples.len() {
99 return Ok(cached.clone());
100 }
101 }
102 let texts: Vec<String> = examples.iter().map(|e| (self.text_of)(e)).collect();
103 let vecs = self.embeddings.embed_documents(texts).await?;
104 *guard = Some(vecs.clone());
105 return Ok(vecs);
106 }
107 let texts: Vec<String> = examples.iter().map(|e| (self.text_of)(e)).collect();
108 self.embeddings.embed_documents(texts).await
109 }
110
111 async fn select_async(&self, input: &str, examples: &[E]) -> Result<Vec<E>>
114 where
115 E: Clone,
116 {
117 if examples.is_empty() {
118 return Ok(Vec::new());
119 }
120 let q = self.embeddings.embed_query(input.to_string()).await?;
121 let pool_vecs = self.embed_pool(examples).await?;
122 let mut scored: Vec<(usize, f32)> = pool_vecs
123 .iter()
124 .enumerate()
125 .map(|(i, v)| (i, self.distance.similarity(&q, v)))
126 .collect();
127 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
128 Ok(scored
129 .into_iter()
130 .take(self.k.min(examples.len()))
131 .map(|(i, _)| examples[i].clone())
132 .collect())
133 }
134}
135
136#[async_trait]
137impl<E> AsyncExampleSelector<E> for SemanticSimilarityExampleSelector<E>
138where
139 E: Clone + Send + Sync + 'static,
140{
141 async fn select_async(&self, input: &str, examples: &[E]) -> Result<Vec<E>> {
142 SemanticSimilarityExampleSelector::select_async(self, input, examples).await
143 }
144}
145
146impl<E> ExampleSelector<E> for SemanticSimilarityExampleSelector<E>
147where
148 E: Clone + Send + Sync + 'static,
149{
150 fn select(&self, input: &str, examples: &[E]) -> Result<Vec<E>> {
151 let handle = tokio::runtime::Handle::try_current().map_err(|_| {
155 CognisError::Configuration(
156 "SemanticSimilarityExampleSelector::select called outside a tokio runtime; \
157 use AsyncExampleSelector::select_async for explicit await"
158 .into(),
159 )
160 })?;
161 tokio::task::block_in_place(|| handle.block_on(self.select_async(input, examples)))
162 }
163}
164
165pub struct MmrExampleSelector<E> {
174 embeddings: Arc<dyn Embeddings>,
175 k: usize,
176 lambda: f32,
177 distance: Distance,
178 text_of: ExampleTextFn<E>,
179}
180
181impl<E> MmrExampleSelector<E>
182where
183 E: Send + Sync + 'static,
184{
185 pub fn new<F>(embeddings: Arc<dyn Embeddings>, k: usize, lambda: f32, text_of: F) -> Self
188 where
189 F: Fn(&E) -> String + Send + Sync + 'static,
190 {
191 Self {
192 embeddings,
193 k,
194 lambda: lambda.clamp(0.0, 1.0),
195 distance: Distance::Cosine,
196 text_of: Arc::new(text_of),
197 }
198 }
199
200 pub fn with_distance(mut self, d: Distance) -> Self {
202 self.distance = d;
203 self
204 }
205
206 async fn select_async(&self, input: &str, examples: &[E]) -> Result<Vec<E>>
207 where
208 E: Clone,
209 {
210 if examples.is_empty() {
211 return Ok(Vec::new());
212 }
213 let q = self.embeddings.embed_query(input.to_string()).await?;
214 let texts: Vec<String> = examples.iter().map(|e| (self.text_of)(e)).collect();
215 let pool_vecs = self.embeddings.embed_documents(texts).await?;
216 let n = examples.len();
217 let take = self.k.min(n);
218 let mut chosen: Vec<usize> = Vec::with_capacity(take);
219 let mut available: Vec<usize> = (0..n).collect();
220
221 for _ in 0..take {
222 let mut best_idx: Option<usize> = None;
223 let mut best_score = f32::NEG_INFINITY;
224 for &i in &available {
225 let sim_to_query = self.distance.similarity(&q, &pool_vecs[i]);
226 let max_sim_to_chosen = chosen
227 .iter()
228 .map(|&j| self.distance.similarity(&pool_vecs[i], &pool_vecs[j]))
229 .fold(f32::NEG_INFINITY, f32::max);
230 let novelty = if chosen.is_empty() {
231 0.0
232 } else {
233 max_sim_to_chosen
234 };
235 let score = self.lambda * sim_to_query - (1.0 - self.lambda) * novelty;
236 if score > best_score {
237 best_score = score;
238 best_idx = Some(i);
239 }
240 }
241 let pick = match best_idx {
242 Some(i) => i,
243 None => break,
244 };
245 chosen.push(pick);
246 available.retain(|&i| i != pick);
247 }
248
249 Ok(chosen.into_iter().map(|i| examples[i].clone()).collect())
250 }
251}
252
253#[async_trait]
254impl<E> AsyncExampleSelector<E> for MmrExampleSelector<E>
255where
256 E: Clone + Send + Sync + 'static,
257{
258 async fn select_async(&self, input: &str, examples: &[E]) -> Result<Vec<E>> {
259 MmrExampleSelector::select_async(self, input, examples).await
260 }
261}
262
263impl<E> ExampleSelector<E> for MmrExampleSelector<E>
264where
265 E: Clone + Send + Sync + 'static,
266{
267 fn select(&self, input: &str, examples: &[E]) -> Result<Vec<E>> {
268 let handle = tokio::runtime::Handle::try_current().map_err(|_| {
269 CognisError::Configuration(
270 "MmrExampleSelector::select called outside a tokio runtime; \
271 use AsyncExampleSelector::select_async for explicit await"
272 .into(),
273 )
274 })?;
275 tokio::task::block_in_place(|| handle.block_on(self.select_async(input, examples)))
276 }
277}
278
279#[async_trait]
287pub trait AsyncExampleSelector<E>: Send + Sync
288where
289 E: Send + Sync + 'static,
290{
291 async fn select_async(&self, input: &str, examples: &[E]) -> Result<Vec<E>>;
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use crate::embeddings::FakeEmbeddings;
299
300 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
301 async fn semantic_selector_picks_topk() {
302 let embeddings: Arc<dyn Embeddings> = Arc::new(FakeEmbeddings::new(8));
304 let sel =
305 SemanticSimilarityExampleSelector::new(embeddings.clone(), 2, |s: &String| s.clone());
306 let pool: Vec<String> = vec![
307 "completely different".into(),
308 "rust programming".into(),
309 "python programming".into(),
310 ];
311 let picked = sel.select_async("rust programming", &pool).await.unwrap();
312 assert_eq!(picked.len(), 2);
313 assert!(picked.iter().any(|s| s == "rust programming"));
315 }
316
317 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
318 async fn semantic_selector_handles_empty_pool() {
319 let embeddings: Arc<dyn Embeddings> = Arc::new(FakeEmbeddings::new(4));
320 let sel = SemanticSimilarityExampleSelector::new(embeddings, 3, |s: &String| s.clone());
321 let picked = sel.select_async("anything", &[]).await.unwrap();
322 assert!(picked.is_empty());
323 }
324
325 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
326 async fn mmr_selector_returns_k_distinct_picks() {
327 let embeddings: Arc<dyn Embeddings> = Arc::new(FakeEmbeddings::new(8));
328 let sel = MmrExampleSelector::new(embeddings, 2, 0.5, |s: &String| s.clone());
329 let pool: Vec<String> = vec!["a".into(), "b".into(), "c".into()];
330 let picked = sel.select_async("query", &pool).await.unwrap();
331 assert_eq!(picked.len(), 2);
332 assert_ne!(picked[0], picked[1]);
333 }
334
335 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
336 async fn semantic_selector_caches_pool_embeddings() {
337 let embeddings: Arc<dyn Embeddings> = Arc::new(FakeEmbeddings::new(4));
338 let sel = SemanticSimilarityExampleSelector::new(embeddings, 1, |s: &String| s.clone())
339 .with_embed_mode(EmbedMode::Cached);
340 let pool: Vec<String> = vec!["one".into(), "two".into()];
341 let _ = sel.select_async("one", &pool).await.unwrap();
343 let picked = sel.select_async("one", &pool).await.unwrap();
345 assert_eq!(picked.len(), 1);
346 }
347
348 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
349 async fn semantic_selector_sync_select_works_in_runtime() {
350 let embeddings: Arc<dyn Embeddings> = Arc::new(FakeEmbeddings::new(4));
351 let sel = SemanticSimilarityExampleSelector::new(embeddings, 2, |s: &String| s.clone());
352 let pool: Vec<String> = vec!["a".into(), "b".into(), "c".into()];
353 let picked = ExampleSelector::select(&sel, "a", &pool).unwrap();
355 assert_eq!(picked.len(), 2);
356 }
357}