use crate::tuner::{error::Result, types::{Evaluator, Sample}};
pub fn model_output(
prompt: &str,
sample: &Sample,
llm: &dyn Fn(&str) -> Result<String>,
) -> Result<String> {
let full = format!("{}\n\nInput: {}\nOutput:", prompt, sample.input);
llm(&full)
}
pub fn evaluate_prompt(
prompt: &str,
dataset: &[Sample],
evaluator: &dyn Evaluator,
llm: &dyn Fn(&str) -> Result<String>,
) -> f32 {
let scores: Vec<f32> = dataset
.iter()
.filter_map(|sample| {
let output = model_output(prompt, sample, llm).ok()?;
evaluator.score(prompt, sample, &output)
})
.collect();
if scores.is_empty() {
return f32::NEG_INFINITY;
}
scores.iter().sum::<f32>() / scores.len() as f32
}
pub fn generate_variant(
base_prompt: &str,
llm: &dyn Fn(&str) -> Result<String>,
) -> Result<String> {
let prompt = format!(
"Rewrite the following prompt to improve its clarity and performance.\n\
Prompt:\n{}\n\nImproved prompt:",
base_prompt
);
llm(&prompt)
}
pub fn generate_variants(
base_prompt: &str,
count: usize,
llm: &dyn Fn(&str) -> Result<String>,
) -> Vec<String> {
let mut variants = vec![base_prompt.to_string()];
for _ in 0..count {
if let Ok(v) = generate_variant(base_prompt, llm)
&& !variants.contains(&v) {
variants.push(v);
}
}
variants
}
pub fn build_demo_sets(dataset: &[Sample], count: usize, size: usize, seed: u64) -> Vec<Vec<Sample>> {
let size = size.min(dataset.len());
let mut rng = seed.wrapping_add(1);
(0..count)
.map(|_| {
let mut pool: Vec<usize> = (0..dataset.len()).collect();
for i in 0..size {
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
let j = i + (rng >> 33) as usize % (dataset.len() - i);
pool.swap(i, j);
}
pool[..size].iter().map(|&i| dataset[i].clone()).collect()
})
.collect()
}
pub fn format_demo(sample: &Sample) -> String {
match &sample.expected_output {
Some(expected) => format!("Input: {}\nOutput: {}", sample.input, expected),
None => format!("Input: {}", sample.input),
}
}
pub fn assemble_prompt(instruction: &str, demos: &[Sample]) -> String {
let demo_text: Vec<String> = demos.iter().map(format_demo).collect();
let parts: Vec<&str> = std::iter::once(instruction)
.chain(demo_text.iter().map(String::as_str))
.filter(|s| !s.is_empty())
.collect();
parts.join("\n\n")
}
pub fn select_best(candidates: &[(String, f32)]) -> Option<&(String, f32)> {
candidates.iter().max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tuner::types::{ExactMatchEvaluator, Sample};
fn echo_llm(prompt: &str) -> Result<String> {
Ok(prompt.split_whitespace().last().unwrap_or("").to_string())
}
fn samples() -> Vec<Sample> {
vec![
Sample::new("Q1").with_output("A1"),
Sample::new("Q2").with_output("A2"),
]
}
#[test]
fn assemble_prompt_no_demos() {
let p = assemble_prompt("Do the thing", &[]);
assert_eq!(p, "Do the thing");
}
#[test]
fn assemble_prompt_with_demos() {
let demos = vec![Sample::new("hi").with_output("hello")];
let p = assemble_prompt("Instruction", &demos);
assert!(p.contains("Instruction"));
assert!(p.contains("Input: hi"));
assert!(p.contains("Output: hello"));
}
#[test]
fn format_demo_with_output() {
let s = Sample::new("q").with_output("a");
assert_eq!(format_demo(&s), "Input: q\nOutput: a");
}
#[test]
fn format_demo_without_output() {
let s = Sample::new("q");
assert_eq!(format_demo(&s), "Input: q");
}
#[test]
fn build_demo_sets_correct_count_and_size() {
let data = samples();
let sets = build_demo_sets(&data, 3, 1, 42);
assert_eq!(sets.len(), 3);
assert!(sets.iter().all(|s| s.len() == 1));
}
#[test]
fn build_demo_sets_size_capped_at_dataset_len() {
let data = samples();
let sets = build_demo_sets(&data, 2, 100, 1);
assert!(sets.iter().all(|s| s.len() <= data.len()));
}
#[test]
fn select_best_picks_highest_score() {
let candidates = vec![
("a".to_string(), 0.5_f32),
("b".to_string(), 0.9_f32),
("c".to_string(), 0.3_f32),
];
let best = select_best(&candidates).unwrap();
assert_eq!(best.0, "b");
}
#[test]
fn select_best_empty_returns_none() {
assert!(select_best(&[]).is_none());
}
#[test]
fn evaluate_prompt_returns_mean() {
let eval = ExactMatchEvaluator;
let score = evaluate_prompt("p", &samples(), &eval, &echo_llm);
assert!(score.is_finite() || score == f32::NEG_INFINITY);
}
#[test]
fn generate_variants_includes_original() {
let llm = |_: &str| -> Result<String> { Ok("improved".to_string()) };
let variants = generate_variants("base", 2, &llm);
assert!(variants.contains(&"base".to_string()));
}
#[test]
fn generate_variants_deduplicates() {
let llm = |_: &str| -> Result<String> { Ok("same variant".to_string()) };
let variants = generate_variants("base", 5, &llm);
assert_eq!(variants.len(), 2);
}
}