use std::sync::Arc;
use async_trait::async_trait;
use super::optimizer::{AgentOptimizer, OptimizerError, OptimizerResult};
use super::sampler::Sampler;
use crate::llm::BaseLlm;
#[derive(Debug, Clone)]
pub struct SimplePromptOptimizerConfig {
pub num_iterations: usize,
pub batch_size: usize,
}
impl Default for SimplePromptOptimizerConfig {
fn default() -> Self {
Self {
num_iterations: 10,
batch_size: 5,
}
}
}
pub struct SimplePromptOptimizer {
optimizer_llm: Arc<dyn BaseLlm>,
sampler: Arc<dyn Sampler>,
config: SimplePromptOptimizerConfig,
}
impl SimplePromptOptimizer {
pub fn new(
optimizer_llm: Arc<dyn BaseLlm>,
sampler: Arc<dyn Sampler>,
config: SimplePromptOptimizerConfig,
) -> Self {
Self {
optimizer_llm,
sampler,
config,
}
}
async fn generate_candidate(
&self,
current_instruction: &str,
current_score: f64,
) -> Result<String, OptimizerError> {
let prompt = format!(
"You are an expert prompt engineer. Your task is to improve the following \
agent instruction to achieve better performance.\n\n\
Current instruction (score: {current_score:.2}):\n\
---\n{current_instruction}\n---\n\n\
Generate an improved version of the instruction. Focus on:\n\
- Clarity and specificity\n\
- Better task decomposition guidance\n\
- More effective tool use instructions\n\
- Appropriate constraints and guardrails\n\n\
Respond with ONLY the improved instruction text, nothing else."
);
let request = crate::llm::LlmRequest::from_text(&prompt);
let response = self
.optimizer_llm
.generate(request)
.await
.map_err(|e| OptimizerError::Llm(e.to_string()))?;
Ok(response.text())
}
}
#[async_trait]
impl AgentOptimizer for SimplePromptOptimizer {
async fn optimize(
&self,
initial_instruction: &str,
model_id: &str,
) -> Result<OptimizerResult, OptimizerError> {
let mut best_instruction = initial_instruction.to_string();
let mut score_history = Vec::new();
let training_batch = self.sampler.sample_training(self.config.batch_size).await?;
let mut best_score = self
.sampler
.score(&best_instruction, model_id, &training_batch.cases)
.await?;
score_history.push((0, best_score));
for iteration in 1..=self.config.num_iterations {
let candidate = self
.generate_candidate(&best_instruction, best_score)
.await?;
let training_batch = self.sampler.sample_training(self.config.batch_size).await?;
let candidate_score = self
.sampler
.score(&candidate, model_id, &training_batch.cases)
.await?;
score_history.push((iteration, candidate_score));
if candidate_score > best_score {
best_instruction = candidate;
best_score = candidate_score;
}
}
let validation = self.sampler.validation_set().await?;
let validation_score = self
.sampler
.score(&best_instruction, model_id, &validation.cases)
.await?;
Ok(OptimizerResult {
best_instruction,
best_score: validation_score,
iterations: self.config.num_iterations,
score_history,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config() {
let config = SimplePromptOptimizerConfig::default();
assert_eq!(config.num_iterations, 10);
assert_eq!(config.batch_size, 5);
}
}