1use std::path::PathBuf;
32use std::sync::Arc;
33
34use futures::StreamExt;
35use tracing::{info, warn};
36
37use adk_core::types::Content;
38use adk_core::{Agent, Llm, LlmRequest};
39
40use crate::error::{EvalError, Result};
41use crate::evaluator::Evaluator;
42use crate::schema::EvalSet;
43
44#[derive(Debug, Clone)]
46pub struct OptimizerConfig {
47 pub max_iterations: u32,
49 pub target_threshold: f64,
52 pub output_path: PathBuf,
54}
55
56impl Default for OptimizerConfig {
57 fn default() -> Self {
58 Self {
59 max_iterations: 5,
60 target_threshold: 0.9,
61 output_path: PathBuf::from("optimized_instructions.txt"),
62 }
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct OptimizationResult {
69 pub initial_score: f64,
71 pub final_score: f64,
73 pub iterations_run: u32,
75 pub best_instructions: String,
77}
78
79pub struct PromptOptimizer {
85 optimizer_llm: Arc<dyn Llm>,
86 evaluator: Evaluator,
87 config: OptimizerConfig,
88}
89
90impl PromptOptimizer {
91 pub fn new(optimizer_llm: Arc<dyn Llm>, evaluator: Evaluator, config: OptimizerConfig) -> Self {
100 Self { optimizer_llm, evaluator, config }
101 }
102
103 pub async fn optimize(
112 &self,
113 agent: Arc<dyn Agent>,
114 eval_set: &EvalSet,
115 ) -> Result<OptimizationResult> {
116 let base_path = std::path::Path::new(".");
117 let cases = eval_set.get_all_cases(base_path)?;
118
119 if cases.is_empty() {
120 return Err(EvalError::ConfigError("eval set contains no cases".to_string()));
121 }
122
123 let mut current_instructions = agent.description().to_string();
125
126 let initial_score = self.evaluate_agent(agent.clone(), eval_set).await?;
128 info!(iteration = 0, score = initial_score, "initial evaluation complete");
129
130 if initial_score >= self.config.target_threshold {
132 info!(
133 score = initial_score,
134 threshold = self.config.target_threshold,
135 "no optimization needed — initial score meets target threshold"
136 );
137
138 self.write_output(¤t_instructions)?;
139
140 return Ok(OptimizationResult {
141 initial_score,
142 final_score: initial_score,
143 iterations_run: 0,
144 best_instructions: current_instructions,
145 });
146 }
147
148 let mut best_score = initial_score;
149 let mut best_instructions = current_instructions.clone();
150 let mut iterations_run = 0;
151
152 for iteration in 1..=self.config.max_iterations {
153 iterations_run = iteration;
154
155 let proposed = self.propose_improvements(¤t_instructions, best_score).await?;
157
158 info!(
159 iteration,
160 current_score = best_score,
161 proposed_changes = %proposed,
162 "proposed instruction improvements"
163 );
164
165 current_instructions = proposed.clone();
167
168 let score = self.evaluate_agent(agent.clone(), eval_set).await?;
170
171 info!(iteration, score, previous_best = best_score, "evaluation complete");
172
173 if score > best_score {
174 best_score = score;
175 best_instructions = current_instructions.clone();
176 } else {
177 warn!(
179 iteration,
180 score, best_score, "score did not improve, reverting to best instructions"
181 );
182 current_instructions = best_instructions.clone();
183 }
184
185 if best_score >= self.config.target_threshold {
187 info!(
188 iteration,
189 score = best_score,
190 threshold = self.config.target_threshold,
191 "target threshold reached — stopping early"
192 );
193 break;
194 }
195 }
196
197 self.write_output(&best_instructions)?;
199
200 info!(
201 initial_score,
202 final_score = best_score,
203 iterations_run,
204 output_path = %self.config.output_path.display(),
205 "optimization complete"
206 );
207
208 Ok(OptimizationResult {
209 initial_score,
210 final_score: best_score,
211 iterations_run,
212 best_instructions,
213 })
214 }
215
216 async fn evaluate_agent(&self, agent: Arc<dyn Agent>, eval_set: &EvalSet) -> Result<f64> {
218 let base_path = std::path::Path::new(".");
219 let cases = eval_set.get_all_cases(base_path)?;
220
221 if cases.is_empty() {
222 return Ok(0.0);
223 }
224
225 let mut total_score = 0.0;
226 let mut case_count = 0u32;
227
228 for case in &cases {
229 let result = self.evaluator.evaluate_case(agent.clone(), case).await?;
230 let case_score = if result.scores.is_empty() {
232 if result.passed { 1.0 } else { 0.0 }
233 } else {
234 result.scores.values().sum::<f64>() / result.scores.len() as f64
235 };
236 total_score += case_score;
237 case_count += 1;
238 }
239
240 Ok(if case_count > 0 { total_score / f64::from(case_count) } else { 0.0 })
241 }
242
243 async fn propose_improvements(
245 &self,
246 current_instructions: &str,
247 current_score: f64,
248 ) -> Result<String> {
249 let prompt = format!(
250 "You are a prompt optimization assistant. Your task is to improve the following \
251 system instructions for an AI agent.\n\n\
252 Current instructions:\n{current_instructions}\n\n\
253 Current evaluation score: {current_score:.2} (target: {target:.2})\n\n\
254 Please provide improved instructions that will help the agent perform better \
255 on its evaluation set. Return ONLY the improved instructions text, nothing else.",
256 target = self.config.target_threshold,
257 );
258
259 let request = LlmRequest::new(
260 self.optimizer_llm.name(),
261 vec![Content::new("user").with_text(prompt)],
262 );
263
264 let mut stream =
265 self.optimizer_llm.generate_content(request, false).await.map_err(|e| {
266 EvalError::ExecutionError(format!("optimizer LLM call failed: {e}"))
267 })?;
268
269 let mut result_text = String::new();
270 while let Some(response) = stream.next().await {
271 let response = response.map_err(|e| {
272 EvalError::ExecutionError(format!("optimizer LLM stream error: {e}"))
273 })?;
274 if let Some(content) = &response.content {
275 for part in &content.parts {
276 if let Some(text) = part.text() {
277 result_text.push_str(text);
278 }
279 }
280 }
281 }
282
283 if result_text.is_empty() {
284 return Err(EvalError::ExecutionError(
285 "optimizer LLM returned empty response".to_string(),
286 ));
287 }
288
289 Ok(result_text)
290 }
291
292 fn write_output(&self, instructions: &str) -> Result<()> {
294 std::fs::write(&self.config.output_path, instructions)?;
295 info!(
296 path = %self.config.output_path.display(),
297 "wrote optimized instructions to output file"
298 );
299 Ok(())
300 }
301}
302
303pub fn run_optimization_loop(
311 scores: &[f64],
312 max_iterations: u32,
313 target_threshold: f64,
314) -> (u32, f64) {
315 if scores.is_empty() {
316 return (0, 0.0);
317 }
318
319 let initial_score = scores[0];
320
321 if initial_score >= target_threshold {
323 return (0, initial_score);
324 }
325
326 let mut best_score = initial_score;
327 let mut iterations_run = 0u32;
328
329 for iteration in 1..=max_iterations {
330 iterations_run = iteration;
331
332 let score_idx = iteration as usize;
334 let score = if score_idx < scores.len() {
335 scores[score_idx]
336 } else {
337 scores[scores.len() - 1]
339 };
340
341 if score > best_score {
342 best_score = score;
343 }
344
345 if best_score >= target_threshold {
347 break;
348 }
349 }
350
351 (iterations_run, best_score)
352}