use std::sync::Arc;
use crate::tokenizer::Tokenizer;
use crate::{CognisError, Result};
#[allow(clippy::ptr_arg)] pub trait ExampleSelector<E>: Send + Sync
where
E: Send + Sync + 'static,
{
fn select(&self, input: &str, examples: &[E]) -> Result<Vec<E>>;
}
#[derive(Debug, Clone, Default)]
pub struct StaticExampleSelector {
max: Option<usize>,
}
impl StaticExampleSelector {
pub fn all() -> Self {
Self { max: None }
}
pub fn at_most(n: usize) -> Self {
Self { max: Some(n) }
}
}
impl<E> ExampleSelector<E> for StaticExampleSelector
where
E: Clone + Send + Sync + 'static,
{
fn select(&self, _input: &str, examples: &[E]) -> Result<Vec<E>> {
Ok(match self.max {
Some(n) => examples.iter().take(n).cloned().collect(),
None => examples.to_vec(),
})
}
}
pub type ExampleRenderFn<E> = Arc<dyn Fn(&E) -> String + Send + Sync>;
#[derive(Clone)]
pub struct LengthBasedExampleSelector<E> {
max_tokens: usize,
tokenizer: Arc<dyn Tokenizer>,
render: ExampleRenderFn<E>,
}
impl<E> std::fmt::Debug for LengthBasedExampleSelector<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LengthBasedExampleSelector")
.field("max_tokens", &self.max_tokens)
.finish()
}
}
impl<E> LengthBasedExampleSelector<E>
where
E: Send + Sync + 'static,
{
pub fn new<F>(max_tokens: usize, tokenizer: Arc<dyn Tokenizer>, render: F) -> Self
where
F: Fn(&E) -> String + Send + Sync + 'static,
{
Self {
max_tokens,
tokenizer,
render: Arc::new(render),
}
}
pub fn with_render<F>(mut self, render: F) -> Self
where
F: Fn(&E) -> String + Send + Sync + 'static,
{
self.render = Arc::new(render);
self
}
pub fn with_tokenizer(mut self, tokenizer: Arc<dyn Tokenizer>) -> Self {
self.tokenizer = tokenizer;
self
}
}
impl<E> ExampleSelector<E> for LengthBasedExampleSelector<E>
where
E: Clone + Send + Sync + 'static,
{
fn select(&self, input: &str, examples: &[E]) -> Result<Vec<E>> {
let mut budget = self
.max_tokens
.checked_sub(self.tokenizer.count(input))
.ok_or_else(|| {
CognisError::Configuration(
"LengthBasedExampleSelector: input alone exceeds max_tokens".into(),
)
})?;
let mut out = Vec::new();
for ex in examples {
let cost = self.tokenizer.count(&(self.render)(ex));
if cost > budget {
break;
}
budget -= cost;
out.push(ex.clone());
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizer::CharTokenizer;
#[test]
fn static_selector_returns_all_by_default() {
let s = StaticExampleSelector::all();
let pool = vec!["a", "b", "c"];
let out: Vec<&str> = ExampleSelector::select(&s, "ignored", &pool).unwrap();
assert_eq!(out, pool);
}
#[test]
fn static_selector_caps_at_most() {
let s = StaticExampleSelector::at_most(2);
let pool = vec!["a", "b", "c"];
let out: Vec<&str> = ExampleSelector::select(&s, "ignored", &pool).unwrap();
assert_eq!(out, vec!["a", "b"]);
}
#[test]
fn length_based_stops_at_budget() {
let tokenizer: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
let sel: LengthBasedExampleSelector<String> =
LengthBasedExampleSelector::new(20, tokenizer, |s: &String| s.clone());
let pool = vec![
"five5".to_string(), "five5".to_string(), "five5".to_string(), "ovrflw".to_string(), ];
let picked = sel.select("input", &pool).unwrap();
assert_eq!(picked.len(), 3);
}
#[test]
fn length_based_rejects_input_alone_too_big() {
let tokenizer: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
let sel: LengthBasedExampleSelector<String> =
LengthBasedExampleSelector::new(3, tokenizer, |s: &String| s.clone());
let err = sel.select("longer-than-budget", &[]).unwrap_err();
assert!(matches!(err, CognisError::Configuration(_)));
}
#[test]
fn length_based_with_custom_renderer() {
let tokenizer: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
let sel: LengthBasedExampleSelector<String> =
LengthBasedExampleSelector::new(10, tokenizer, |s: &String| s.clone() + s);
let pool = vec![
"ab".to_string(), "ab".to_string(), "abc".to_string(), ];
let picked = sel.select("", &pool).unwrap();
assert_eq!(picked.len(), 2);
}
}