use std::collections::HashMap;
use std::fmt;
use std::str::FromStr;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq)]
pub enum OptimizationError {
InvalidSignature(String),
InvalidFieldType(String),
MissingComponent(String),
OptimizationFailed(String),
InsufficientData { required: usize, provided: usize },
}
impl fmt::Display for OptimizationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidSignature(s) => write!(f, "Invalid signature: {s}"),
Self::InvalidFieldType(s) => write!(f, "Invalid field type: {s}"),
Self::MissingComponent(s) => write!(f, "Missing component: {s}"),
Self::OptimizationFailed(s) => write!(f, "Optimization failed: {s}"),
Self::InsufficientData { required, provided } => {
write!(
f,
"Insufficient training data: required {required}, provided {provided}"
)
}
}
}
}
impl std::error::Error for OptimizationError {}
pub type Result<T> = std::result::Result<T, OptimizationError>;
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum FieldType {
#[default]
String,
Integer,
Float,
Boolean,
Json,
List(Box<FieldType>),
}
impl fmt::Display for FieldType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::String => write!(f, "str"),
Self::Integer => write!(f, "int"),
Self::Float => write!(f, "float"),
Self::Boolean => write!(f, "bool"),
Self::Json => write!(f, "json"),
Self::List(inner) => {
if **inner == FieldType::String {
write!(f, "list")
} else {
write!(f, "list[{inner}]")
}
}
}
}
}
impl FromStr for FieldType {
type Err = OptimizationError;
fn from_str(s: &str) -> Result<Self> {
let s = s.trim().to_lowercase();
if s.starts_with("list[") && s.ends_with(']') {
let inner = &s[5..s.len() - 1];
let inner_type = inner.parse::<FieldType>()?;
return Ok(Self::List(Box::new(inner_type)));
}
match s.as_str() {
"str" | "string" | "text" => Ok(Self::String),
"int" | "integer" => Ok(Self::Integer),
"float" | "number" | "decimal" => Ok(Self::Float),
"bool" | "boolean" => Ok(Self::Boolean),
"json" | "object" | "dict" => Ok(Self::Json),
"list" | "array" => Ok(Self::List(Box::new(FieldType::String))),
_ => Err(OptimizationError::InvalidFieldType(s)),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SignatureField {
pub name: String,
pub field_type: FieldType,
pub prefix: String,
pub description: Option<String>,
}
impl SignatureField {
pub fn new(name: impl Into<String>, field_type: FieldType) -> Self {
let name = name.into();
let prefix = Self::generate_prefix(&name);
Self {
name,
field_type,
prefix,
description: None,
}
}
pub fn string(name: impl Into<String>) -> Self {
Self::new(name, FieldType::String)
}
#[must_use]
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
#[must_use]
pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
self.prefix = prefix.into();
self
}
fn generate_prefix(name: &str) -> String {
let words: Vec<&str> = name.split('_').collect();
let titled: Vec<String> = words
.iter()
.map(|w| {
let mut chars = w.chars();
match chars.next() {
Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
None => String::new(),
}
})
.collect();
format!("{}:", titled.join(" "))
}
pub fn parse(spec: &str) -> Result<Self> {
let spec = spec.trim();
if let Some((name, type_str)) = spec.split_once(':') {
let name = name.trim();
let field_type = type_str.trim().parse::<FieldType>()?;
Ok(Self::new(name, field_type))
} else {
Ok(Self::string(spec))
}
}
}
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct Signature {
pub instructions: String,
pub input_fields: Vec<SignatureField>,
pub output_fields: Vec<SignatureField>,
}
impl Signature {
pub fn new() -> Self {
Self::default()
}
pub fn parse(spec: &str) -> Result<Self> {
let spec = spec.trim();
let parts: Vec<&str> = spec.split("->").collect();
if parts.len() != 2 {
return Err(OptimizationError::InvalidSignature(format!(
"Expected 'inputs -> outputs' format, got: {spec}"
)));
}
let input_spec = parts[0].trim();
let output_spec = parts[1].trim();
let input_fields = Self::parse_field_list(input_spec)?;
if input_fields.is_empty() {
return Err(OptimizationError::MissingComponent(
"At least one input field is required".to_string(),
));
}
let output_fields = Self::parse_field_list(output_spec)?;
if output_fields.is_empty() {
return Err(OptimizationError::MissingComponent(
"At least one output field is required".to_string(),
));
}
Ok(Self {
instructions: String::new(),
input_fields,
output_fields,
})
}
fn parse_field_list(spec: &str) -> Result<Vec<SignatureField>> {
if spec.is_empty() {
return Ok(Vec::new());
}
spec.split(',')
.map(|s| SignatureField::parse(s.trim()))
.collect()
}
#[must_use]
pub fn with_instructions(mut self, instructions: impl Into<String>) -> Self {
self.instructions = instructions.into();
self
}
#[must_use]
pub fn with_input(mut self, field: SignatureField) -> Self {
self.input_fields.push(field);
self
}
#[must_use]
pub fn with_output(mut self, field: SignatureField) -> Self {
self.output_fields.push(field);
self
}
pub fn input_names(&self) -> Vec<&str> {
self.input_fields.iter().map(|f| f.name.as_str()).collect()
}
pub fn output_names(&self) -> Vec<&str> {
self.output_fields.iter().map(|f| f.name.as_str()).collect()
}
pub fn to_prompt_template(&self) -> String {
let mut parts = Vec::new();
if !self.instructions.is_empty() {
parts.push(self.instructions.clone());
parts.push(String::new()); }
for field in &self.input_fields {
parts.push(format!("{} {{{}}}", field.prefix, field.name));
}
for field in &self.output_fields {
parts.push(field.prefix.to_string());
}
parts.join("\n")
}
pub fn to_spec_string(&self) -> String {
let inputs: Vec<String> = self
.input_fields
.iter()
.map(|f| {
if f.field_type == FieldType::String {
f.name.clone()
} else {
format!("{}:{}", f.name, f.field_type)
}
})
.collect();
let outputs: Vec<String> = self
.output_fields
.iter()
.map(|f| {
if f.field_type == FieldType::String {
f.name.clone()
} else {
format!("{}:{}", f.name, f.field_type)
}
})
.collect();
format!("{} -> {}", inputs.join(", "), outputs.join(", "))
}
}
impl fmt::Display for Signature {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_spec_string())
}
}
impl FromStr for Signature {
type Err = OptimizationError;
fn from_str(s: &str) -> Result<Self> {
Self::parse(s)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizerConfig {
pub num_trials: usize,
pub minibatch_size: usize,
pub max_bootstrapped_demos: usize,
pub max_labeled_demos: usize,
pub num_instruction_candidates: usize,
pub seed: Option<u64>,
pub verbose: bool,
pub early_stopping_patience: Option<usize>,
pub min_improvement_threshold: f64,
}
impl Default for OptimizerConfig {
fn default() -> Self {
Self {
num_trials: 20,
minibatch_size: 25,
max_bootstrapped_demos: 4,
max_labeled_demos: 16,
num_instruction_candidates: 10,
seed: None,
verbose: false,
early_stopping_patience: Some(5),
min_improvement_threshold: 0.01,
}
}
}
impl OptimizerConfig {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_num_trials(mut self, num_trials: usize) -> Self {
self.num_trials = num_trials;
self
}
#[must_use]
pub fn with_minibatch_size(mut self, minibatch_size: usize) -> Self {
self.minibatch_size = minibatch_size;
self
}
#[must_use]
pub fn with_max_bootstrapped_demos(mut self, max_demos: usize) -> Self {
self.max_bootstrapped_demos = max_demos;
self
}
#[must_use]
pub fn with_max_labeled_demos(mut self, max_demos: usize) -> Self {
self.max_labeled_demos = max_demos;
self
}
#[must_use]
pub fn with_num_instruction_candidates(mut self, num_candidates: usize) -> Self {
self.num_instruction_candidates = num_candidates;
self
}
#[must_use]
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
#[must_use]
pub fn with_verbose(mut self, verbose: bool) -> Self {
self.verbose = verbose;
self
}
#[must_use]
pub fn with_early_stopping(mut self, patience: usize) -> Self {
self.early_stopping_patience = Some(patience);
self
}
#[must_use]
pub fn without_early_stopping(mut self) -> Self {
self.early_stopping_patience = None;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrialResult {
pub trial_number: usize,
pub prompt: String,
pub instruction: String,
pub num_demos: usize,
pub score: f64,
pub duration_ms: u64,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct OptimizationMetrics {
pub total_duration_ms: u64,
pub trials_completed: usize,
pub llm_calls: usize,
pub tokens_used: usize,
pub improvement: f64,
pub early_stopped: bool,
pub score_history: Vec<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizationResult {
pub best_prompt: String,
pub best_instruction: String,
pub best_demos: Vec<Demonstration>,
pub best_score: f64,
pub trials_history: Vec<TrialResult>,
pub metrics: OptimizationMetrics,
pub signature: Signature,
pub config: OptimizerConfig,
}
impl OptimizationResult {
pub fn best_trial(&self) -> Option<&TrialResult> {
self.trials_history
.iter()
.max_by(|a, b| a.score.partial_cmp(&b.score).unwrap())
}
pub fn improvement(&self) -> f64 {
self.metrics.improvement
}
pub fn converged(&self) -> bool {
self.metrics.early_stopped || self.metrics.trials_completed == self.config.num_trials
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Demonstration {
pub inputs: HashMap<String, String>,
pub outputs: HashMap<String, String>,
pub source: Option<String>,
pub quality_score: Option<f64>,
}
impl Demonstration {
pub fn new() -> Self {
Self {
inputs: HashMap::new(),
outputs: HashMap::new(),
source: None,
quality_score: None,
}
}
#[must_use]
pub fn with_input(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.inputs.insert(name.into(), value.into());
self
}
#[must_use]
pub fn with_output(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.outputs.insert(name.into(), value.into());
self
}
#[must_use]
pub fn with_source(mut self, source: impl Into<String>) -> Self {
self.source = Some(source.into());
self
}
pub fn format(&self, signature: &Signature) -> String {
let mut parts = Vec::new();
for field in &signature.input_fields {
if let Some(value) = self.inputs.get(&field.name) {
parts.push(format!("{} {}", field.prefix, value));
}
}
for field in &signature.output_fields {
if let Some(value) = self.outputs.get(&field.name) {
parts.push(format!("{} {}", field.prefix, value));
}
}
parts.join("\n")
}
}
impl Default for Demonstration {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingExample {
pub inputs: HashMap<String, String>,
pub expected_outputs: HashMap<String, String>,
pub weight: Option<f64>,
}
impl TrainingExample {
pub fn new() -> Self {
Self {
inputs: HashMap::new(),
expected_outputs: HashMap::new(),
weight: None,
}
}
#[must_use]
pub fn with_input(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.inputs.insert(name.into(), value.into());
self
}
#[must_use]
pub fn with_expected(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.expected_outputs.insert(name.into(), value.into());
self
}
#[must_use]
pub fn with_weight(mut self, weight: f64) -> Self {
self.weight = Some(weight);
self
}
}
impl Default for TrainingExample {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct PromptOptimizer {
config: OptimizerConfig,
}
impl PromptOptimizer {
pub fn new(config: OptimizerConfig) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self::new(OptimizerConfig::default())
}
pub fn config(&self) -> &OptimizerConfig {
&self.config
}
fn validate_training_data(
&self,
signature: &Signature,
data: &[TrainingExample],
) -> Result<()> {
let min_required = self.config.minibatch_size;
if data.len() < min_required {
return Err(OptimizationError::InsufficientData {
required: min_required,
provided: data.len(),
});
}
for (i, example) in data.iter().enumerate() {
for field in &signature.input_fields {
if !example.inputs.contains_key(&field.name) {
return Err(OptimizationError::InvalidSignature(format!(
"Training example {i} missing input field '{}'",
field.name
)));
}
}
for field in &signature.output_fields {
if !example.expected_outputs.contains_key(&field.name) {
return Err(OptimizationError::InvalidSignature(format!(
"Training example {i} missing expected output field '{}'",
field.name
)));
}
}
}
Ok(())
}
fn generate_instruction_candidates(&self, signature: &Signature) -> Vec<String> {
let base_instruction = if signature.instructions.is_empty() {
format!(
"Given the input fields {}, produce the output fields {}.",
signature.input_names().join(", "),
signature.output_names().join(", ")
)
} else {
signature.instructions.clone()
};
vec![
base_instruction.clone(),
format!("You are an expert assistant. {base_instruction}"),
format!("{base_instruction} Be concise and accurate."),
format!("{base_instruction} Think step by step."),
format!("Task: {base_instruction}\nProvide your response in the specified format."),
]
}
fn select_demonstrations(
&self,
_signature: &Signature,
data: &[TrainingExample],
max_demos: usize,
) -> Vec<Demonstration> {
data.iter()
.take(max_demos)
.map(|ex| {
let mut demo = Demonstration::new().with_source("labeled");
for (k, v) in &ex.inputs {
demo.inputs.insert(k.clone(), v.clone());
}
for (k, v) in &ex.expected_outputs {
demo.outputs.insert(k.clone(), v.clone());
}
demo
})
.collect()
}
fn build_prompt(
&self,
signature: &Signature,
instruction: &str,
demos: &[Demonstration],
) -> String {
let mut parts = Vec::new();
parts.push(instruction.to_string());
parts.push(String::new());
if !demos.is_empty() {
parts.push("---".to_string());
for (i, demo) in demos.iter().enumerate() {
if i > 0 {
parts.push(String::new());
}
parts.push(demo.format(signature));
}
parts.push("---".to_string());
parts.push(String::new());
}
for field in &signature.input_fields {
parts.push(format!("{} {{{}}}", field.prefix, field.name));
}
for field in &signature.output_fields {
parts.push(field.prefix.to_string());
}
parts.join("\n")
}
fn evaluate_prompt(
&self,
_prompt: &str,
_signature: &Signature,
_data: &[TrainingExample],
) -> f64 {
0.5
}
pub fn optimize(
&self,
signature: &Signature,
training_data: &[TrainingExample],
) -> Result<OptimizationResult> {
use std::time::Instant;
self.validate_training_data(signature, training_data)?;
let start_time = Instant::now();
let mut trials_history = Vec::new();
let mut best_score = f64::NEG_INFINITY;
let mut best_prompt = String::new();
let mut best_instruction = String::new();
let mut best_demos = Vec::new();
let mut score_history = Vec::new();
let mut trials_without_improvement = 0;
let instructions = self.generate_instruction_candidates(signature);
let all_demos =
self.select_demonstrations(signature, training_data, self.config.max_labeled_demos);
for trial in 0..self.config.num_trials {
let trial_start = Instant::now();
let instruction = &instructions[trial % instructions.len()];
let num_demos = (trial % (self.config.max_bootstrapped_demos + 1)).min(all_demos.len());
let demos: Vec<Demonstration> = all_demos.iter().take(num_demos).cloned().collect();
let prompt = self.build_prompt(signature, instruction, &demos);
let score = self.evaluate_prompt(&prompt, signature, training_data);
let trial_duration = trial_start.elapsed().as_millis() as u64;
let trial_result = TrialResult {
trial_number: trial,
prompt: prompt.clone(),
instruction: instruction.clone(),
num_demos,
score,
duration_ms: trial_duration,
metadata: HashMap::new(),
};
trials_history.push(trial_result);
score_history.push(score);
if score > best_score + self.config.min_improvement_threshold {
best_score = score;
best_prompt = prompt;
best_instruction = instruction.clone();
best_demos = demos;
trials_without_improvement = 0;
if self.config.verbose {
eprintln!(
"[Trial {}/{}] New best score: {:.4}",
trial + 1,
self.config.num_trials,
best_score
);
}
} else {
trials_without_improvement += 1;
}
if let Some(patience) = self.config.early_stopping_patience {
if trials_without_improvement >= patience {
if self.config.verbose {
eprintln!(
"[Trial {}/{}] Early stopping: no improvement for {} trials",
trial + 1,
self.config.num_trials,
patience
);
}
break;
}
}
}
let total_duration = start_time.elapsed().as_millis() as u64;
let initial_score = score_history.first().copied().unwrap_or(0.0);
let metrics = OptimizationMetrics {
total_duration_ms: total_duration,
trials_completed: trials_history.len(),
llm_calls: trials_history.len(), tokens_used: 0, improvement: best_score - initial_score,
early_stopped: trials_without_improvement
>= self.config.early_stopping_patience.unwrap_or(usize::MAX),
score_history,
};
Ok(OptimizationResult {
best_prompt,
best_instruction,
best_demos,
best_score,
trials_history,
metrics,
signature: signature.clone(),
config: self.config.clone(),
})
}
}
impl Default for PromptOptimizer {
fn default() -> Self {
Self::with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_field_type_parsing() {
assert_eq!("str".parse::<FieldType>().unwrap(), FieldType::String);
assert_eq!("int".parse::<FieldType>().unwrap(), FieldType::Integer);
assert_eq!("float".parse::<FieldType>().unwrap(), FieldType::Float);
assert_eq!("bool".parse::<FieldType>().unwrap(), FieldType::Boolean);
assert_eq!("json".parse::<FieldType>().unwrap(), FieldType::Json);
assert_eq!(
"list".parse::<FieldType>().unwrap(),
FieldType::List(Box::new(FieldType::String))
);
assert_eq!(
"list[int]".parse::<FieldType>().unwrap(),
FieldType::List(Box::new(FieldType::Integer))
);
}
#[test]
fn test_field_type_display() {
assert_eq!(FieldType::String.to_string(), "str");
assert_eq!(FieldType::Integer.to_string(), "int");
assert_eq!(
FieldType::List(Box::new(FieldType::String)).to_string(),
"list"
);
assert_eq!(
FieldType::List(Box::new(FieldType::Float)).to_string(),
"list[float]"
);
}
#[test]
fn test_signature_field_creation() {
let field = SignatureField::new("user_query", FieldType::String);
assert_eq!(field.name, "user_query");
assert_eq!(field.prefix, "User Query:");
assert_eq!(field.field_type, FieldType::String);
assert!(field.description.is_none());
}
#[test]
fn test_signature_field_parsing() {
let field = SignatureField::parse("query:str").unwrap();
assert_eq!(field.name, "query");
assert_eq!(field.field_type, FieldType::String);
let field = SignatureField::parse("count:int").unwrap();
assert_eq!(field.name, "count");
assert_eq!(field.field_type, FieldType::Integer);
let field = SignatureField::parse("items:list[str]").unwrap();
assert_eq!(field.name, "items");
assert_eq!(
field.field_type,
FieldType::List(Box::new(FieldType::String))
);
}
#[test]
fn test_simple_signature_parsing() {
let sig = Signature::parse("question -> answer").unwrap();
assert_eq!(sig.input_fields.len(), 1);
assert_eq!(sig.output_fields.len(), 1);
assert_eq!(sig.input_fields[0].name, "question");
assert_eq!(sig.output_fields[0].name, "answer");
}
#[test]
fn test_multi_field_signature_parsing() {
let sig = Signature::parse("question, context -> answer").unwrap();
assert_eq!(sig.input_fields.len(), 2);
assert_eq!(sig.output_fields.len(), 1);
assert_eq!(sig.input_fields[0].name, "question");
assert_eq!(sig.input_fields[1].name, "context");
}
#[test]
fn test_typed_signature_parsing() {
let sig =
Signature::parse("query:str, docs:list -> summary:str, confidence:float").unwrap();
assert_eq!(sig.input_fields.len(), 2);
assert_eq!(sig.output_fields.len(), 2);
assert_eq!(sig.input_fields[0].field_type, FieldType::String);
assert_eq!(
sig.input_fields[1].field_type,
FieldType::List(Box::new(FieldType::String))
);
assert_eq!(sig.output_fields[0].field_type, FieldType::String);
assert_eq!(sig.output_fields[1].field_type, FieldType::Float);
}
#[test]
fn test_signature_spec_string() {
let sig = Signature::parse("question, context -> answer").unwrap();
assert_eq!(sig.to_spec_string(), "question, context -> answer");
let sig = Signature::parse("query:str -> summary:str, score:float").unwrap();
assert_eq!(sig.to_spec_string(), "query -> summary, score:float");
}
#[test]
fn test_signature_with_instructions() {
let sig = Signature::parse("question -> answer")
.unwrap()
.with_instructions("Answer the question accurately and concisely.");
assert_eq!(
sig.instructions,
"Answer the question accurately and concisely."
);
}
#[test]
fn test_invalid_signature() {
assert!(Signature::parse("no arrow here").is_err());
assert!(Signature::parse("-> only output").is_err());
assert!(Signature::parse("only input ->").is_err());
}
#[test]
fn test_optimizer_config_defaults() {
let config = OptimizerConfig::default();
assert_eq!(config.num_trials, 20);
assert_eq!(config.minibatch_size, 25);
assert_eq!(config.max_bootstrapped_demos, 4);
assert_eq!(config.max_labeled_demos, 16);
assert_eq!(config.num_instruction_candidates, 10);
}
#[test]
fn test_optimizer_config_builder() {
let config = OptimizerConfig::default()
.with_num_trials(50)
.with_minibatch_size(32)
.with_seed(42)
.with_verbose(true);
assert_eq!(config.num_trials, 50);
assert_eq!(config.minibatch_size, 32);
assert_eq!(config.seed, Some(42));
assert!(config.verbose);
}
#[test]
fn test_demonstration_creation() {
let demo = Demonstration::new()
.with_input("question", "What is Rust?")
.with_output("answer", "A systems programming language.")
.with_source("labeled");
assert_eq!(demo.inputs.get("question").unwrap(), "What is Rust?");
assert_eq!(
demo.outputs.get("answer").unwrap(),
"A systems programming language."
);
assert_eq!(demo.source, Some("labeled".to_string()));
}
#[test]
fn test_training_example_creation() {
let example = TrainingExample::new()
.with_input("question", "What is 2+2?")
.with_expected("answer", "4")
.with_weight(1.0);
assert_eq!(example.inputs.get("question").unwrap(), "What is 2+2?");
assert_eq!(example.expected_outputs.get("answer").unwrap(), "4");
assert_eq!(example.weight, Some(1.0));
}
#[test]
fn test_prompt_optimizer_creation() {
let optimizer = PromptOptimizer::with_defaults();
assert_eq!(optimizer.config().num_trials, 20);
let custom_config = OptimizerConfig::default().with_num_trials(100);
let optimizer = PromptOptimizer::new(custom_config);
assert_eq!(optimizer.config().num_trials, 100);
}
#[test]
fn test_prompt_optimizer_validation() {
let optimizer = PromptOptimizer::with_defaults();
let signature = Signature::parse("question -> answer").unwrap();
let result = optimizer.optimize(&signature, &[]);
assert!(result.is_err());
let training_data: Vec<TrainingExample> = (0..30)
.map(|i| {
TrainingExample::new()
.with_input("question", format!("Question {i}"))
.with_expected("answer", format!("Answer {i}"))
})
.collect();
let result = optimizer.optimize(&signature, &training_data);
assert!(result.is_ok());
}
#[test]
fn test_demonstration_formatting() {
let sig = Signature::parse("question -> answer").unwrap();
let demo = Demonstration::new()
.with_input("question", "What is Rust?")
.with_output("answer", "A programming language.");
let formatted = demo.format(&sig);
assert!(formatted.contains("Question:"));
assert!(formatted.contains("What is Rust?"));
assert!(formatted.contains("Answer:"));
assert!(formatted.contains("A programming language."));
}
#[test]
fn test_signature_prompt_template() {
let sig = Signature::parse("question, context -> answer")
.unwrap()
.with_instructions("Answer based on the context.");
let template = sig.to_prompt_template();
assert!(template.contains("Answer based on the context."));
assert!(template.contains("Question:"));
assert!(template.contains("Context:"));
assert!(template.contains("Answer:"));
}
}