use crate::recursive::llm::Llm;
use crate::recursive::shared;
use crate::recursive::validate::Validate;
#[derive(Debug, Clone)]
pub struct TrainingExample {
pub input: String,
pub expected: String,
pub label: Option<String>,
}
impl TrainingExample {
pub fn new(input: impl Into<String>, expected: impl Into<String>) -> Self {
Self {
input: input.into(),
expected: expected.into(),
label: None,
}
}
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
}
#[derive(Debug, Clone, Default)]
pub struct Dataset {
pub examples: Vec<TrainingExample>,
}
impl Dataset {
pub fn new() -> Self {
Self::default()
}
pub fn example(mut self, input: impl Into<String>, expected: impl Into<String>) -> Self {
self.examples.push(TrainingExample::new(input, expected));
self
}
pub fn labeled_example(
mut self,
input: impl Into<String>,
expected: impl Into<String>,
label: impl Into<String>,
) -> Self {
self.examples
.push(TrainingExample::new(input, expected).with_label(label));
self
}
pub fn len(&self) -> usize {
self.examples.len()
}
pub fn is_empty(&self) -> bool {
self.examples.is_empty()
}
}
pub type MetricFn = Box<dyn Fn(&str, &str) -> f64 + Send + Sync>;
#[derive(Debug, Clone)]
pub enum Strategy {
BootstrapFewShot {
max_examples: usize,
},
InstructionSearch {
num_candidates: usize,
},
Combined {
max_examples: usize,
num_candidates: usize,
},
}
#[derive(Debug, Clone)]
pub struct OptimizeResult {
pub prompt: String,
pub examples: Vec<TrainingExample>,
pub instruction: String,
pub score: f64,
pub evaluations: u32,
pub candidate_scores: Vec<f64>,
}
pub struct Optimizer<'a, L: Llm> {
llm: &'a L,
base_prompt: String,
dataset: Option<&'a Dataset>,
metric: Option<MetricFn>,
validator: Option<Box<dyn Validate>>,
strategy: Strategy,
}
impl<'a, L: Llm> Optimizer<'a, L> {
pub fn new(llm: &'a L, base_prompt: impl Into<String>) -> Self {
Self {
llm,
base_prompt: base_prompt.into(),
dataset: None,
metric: None,
validator: None,
strategy: Strategy::BootstrapFewShot { max_examples: 3 },
}
}
pub fn dataset(mut self, dataset: &'a Dataset) -> Self {
self.dataset = Some(dataset);
self
}
pub fn metric<F: Fn(&str, &str) -> f64 + Send + Sync + 'static>(mut self, f: F) -> Self {
self.metric = Some(Box::new(f));
self
}
pub fn metric_impl<M: crate::metric::Metric + 'static>(mut self, m: M) -> Self {
self.metric = Some(Box::new(move |pred: &str, expected: &str| {
m.evaluate(pred, expected)
}));
self
}
pub fn validate(mut self, v: impl Validate + 'static) -> Self {
self.validator = Some(Box::new(v));
self
}
pub fn strategy(mut self, strategy: Strategy) -> Self {
self.strategy = strategy;
self
}
pub fn go(self) -> OptimizeResult {
shared::block_on(self.run())
}
pub async fn run(self) -> OptimizeResult {
let dataset = match self.dataset {
Some(d) if !d.is_empty() => d,
_ => {
return OptimizeResult {
prompt: self.base_prompt.clone(),
examples: Vec::new(),
instruction: self.base_prompt,
score: 0.0,
evaluations: 0,
candidate_scores: Vec::new(),
};
}
};
match self.strategy.clone() {
Strategy::BootstrapFewShot { max_examples } => {
self.run_bootstrap(dataset, max_examples).await
}
Strategy::InstructionSearch { num_candidates } => {
self.run_instruction_search(dataset, num_candidates).await
}
Strategy::Combined {
max_examples,
num_candidates,
} => {
self.run_combined(dataset, max_examples, num_candidates)
.await
}
}
}
async fn evaluate(
&self,
instruction: &str,
examples: &[TrainingExample],
dataset: &Dataset,
) -> f64 {
let mut total_score = 0.0;
let mut count = 0;
let mut context = String::new();
for ex in examples {
context.push_str(&format!("Input: {}\nOutput: {}\n\n", ex.input, ex.expected));
}
for example in &dataset.examples {
let prompt = format!("{}\n\nInput: {}", instruction, example.input);
let output = match self.llm.generate(&prompt, &context, None).await {
Ok(out) => out.text,
Err(_) => continue,
};
let score = self.score_output(&output, &example.expected);
total_score += score;
count += 1;
}
if count > 0 {
total_score / count as f64
} else {
0.0
}
}
fn score_output(&self, output: &str, expected: &str) -> f64 {
let mut score = 0.0;
let mut components = 0;
if let Some(ref metric) = self.metric {
score += metric(output, expected);
components += 1;
}
if let Some(ref validator) = self.validator {
score += validator.validate(output).value;
components += 1;
}
if components == 0 {
return if output.contains(expected) { 1.0 } else { 0.0 };
}
score / components as f64
}
async fn run_bootstrap(&self, dataset: &Dataset, max_examples: usize) -> OptimizeResult {
let mut evaluations = 0u32;
let mut candidate_scores = Vec::new();
let mut scored_examples: Vec<(f64, &TrainingExample)> = Vec::new();
for example in &dataset.examples {
let prompt = format!("{}\n\nInput: {}", self.base_prompt, example.input);
if let Ok(output) = self.llm.generate(&prompt, "", None).await {
let score = self.score_output(&output.text, &example.expected);
scored_examples.push((score, example));
evaluations += 1;
}
}
scored_examples.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let selected: Vec<TrainingExample> = scored_examples
.iter()
.take(max_examples)
.map(|(_, ex)| (*ex).clone())
.collect();
let base_score = self.evaluate(&self.base_prompt, &[], dataset).await;
evaluations += dataset.len() as u32;
candidate_scores.push(base_score);
let few_shot_score = self.evaluate(&self.base_prompt, &selected, dataset).await;
evaluations += dataset.len() as u32;
candidate_scores.push(few_shot_score);
let (final_examples, final_score) = if few_shot_score > base_score {
(selected, few_shot_score)
} else {
(Vec::new(), base_score)
};
let mut optimized_prompt = self.base_prompt.clone();
if !final_examples.is_empty() {
optimized_prompt.push_str("\n\nExamples:");
for ex in &final_examples {
optimized_prompt
.push_str(&format!("\nInput: {}\nOutput: {}", ex.input, ex.expected));
}
}
OptimizeResult {
prompt: optimized_prompt,
examples: final_examples,
instruction: self.base_prompt.clone(),
score: final_score,
evaluations,
candidate_scores,
}
}
async fn run_instruction_search(
&self,
dataset: &Dataset,
num_candidates: usize,
) -> OptimizeResult {
let mut evaluations = 0u32;
let mut candidate_scores = Vec::new();
let mut best_instruction = self.base_prompt.clone();
let base_score = self.evaluate(&self.base_prompt, &[], dataset).await;
evaluations += dataset.len() as u32;
candidate_scores.push(base_score);
let mut best_score = base_score;
let meta_prompt = format!(
"Generate {} different instruction phrasings for this task. \
Each instruction should be a complete, self-contained prompt \
that guides an AI to perform the task well.\n\n\
Original instruction: {}\n\n\
Format each candidate on its own line, prefixed with a number:\n\
1. [instruction]\n2. [instruction]\netc.",
num_candidates, self.base_prompt
);
let candidates = match self.llm.generate(&meta_prompt, "", None).await {
Ok(output) => parse_numbered_list(&output.text),
Err(_) => Vec::new(),
};
evaluations += 1;
for candidate in &candidates {
let score = self.evaluate(candidate, &[], dataset).await;
evaluations += dataset.len() as u32;
candidate_scores.push(score);
if score > best_score {
best_score = score;
best_instruction = candidate.clone();
}
}
OptimizeResult {
prompt: best_instruction.clone(),
examples: Vec::new(),
instruction: best_instruction,
score: best_score,
evaluations,
candidate_scores,
}
}
async fn run_combined(
&self,
dataset: &Dataset,
max_examples: usize,
num_candidates: usize,
) -> OptimizeResult {
let bootstrap_result = self.run_bootstrap(dataset, max_examples).await;
let mut evaluations = bootstrap_result.evaluations;
let mut candidate_scores = bootstrap_result.candidate_scores;
let mut best_instruction = bootstrap_result.instruction.clone();
let mut best_score = bootstrap_result.score;
let best_examples = bootstrap_result.examples;
let meta_prompt = format!(
"Generate {} different instruction phrasings for this task. \
Each should be a complete prompt.\n\n\
Original: {}\n\n\
Format: 1. [instruction]",
num_candidates, self.base_prompt
);
let candidates = match self.llm.generate(&meta_prompt, "", None).await {
Ok(output) => parse_numbered_list(&output.text),
Err(_) => Vec::new(),
};
evaluations += 1;
for candidate in &candidates {
let score = self.evaluate(candidate, &best_examples, dataset).await;
evaluations += dataset.len() as u32;
candidate_scores.push(score);
if score > best_score {
best_score = score;
best_instruction = candidate.clone();
}
}
let mut optimized_prompt = best_instruction.clone();
if !best_examples.is_empty() {
optimized_prompt.push_str("\n\nExamples:");
for ex in &best_examples {
optimized_prompt
.push_str(&format!("\nInput: {}\nOutput: {}", ex.input, ex.expected));
}
}
OptimizeResult {
prompt: optimized_prompt,
examples: best_examples,
instruction: best_instruction,
score: best_score,
evaluations,
candidate_scores,
}
}
}
fn parse_numbered_list(text: &str) -> Vec<String> {
let mut results = Vec::new();
for line in text.lines() {
let trimmed = line.trim();
if let Some(rest) = trimmed
.strip_prefix(|c: char| c.is_ascii_digit())
.and_then(|s| s.strip_prefix(|c: char| c == '.' || c == ')'))
.map(|s| s.trim())
{
if !rest.is_empty() {
results.push(rest.to_string());
}
}
}
results
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dataset_builder() {
let ds = Dataset::new()
.example("What is 2+2?", "4")
.example("What is 3+3?", "6")
.labeled_example("Capital?", "Paris", "geography");
assert_eq!(ds.len(), 3);
assert!(!ds.is_empty());
assert_eq!(ds.examples[2].label.as_deref(), Some("geography"));
}
#[test]
fn test_training_example() {
let ex = TrainingExample::new("input", "output").with_label("test");
assert_eq!(ex.input, "input");
assert_eq!(ex.expected, "output");
assert_eq!(ex.label.as_deref(), Some("test"));
}
#[test]
fn test_parse_numbered_list() {
let text = "1. First instruction\n2. Second instruction\n3. Third one";
let results = parse_numbered_list(text);
assert_eq!(results.len(), 3);
assert_eq!(results[0], "First instruction");
assert_eq!(results[1], "Second instruction");
assert_eq!(results[2], "Third one");
}
#[test]
fn test_parse_numbered_list_with_parentheses() {
let text = "1) First\n2) Second";
let results = parse_numbered_list(text);
assert_eq!(results.len(), 2);
assert_eq!(results[0], "First");
}
#[test]
fn test_empty_dataset() {
let ds = Dataset::new();
assert!(ds.is_empty());
assert_eq!(ds.len(), 0);
}
#[test]
fn test_score_output_contains() {
use crate::recursive::llm::MockLlm;
let llm = MockLlm::new(|_, _| String::new());
let opt = Optimizer::new(&llm, "test");
assert!((opt.score_output("The answer is 42", "42") - 1.0).abs() < f64::EPSILON);
assert!((opt.score_output("The answer is 41", "42") - 0.0).abs() < f64::EPSILON);
}
}