use std::sync::Arc;
use async_trait::async_trait;
use cognis_core::prompts::ExampleSelector;
use cognis_core::{CognisError, Result};
use crate::distance::Distance;
use crate::embeddings::Embeddings;
pub type ExampleTextFn<E> = Arc<dyn Fn(&E) -> String + Send + Sync>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EmbedMode {
Fresh,
Cached,
}
pub struct SemanticSimilarityExampleSelector<E> {
embeddings: Arc<dyn Embeddings>,
k: usize,
distance: Distance,
text_of: ExampleTextFn<E>,
mode: EmbedMode,
cache: Arc<tokio::sync::Mutex<Option<Vec<Vec<f32>>>>>,
}
impl<E> SemanticSimilarityExampleSelector<E>
where
E: Send + Sync + 'static,
{
pub fn new<F>(embeddings: Arc<dyn Embeddings>, k: usize, text_of: F) -> Self
where
F: Fn(&E) -> String + Send + Sync + 'static,
{
Self {
embeddings,
k,
distance: Distance::Cosine,
text_of: Arc::new(text_of),
mode: EmbedMode::Cached,
cache: Arc::new(tokio::sync::Mutex::new(None)),
}
}
pub fn with_distance(mut self, d: Distance) -> Self {
self.distance = d;
self
}
pub fn with_embed_mode(mut self, m: EmbedMode) -> Self {
self.mode = m;
self.cache = Arc::new(tokio::sync::Mutex::new(None));
self
}
async fn embed_pool(&self, examples: &[E]) -> Result<Vec<Vec<f32>>> {
if matches!(self.mode, EmbedMode::Cached) {
let mut guard = self.cache.lock().await;
if let Some(cached) = guard.as_ref() {
if cached.len() == examples.len() {
return Ok(cached.clone());
}
}
let texts: Vec<String> = examples.iter().map(|e| (self.text_of)(e)).collect();
let vecs = self.embeddings.embed_documents(texts).await?;
*guard = Some(vecs.clone());
return Ok(vecs);
}
let texts: Vec<String> = examples.iter().map(|e| (self.text_of)(e)).collect();
self.embeddings.embed_documents(texts).await
}
async fn select_async(&self, input: &str, examples: &[E]) -> Result<Vec<E>>
where
E: Clone,
{
if examples.is_empty() {
return Ok(Vec::new());
}
let q = self.embeddings.embed_query(input.to_string()).await?;
let pool_vecs = self.embed_pool(examples).await?;
let mut scored: Vec<(usize, f32)> = pool_vecs
.iter()
.enumerate()
.map(|(i, v)| (i, self.distance.similarity(&q, v)))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(scored
.into_iter()
.take(self.k.min(examples.len()))
.map(|(i, _)| examples[i].clone())
.collect())
}
}
#[async_trait]
impl<E> AsyncExampleSelector<E> for SemanticSimilarityExampleSelector<E>
where
E: Clone + Send + Sync + 'static,
{
async fn select_async(&self, input: &str, examples: &[E]) -> Result<Vec<E>> {
SemanticSimilarityExampleSelector::select_async(self, input, examples).await
}
}
impl<E> ExampleSelector<E> for SemanticSimilarityExampleSelector<E>
where
E: Clone + Send + Sync + 'static,
{
fn select(&self, input: &str, examples: &[E]) -> Result<Vec<E>> {
let handle = tokio::runtime::Handle::try_current().map_err(|_| {
CognisError::Configuration(
"SemanticSimilarityExampleSelector::select called outside a tokio runtime; \
use AsyncExampleSelector::select_async for explicit await"
.into(),
)
})?;
tokio::task::block_in_place(|| handle.block_on(self.select_async(input, examples)))
}
}
pub struct MmrExampleSelector<E> {
embeddings: Arc<dyn Embeddings>,
k: usize,
lambda: f32,
distance: Distance,
text_of: ExampleTextFn<E>,
}
impl<E> MmrExampleSelector<E>
where
E: Send + Sync + 'static,
{
pub fn new<F>(embeddings: Arc<dyn Embeddings>, k: usize, lambda: f32, text_of: F) -> Self
where
F: Fn(&E) -> String + Send + Sync + 'static,
{
Self {
embeddings,
k,
lambda: lambda.clamp(0.0, 1.0),
distance: Distance::Cosine,
text_of: Arc::new(text_of),
}
}
pub fn with_distance(mut self, d: Distance) -> Self {
self.distance = d;
self
}
async fn select_async(&self, input: &str, examples: &[E]) -> Result<Vec<E>>
where
E: Clone,
{
if examples.is_empty() {
return Ok(Vec::new());
}
let q = self.embeddings.embed_query(input.to_string()).await?;
let texts: Vec<String> = examples.iter().map(|e| (self.text_of)(e)).collect();
let pool_vecs = self.embeddings.embed_documents(texts).await?;
let n = examples.len();
let take = self.k.min(n);
let mut chosen: Vec<usize> = Vec::with_capacity(take);
let mut available: Vec<usize> = (0..n).collect();
for _ in 0..take {
let mut best_idx: Option<usize> = None;
let mut best_score = f32::NEG_INFINITY;
for &i in &available {
let sim_to_query = self.distance.similarity(&q, &pool_vecs[i]);
let max_sim_to_chosen = chosen
.iter()
.map(|&j| self.distance.similarity(&pool_vecs[i], &pool_vecs[j]))
.fold(f32::NEG_INFINITY, f32::max);
let novelty = if chosen.is_empty() {
0.0
} else {
max_sim_to_chosen
};
let score = self.lambda * sim_to_query - (1.0 - self.lambda) * novelty;
if score > best_score {
best_score = score;
best_idx = Some(i);
}
}
let pick = match best_idx {
Some(i) => i,
None => break,
};
chosen.push(pick);
available.retain(|&i| i != pick);
}
Ok(chosen.into_iter().map(|i| examples[i].clone()).collect())
}
}
#[async_trait]
impl<E> AsyncExampleSelector<E> for MmrExampleSelector<E>
where
E: Clone + Send + Sync + 'static,
{
async fn select_async(&self, input: &str, examples: &[E]) -> Result<Vec<E>> {
MmrExampleSelector::select_async(self, input, examples).await
}
}
impl<E> ExampleSelector<E> for MmrExampleSelector<E>
where
E: Clone + Send + Sync + 'static,
{
fn select(&self, input: &str, examples: &[E]) -> Result<Vec<E>> {
let handle = tokio::runtime::Handle::try_current().map_err(|_| {
CognisError::Configuration(
"MmrExampleSelector::select called outside a tokio runtime; \
use AsyncExampleSelector::select_async for explicit await"
.into(),
)
})?;
tokio::task::block_in_place(|| handle.block_on(self.select_async(input, examples)))
}
}
#[async_trait]
pub trait AsyncExampleSelector<E>: Send + Sync
where
E: Send + Sync + 'static,
{
async fn select_async(&self, input: &str, examples: &[E]) -> Result<Vec<E>>;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::embeddings::FakeEmbeddings;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn semantic_selector_picks_topk() {
let embeddings: Arc<dyn Embeddings> = Arc::new(FakeEmbeddings::new(8));
let sel =
SemanticSimilarityExampleSelector::new(embeddings.clone(), 2, |s: &String| s.clone());
let pool: Vec<String> = vec![
"completely different".into(),
"rust programming".into(),
"python programming".into(),
];
let picked = sel.select_async("rust programming", &pool).await.unwrap();
assert_eq!(picked.len(), 2);
assert!(picked.iter().any(|s| s == "rust programming"));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn semantic_selector_handles_empty_pool() {
let embeddings: Arc<dyn Embeddings> = Arc::new(FakeEmbeddings::new(4));
let sel = SemanticSimilarityExampleSelector::new(embeddings, 3, |s: &String| s.clone());
let picked = sel.select_async("anything", &[]).await.unwrap();
assert!(picked.is_empty());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn mmr_selector_returns_k_distinct_picks() {
let embeddings: Arc<dyn Embeddings> = Arc::new(FakeEmbeddings::new(8));
let sel = MmrExampleSelector::new(embeddings, 2, 0.5, |s: &String| s.clone());
let pool: Vec<String> = vec!["a".into(), "b".into(), "c".into()];
let picked = sel.select_async("query", &pool).await.unwrap();
assert_eq!(picked.len(), 2);
assert_ne!(picked[0], picked[1]);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn semantic_selector_caches_pool_embeddings() {
let embeddings: Arc<dyn Embeddings> = Arc::new(FakeEmbeddings::new(4));
let sel = SemanticSimilarityExampleSelector::new(embeddings, 1, |s: &String| s.clone())
.with_embed_mode(EmbedMode::Cached);
let pool: Vec<String> = vec!["one".into(), "two".into()];
let _ = sel.select_async("one", &pool).await.unwrap();
let picked = sel.select_async("one", &pool).await.unwrap();
assert_eq!(picked.len(), 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn semantic_selector_sync_select_works_in_runtime() {
let embeddings: Arc<dyn Embeddings> = Arc::new(FakeEmbeddings::new(4));
let sel = SemanticSimilarityExampleSelector::new(embeddings, 2, |s: &String| s.clone());
let pool: Vec<String> = vec!["a".into(), "b".into(), "c".into()];
let picked = ExampleSelector::select(&sel, "a", &pool).unwrap();
assert_eq!(picked.len(), 2);
}
}