Skip to main content

adk_eval/
optimizer.rs

1//! Prompt optimization engine.
2//!
3//! Iteratively improves an agent's system instructions using an optimizer LLM
4//! and an evaluation set. Used by the `adk optimize` CLI command.
5//!
6//! # Overview
7//!
8//! The [`PromptOptimizer`] runs an optimization loop:
9//! 1. Evaluate the agent against the eval set to get a baseline score
10//! 2. If the score already meets the target threshold, report "no optimization needed"
11//! 3. Otherwise, ask the optimizer LLM to propose improved instructions
12//! 4. Apply the best improvement and re-evaluate
13//! 5. Repeat until max iterations or target threshold is reached
14//! 6. Write the best-performing instructions to the output file
15//!
16//! # Example
17//!
18//! ```rust,ignore
19//! use adk_eval::optimizer::{PromptOptimizer, OptimizerConfig};
20//! use std::sync::Arc;
21//!
22//! let optimizer = PromptOptimizer::new(
23//!     optimizer_llm,
24//!     evaluator,
25//!     OptimizerConfig::default(),
26//! );
27//! let result = optimizer.optimize(agent, &eval_set).await?;
28//! println!("Final score: {}", result.final_score);
29//! ```
30
31use std::path::PathBuf;
32use std::sync::Arc;
33
34use futures::StreamExt;
35use tracing::{info, warn};
36
37use adk_core::types::Content;
38use adk_core::{Agent, Llm, LlmRequest};
39
40use crate::error::{EvalError, Result};
41use crate::evaluator::Evaluator;
42use crate::schema::EvalSet;
43
44/// Configuration for the prompt optimization loop.
45#[derive(Debug, Clone)]
46pub struct OptimizerConfig {
47    /// Maximum number of optimization iterations (default: 5).
48    pub max_iterations: u32,
49    /// Target evaluation score threshold (default: 0.9).
50    /// Optimization stops early if this score is reached.
51    pub target_threshold: f64,
52    /// Path to write the best-performing instructions.
53    pub output_path: PathBuf,
54}
55
56impl Default for OptimizerConfig {
57    fn default() -> Self {
58        Self {
59            max_iterations: 5,
60            target_threshold: 0.9,
61            output_path: PathBuf::from("optimized_instructions.txt"),
62        }
63    }
64}
65
66/// Result of a prompt optimization run.
67#[derive(Debug, Clone)]
68pub struct OptimizationResult {
69    /// Evaluation score before optimization.
70    pub initial_score: f64,
71    /// Best evaluation score achieved.
72    pub final_score: f64,
73    /// Number of iterations actually executed.
74    pub iterations_run: u32,
75    /// The best-performing system instructions.
76    pub best_instructions: String,
77}
78
79/// Iteratively improves an agent's system instructions using an optimizer LLM
80/// and an evaluation set.
81///
82/// The optimizer runs a loop of evaluate → propose improvements → apply best →
83/// repeat, logging progress via `tracing` at each iteration.
84pub struct PromptOptimizer {
85    optimizer_llm: Arc<dyn Llm>,
86    evaluator: Evaluator,
87    config: OptimizerConfig,
88}
89
90impl PromptOptimizer {
91    /// Create a new prompt optimizer.
92    ///
93    /// # Arguments
94    ///
95    /// * `optimizer_llm` - The LLM used to propose instruction improvements
96    ///   (separate from the agent's own LLM).
97    /// * `evaluator` - The evaluator used to score the agent against the eval set.
98    /// * `config` - Optimization configuration (max iterations, target threshold, output path).
99    pub fn new(optimizer_llm: Arc<dyn Llm>, evaluator: Evaluator, config: OptimizerConfig) -> Self {
100        Self { optimizer_llm, evaluator, config }
101    }
102
103    /// Run the optimization loop.
104    ///
105    /// Evaluates the agent, proposes improvements via the optimizer LLM,
106    /// applies the best improvement, and repeats until the target threshold
107    /// is met or max iterations are exhausted.
108    ///
109    /// On completion, writes the best-performing instructions to the configured
110    /// output file.
111    pub async fn optimize(
112        &self,
113        agent: Arc<dyn Agent>,
114        eval_set: &EvalSet,
115    ) -> Result<OptimizationResult> {
116        let base_path = std::path::Path::new(".");
117        let cases = eval_set.get_all_cases(base_path)?;
118
119        if cases.is_empty() {
120            return Err(EvalError::ConfigError("eval set contains no cases".to_string()));
121        }
122
123        // Get initial instructions from the agent
124        let mut current_instructions = agent.description().to_string();
125
126        // Run initial evaluation
127        let initial_score = self.evaluate_agent(agent.clone(), eval_set).await?;
128        info!(iteration = 0, score = initial_score, "initial evaluation complete");
129
130        // Check if initial score already meets threshold
131        if initial_score >= self.config.target_threshold {
132            info!(
133                score = initial_score,
134                threshold = self.config.target_threshold,
135                "no optimization needed — initial score meets target threshold"
136            );
137
138            self.write_output(&current_instructions)?;
139
140            return Ok(OptimizationResult {
141                initial_score,
142                final_score: initial_score,
143                iterations_run: 0,
144                best_instructions: current_instructions,
145            });
146        }
147
148        let mut best_score = initial_score;
149        let mut best_instructions = current_instructions.clone();
150        let mut iterations_run = 0;
151
152        for iteration in 1..=self.config.max_iterations {
153            iterations_run = iteration;
154
155            // Propose improved instructions via optimizer LLM
156            let proposed = self.propose_improvements(&current_instructions, best_score).await?;
157
158            info!(
159                iteration,
160                current_score = best_score,
161                proposed_changes = %proposed,
162                "proposed instruction improvements"
163            );
164
165            // Apply the proposed instructions
166            current_instructions = proposed.clone();
167
168            // Re-evaluate with the new instructions
169            let score = self.evaluate_agent(agent.clone(), eval_set).await?;
170
171            info!(iteration, score, previous_best = best_score, "evaluation complete");
172
173            if score > best_score {
174                best_score = score;
175                best_instructions = current_instructions.clone();
176            } else {
177                // Revert to best instructions if score didn't improve
178                warn!(
179                    iteration,
180                    score, best_score, "score did not improve, reverting to best instructions"
181                );
182                current_instructions = best_instructions.clone();
183            }
184
185            // Check if target threshold is met
186            if best_score >= self.config.target_threshold {
187                info!(
188                    iteration,
189                    score = best_score,
190                    threshold = self.config.target_threshold,
191                    "target threshold reached — stopping early"
192                );
193                break;
194            }
195        }
196
197        // Write best instructions to output file
198        self.write_output(&best_instructions)?;
199
200        info!(
201            initial_score,
202            final_score = best_score,
203            iterations_run,
204            output_path = %self.config.output_path.display(),
205            "optimization complete"
206        );
207
208        Ok(OptimizationResult {
209            initial_score,
210            final_score: best_score,
211            iterations_run,
212            best_instructions,
213        })
214    }
215
216    /// Evaluate the agent against the eval set and return an aggregate score.
217    async fn evaluate_agent(&self, agent: Arc<dyn Agent>, eval_set: &EvalSet) -> Result<f64> {
218        let base_path = std::path::Path::new(".");
219        let cases = eval_set.get_all_cases(base_path)?;
220
221        if cases.is_empty() {
222            return Ok(0.0);
223        }
224
225        let mut total_score = 0.0;
226        let mut case_count = 0u32;
227
228        for case in &cases {
229            let result = self.evaluator.evaluate_case(agent.clone(), case).await?;
230            // Compute average of all criterion scores for this case
231            let case_score = if result.scores.is_empty() {
232                if result.passed { 1.0 } else { 0.0 }
233            } else {
234                result.scores.values().sum::<f64>() / result.scores.len() as f64
235            };
236            total_score += case_score;
237            case_count += 1;
238        }
239
240        Ok(if case_count > 0 { total_score / f64::from(case_count) } else { 0.0 })
241    }
242
243    /// Ask the optimizer LLM to propose improved instructions.
244    async fn propose_improvements(
245        &self,
246        current_instructions: &str,
247        current_score: f64,
248    ) -> Result<String> {
249        let prompt = format!(
250            "You are a prompt optimization assistant. Your task is to improve the following \
251             system instructions for an AI agent.\n\n\
252             Current instructions:\n{current_instructions}\n\n\
253             Current evaluation score: {current_score:.2} (target: {target:.2})\n\n\
254             Please provide improved instructions that will help the agent perform better \
255             on its evaluation set. Return ONLY the improved instructions text, nothing else.",
256            target = self.config.target_threshold,
257        );
258
259        let request = LlmRequest::new(
260            self.optimizer_llm.name(),
261            vec![Content::new("user").with_text(prompt)],
262        );
263
264        let mut stream =
265            self.optimizer_llm.generate_content(request, false).await.map_err(|e| {
266                EvalError::ExecutionError(format!("optimizer LLM call failed: {e}"))
267            })?;
268
269        let mut result_text = String::new();
270        while let Some(response) = stream.next().await {
271            let response = response.map_err(|e| {
272                EvalError::ExecutionError(format!("optimizer LLM stream error: {e}"))
273            })?;
274            if let Some(content) = &response.content {
275                for part in &content.parts {
276                    if let Some(text) = part.text() {
277                        result_text.push_str(text);
278                    }
279                }
280            }
281        }
282
283        if result_text.is_empty() {
284            return Err(EvalError::ExecutionError(
285                "optimizer LLM returned empty response".to_string(),
286            ));
287        }
288
289        Ok(result_text)
290    }
291
292    /// Write the best instructions to the output file.
293    fn write_output(&self, instructions: &str) -> Result<()> {
294        std::fs::write(&self.config.output_path, instructions)?;
295        info!(
296            path = %self.config.output_path.display(),
297            "wrote optimized instructions to output file"
298        );
299        Ok(())
300    }
301}
302
303/// Run the core optimization loop with injectable evaluation and proposal functions.
304///
305/// This is the pure logic extracted for testability. Given a sequence of scores
306/// (one per iteration, plus the initial score), it runs the loop respecting
307/// `max_iterations` and `target_threshold`.
308///
309/// Returns `(iterations_run, best_score)`.
310pub fn run_optimization_loop(
311    scores: &[f64],
312    max_iterations: u32,
313    target_threshold: f64,
314) -> (u32, f64) {
315    if scores.is_empty() {
316        return (0, 0.0);
317    }
318
319    let initial_score = scores[0];
320
321    // Early exit if initial score meets threshold
322    if initial_score >= target_threshold {
323        return (0, initial_score);
324    }
325
326    let mut best_score = initial_score;
327    let mut iterations_run = 0u32;
328
329    for iteration in 1..=max_iterations {
330        iterations_run = iteration;
331
332        // Get the score for this iteration (cycle through available scores)
333        let score_idx = iteration as usize;
334        let score = if score_idx < scores.len() {
335            scores[score_idx]
336        } else {
337            // If we run out of scores, repeat the last one
338            scores[scores.len() - 1]
339        };
340
341        if score > best_score {
342            best_score = score;
343        }
344
345        // Check if target threshold is met
346        if best_score >= target_threshold {
347            break;
348        }
349    }
350
351    (iterations_run, best_score)
352}