rs_adk/optimization/
simple_prompt.rs1use std::sync::Arc;
6
7use async_trait::async_trait;
8
9use super::optimizer::{AgentOptimizer, OptimizerError, OptimizerResult};
10use super::sampler::Sampler;
11use crate::llm::BaseLlm;
12
13#[derive(Debug, Clone)]
15pub struct SimplePromptOptimizerConfig {
16 pub num_iterations: usize,
18 pub batch_size: usize,
20}
21
22impl Default for SimplePromptOptimizerConfig {
23 fn default() -> Self {
24 Self {
25 num_iterations: 10,
26 batch_size: 5,
27 }
28 }
29}
30
31pub struct SimplePromptOptimizer {
39 optimizer_llm: Arc<dyn BaseLlm>,
41 sampler: Arc<dyn Sampler>,
43 config: SimplePromptOptimizerConfig,
45}
46
47impl SimplePromptOptimizer {
48 pub fn new(
50 optimizer_llm: Arc<dyn BaseLlm>,
51 sampler: Arc<dyn Sampler>,
52 config: SimplePromptOptimizerConfig,
53 ) -> Self {
54 Self {
55 optimizer_llm,
56 sampler,
57 config,
58 }
59 }
60
61 async fn generate_candidate(
63 &self,
64 current_instruction: &str,
65 current_score: f64,
66 ) -> Result<String, OptimizerError> {
67 let prompt = format!(
68 "You are an expert prompt engineer. Your task is to improve the following \
69 agent instruction to achieve better performance.\n\n\
70 Current instruction (score: {current_score:.2}):\n\
71 ---\n{current_instruction}\n---\n\n\
72 Generate an improved version of the instruction. Focus on:\n\
73 - Clarity and specificity\n\
74 - Better task decomposition guidance\n\
75 - More effective tool use instructions\n\
76 - Appropriate constraints and guardrails\n\n\
77 Respond with ONLY the improved instruction text, nothing else."
78 );
79
80 let request = crate::llm::LlmRequest::from_text(&prompt);
81 let response = self
82 .optimizer_llm
83 .generate(request)
84 .await
85 .map_err(|e| OptimizerError::Llm(e.to_string()))?;
86
87 Ok(response.text())
88 }
89}
90
91#[async_trait]
92impl AgentOptimizer for SimplePromptOptimizer {
93 async fn optimize(
94 &self,
95 initial_instruction: &str,
96 model_id: &str,
97 ) -> Result<OptimizerResult, OptimizerError> {
98 let mut best_instruction = initial_instruction.to_string();
99 let mut score_history = Vec::new();
100
101 let training_batch = self.sampler.sample_training(self.config.batch_size).await?;
103 let mut best_score = self
104 .sampler
105 .score(&best_instruction, model_id, &training_batch.cases)
106 .await?;
107 score_history.push((0, best_score));
108
109 for iteration in 1..=self.config.num_iterations {
111 let candidate = self
112 .generate_candidate(&best_instruction, best_score)
113 .await?;
114
115 let training_batch = self.sampler.sample_training(self.config.batch_size).await?;
116 let candidate_score = self
117 .sampler
118 .score(&candidate, model_id, &training_batch.cases)
119 .await?;
120
121 score_history.push((iteration, candidate_score));
122
123 if candidate_score > best_score {
124 best_instruction = candidate;
125 best_score = candidate_score;
126 }
127 }
128
129 let validation = self.sampler.validation_set().await?;
131 let validation_score = self
132 .sampler
133 .score(&best_instruction, model_id, &validation.cases)
134 .await?;
135
136 Ok(OptimizerResult {
137 best_instruction,
138 best_score: validation_score,
139 iterations: self.config.num_iterations,
140 score_history,
141 })
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn default_config() {
151 let config = SimplePromptOptimizerConfig::default();
152 assert_eq!(config.num_iterations, 10);
153 assert_eq!(config.batch_size, 5);
154 }
155}