Skip to main content

rs_adk/optimization/
simple_prompt.rs

1//! Simple prompt optimizer — iteratively rewrites instructions using an LLM.
2//!
3//! Mirrors ADK-Python's `SimplePromptOptimizer`.
4
5use std::sync::Arc;
6
7use async_trait::async_trait;
8
9use super::optimizer::{AgentOptimizer, OptimizerError, OptimizerResult};
10use super::sampler::Sampler;
11use crate::llm::BaseLlm;
12
13/// Configuration for the simple prompt optimizer.
14#[derive(Debug, Clone)]
15pub struct SimplePromptOptimizerConfig {
16    /// Number of optimization iterations.
17    pub num_iterations: usize,
18    /// Number of training examples per evaluation batch.
19    pub batch_size: usize,
20}
21
22impl Default for SimplePromptOptimizerConfig {
23    fn default() -> Self {
24        Self {
25            num_iterations: 10,
26            batch_size: 5,
27        }
28    }
29}
30
31/// Simple prompt optimizer that uses an LLM to iteratively rewrite instructions.
32///
33/// Process:
34/// 1. Score the initial instruction on a training batch
35/// 2. For each iteration, generate a candidate instruction using the optimizer LLM
36/// 3. Score the candidate — keep it if it improves on the best score
37/// 4. Validate the best instruction on the validation set
38pub struct SimplePromptOptimizer {
39    /// The LLM used to generate candidate instructions.
40    optimizer_llm: Arc<dyn BaseLlm>,
41    /// Sampler for training/validation examples.
42    sampler: Arc<dyn Sampler>,
43    /// Optimizer configuration.
44    config: SimplePromptOptimizerConfig,
45}
46
47impl SimplePromptOptimizer {
48    /// Create a new simple prompt optimizer.
49    pub fn new(
50        optimizer_llm: Arc<dyn BaseLlm>,
51        sampler: Arc<dyn Sampler>,
52        config: SimplePromptOptimizerConfig,
53    ) -> Self {
54        Self {
55            optimizer_llm,
56            sampler,
57            config,
58        }
59    }
60
61    /// Generate a candidate instruction by asking the optimizer LLM to improve it.
62    async fn generate_candidate(
63        &self,
64        current_instruction: &str,
65        current_score: f64,
66    ) -> Result<String, OptimizerError> {
67        let prompt = format!(
68            "You are an expert prompt engineer. Your task is to improve the following \
69             agent instruction to achieve better performance.\n\n\
70             Current instruction (score: {current_score:.2}):\n\
71             ---\n{current_instruction}\n---\n\n\
72             Generate an improved version of the instruction. Focus on:\n\
73             - Clarity and specificity\n\
74             - Better task decomposition guidance\n\
75             - More effective tool use instructions\n\
76             - Appropriate constraints and guardrails\n\n\
77             Respond with ONLY the improved instruction text, nothing else."
78        );
79
80        let request = crate::llm::LlmRequest::from_text(&prompt);
81        let response = self
82            .optimizer_llm
83            .generate(request)
84            .await
85            .map_err(|e| OptimizerError::Llm(e.to_string()))?;
86
87        Ok(response.text())
88    }
89}
90
91#[async_trait]
92impl AgentOptimizer for SimplePromptOptimizer {
93    async fn optimize(
94        &self,
95        initial_instruction: &str,
96        model_id: &str,
97    ) -> Result<OptimizerResult, OptimizerError> {
98        let mut best_instruction = initial_instruction.to_string();
99        let mut score_history = Vec::new();
100
101        // Score baseline
102        let training_batch = self.sampler.sample_training(self.config.batch_size).await?;
103        let mut best_score = self
104            .sampler
105            .score(&best_instruction, model_id, &training_batch.cases)
106            .await?;
107        score_history.push((0, best_score));
108
109        // Optimization loop
110        for iteration in 1..=self.config.num_iterations {
111            let candidate = self
112                .generate_candidate(&best_instruction, best_score)
113                .await?;
114
115            let training_batch = self.sampler.sample_training(self.config.batch_size).await?;
116            let candidate_score = self
117                .sampler
118                .score(&candidate, model_id, &training_batch.cases)
119                .await?;
120
121            score_history.push((iteration, candidate_score));
122
123            if candidate_score > best_score {
124                best_instruction = candidate;
125                best_score = candidate_score;
126            }
127        }
128
129        // Final validation
130        let validation = self.sampler.validation_set().await?;
131        let validation_score = self
132            .sampler
133            .score(&best_instruction, model_id, &validation.cases)
134            .await?;
135
136        Ok(OptimizerResult {
137            best_instruction,
138            best_score: validation_score,
139            iterations: self.config.num_iterations,
140            score_history,
141        })
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn default_config() {
151        let config = SimplePromptOptimizerConfig::default();
152        assert_eq!(config.num_iterations, 10);
153        assert_eq!(config.batch_size, 5);
154    }
155}