1use crate as dspy_rs;
2use crate::{
3 Evaluator, Example, LM, Module, Optimizable, Optimizer, Predict, Prediction, Predictor,
4 example, get_lm,
5};
6use anyhow::Result;
7use bon::Builder;
8use dsrs_macros::Signature;
9use futures::future::join_all;
10use std::sync::Arc;
11use std::{collections::HashMap, future::Future, pin::Pin, sync::LazyLock};
12
13#[Signature]
14struct BasicGenerateInstruction {
15 #[input(desc = "The initial instructions before optimization")]
18 pub basic_instruction: String,
19 #[output(desc = "The improved instructions for the language model")]
20 pub proposed_instruction: String,
21}
22
23#[Signature]
24struct GenerateInstructionGivenAttempts {
25 #[input(
30 desc = "The instructions I've tried, along with their corresponding validation scores"
31 )]
32 pub attempted_instructions: Vec<String>,
33 #[output(desc = "The improved instructions for the language model")]
34 pub proposed_instruction: String,
35}
36
37#[derive(Clone)]
38struct Candidate {
39 pub score: f32,
40 pub instruction: String,
41 pub prefix: String,
42}
43
44#[derive(Clone)]
45struct ProgramStats {
46 pub results_best: HashMap<String, Vec<f32>>,
47 pub results_latest: HashMap<String, Vec<f32>>,
48 pub total_calls: usize,
49}
50
51#[derive(Builder)]
52pub struct COPRO {
53 #[builder(default = 10)]
54 pub breadth: usize,
55 #[builder(default = 3)]
56 pub depth: usize,
57 #[builder(default = 1.4)]
58 pub init_temperature: f32,
59 #[builder(default = false)]
60 pub track_stats: bool,
61 pub prompt_model: Option<LM>,
62}
63
64static BASIC_GENERATOR: LazyLock<Predict> =
65 LazyLock::new(|| Predict::new(BasicGenerateInstruction::new()));
66static REFINEMENT_GENERATOR: LazyLock<Predict> =
67 LazyLock::new(|| Predict::new(GenerateInstructionGivenAttempts::new()));
68
69impl COPRO {
70 fn get_output_field_prefix(&self, predictor: &dyn Optimizable) -> String {
71 let output_fields = predictor.get_signature().output_fields();
73 if let Some(obj) = output_fields.as_object()
74 && let Some((_, field)) = obj.iter().next_back()
75 && let Some(desc) = field.get("desc")
76 {
77 return desc.as_str().unwrap_or("").to_string();
78 }
79 "".to_string()
80 }
81}
82
83impl Optimizer for COPRO {
84 async fn compile<M: Module + Optimizable + Evaluator>(
85 &self,
86 module: &mut M,
87 trainset: Vec<Example>,
88 ) -> Result<()> {
89 if self.breadth <= 1 {
90 return Err(anyhow::anyhow!("Breadth must be greater than 1"));
91 }
92
93 let predictor_info: Vec<(String, String, String)> = {
95 let named_predictors = module.parameters();
96 named_predictors
97 .iter()
98 .map(|(name, predictor)| {
99 let basic_instruction = predictor.get_signature().instruction();
100 let basic_prefix = self.get_output_field_prefix(*predictor);
101 (name.clone(), basic_instruction, basic_prefix)
102 })
103 .collect()
104 };
105
106 let mut all_candidates: HashMap<String, Vec<(String, String)>> = HashMap::new();
107 let mut latest_candidates: HashMap<String, Vec<(String, String)>> = HashMap::new();
108 let mut evaluated_candidates: HashMap<String, HashMap<(String, String), Candidate>> =
109 HashMap::new();
110
111 let mut stats = ProgramStats {
112 results_best: HashMap::new(),
113 results_latest: HashMap::new(),
114 total_calls: 0,
115 };
116
117 for (predictor_name, basic_instruction, basic_prefix) in &predictor_info {
119 let mut candidates = Vec::new();
120
121 if self.breadth > 1 {
123 let mut futures: Vec<Pin<Box<dyn Future<Output = Result<Prediction>> + Send>>> =
124 Vec::new();
125
126 for _ in 0..self.breadth - 1 {
127 let inst = basic_instruction.clone();
128 if let Some(mut prompt_model) = self.prompt_model.clone() {
129 prompt_model.temperature = self.init_temperature;
130 futures.push(Box::pin(async move {
131 BASIC_GENERATOR
132 .forward_with_config(
133 example! {
134 "basic_instruction": "input" => inst
135 },
136 Arc::new(prompt_model),
137 )
138 .await
139 }));
140 } else {
141 futures.push(Box::pin(async move {
142 BASIC_GENERATOR
143 .forward_with_config(
144 example! {
145 "basic_instruction": "input" => inst
146 },
147 Arc::clone(&get_lm()),
148 )
149 .await
150 }));
151 }
152 }
153
154 let results = join_all(futures).await;
155 let predictions = results.into_iter().collect::<Result<Vec<_>>>()?;
156
157 for pred in predictions {
158 let instruction = pred
159 .data
160 .get("proposed_instruction")
161 .and_then(|v| v.as_str())
162 .unwrap_or(basic_instruction)
163 .to_string();
164 let prefix = pred
165 .data
166 .get("proposed_prefix_for_output_field")
167 .and_then(|v| v.as_str())
168 .unwrap_or(basic_prefix)
169 .to_string();
170 candidates.push((instruction, prefix));
171 }
172 }
173
174 candidates.push((basic_instruction.clone(), basic_prefix.clone()));
175
176 all_candidates.insert(predictor_name.clone(), candidates.clone());
177 latest_candidates.insert(predictor_name.clone(), candidates);
178 evaluated_candidates.insert(predictor_name.clone(), HashMap::new());
179
180 if self.track_stats {
181 stats
182 .results_best
183 .insert(predictor_name.clone(), Vec::new());
184 stats
185 .results_latest
186 .insert(predictor_name.clone(), Vec::new());
187 }
188 }
189
190 for d in 0..self.depth {
192 println!("Iteration Depth: {}/{}", d + 1, self.depth);
193
194 for (p_i, (predictor_name, _, _)) in predictor_info.iter().enumerate() {
196 let candidates_to_eval = if predictor_info.len() > 1 {
198 all_candidates.get(predictor_name).unwrap().clone()
200 } else {
201 latest_candidates.get(predictor_name).unwrap().clone()
203 };
204
205 let mut latest_scores = Vec::new();
206
207 for (c_i, (instruction, prefix)) in candidates_to_eval.iter().enumerate() {
208 let key = (instruction.clone(), prefix.clone());
210
211 let score = if let Some(existing) = evaluated_candidates
212 .get(predictor_name)
213 .and_then(|m| m.get(&key))
214 {
215 existing.score
217 } else {
218 {
220 let mut module_predictors = module.parameters();
221 if let Some(predictor) = module_predictors.get_mut(predictor_name) {
222 predictor.update_signature_instruction(instruction.clone())?;
223 }
226 }
227
228 println!(
229 "At Depth {}/{}, Evaluating Prompt Candidate #{}/{} for Predictor {} of {}",
230 d + 1,
231 self.depth,
232 c_i + 1,
233 candidates_to_eval.len(),
234 p_i + 1,
235 predictor_info.len()
236 );
237
238 let score = module.evaluate(trainset.clone()).await;
240 stats.total_calls += 1;
241
242 evaluated_candidates
244 .get_mut(predictor_name)
245 .unwrap()
246 .insert(
247 key,
248 Candidate {
249 score,
250 instruction: instruction.clone(),
251 prefix: prefix.clone(),
252 },
253 );
254
255 score
256 };
257
258 if candidates_to_eval.len() - self.breadth <= c_i {
260 latest_scores.push(score);
261 }
262 }
263
264 if let Some(best) = evaluated_candidates.get(predictor_name).and_then(|m| {
266 m.values()
267 .max_by(|a, b| a.score.partial_cmp(&b.score).unwrap())
268 }) {
269 {
270 let mut module_predictors = module.parameters();
271 if let Some(predictor) = module_predictors.get_mut(predictor_name) {
272 predictor.update_signature_instruction(best.instruction.clone())?;
273 }
274 }
275
276 println!(
277 "Updating Predictor {} to best candidate with score {:.3}",
278 predictor_name, best.score
279 );
280 }
281
282 if self.track_stats && !latest_scores.is_empty() {
284 let avg = latest_scores.iter().sum::<f32>() / latest_scores.len() as f32;
285 stats
286 .results_latest
287 .get_mut(predictor_name)
288 .unwrap()
289 .push(avg);
290
291 let mut best_scores: Vec<f32> = evaluated_candidates
293 .get(predictor_name)
294 .unwrap()
295 .values()
296 .map(|c| c.score)
297 .collect();
298 best_scores.sort_by(|a, b| b.partial_cmp(a).unwrap());
299 best_scores.truncate(10);
300
301 if !best_scores.is_empty() {
302 let best_avg = best_scores.iter().sum::<f32>() / best_scores.len() as f32;
303 stats
304 .results_best
305 .get_mut(predictor_name)
306 .unwrap()
307 .push(best_avg);
308 }
309 }
310 }
311
312 if d == self.depth - 1 {
314 break;
315 }
316
317 let mut new_latest_candidates = HashMap::new();
319
320 for (predictor_name, _, _) in &predictor_info {
321 let mut attempts_list = Vec::new();
323 let mut best_candidates: Vec<_> = evaluated_candidates
324 .get(predictor_name)
325 .unwrap()
326 .values()
327 .cloned()
328 .collect();
329 best_candidates.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
330
331 let num_examples = std::cmp::min(self.breadth, best_candidates.len());
333 for (i, candidate) in best_candidates.iter().take(num_examples).enumerate() {
334 attempts_list.push(format!(
335 "Instruction #{}: {}",
336 i + 1,
337 candidate.instruction
338 ));
339 attempts_list.push(format!("Prefix #{}: {}", i + 1, candidate.prefix));
340 attempts_list.push(format!(
341 "Resulting Score #{}: {:.3}",
342 i + 1,
343 candidate.score
344 ));
345 }
346
347 let attempts_str = attempts_list.join("\n");
348
349 let results = if let Some(mut prompt_model) = self.prompt_model.clone() {
351 prompt_model.temperature = self.init_temperature;
352 let attempts = attempts_str.clone();
353
354 REFINEMENT_GENERATOR
355 .batch_with_config(
356 (0..self.breadth)
357 .map(|_| {
358 example! {
359 "attempted_instructions": "input" => attempts.clone()
360 }
361 })
362 .collect(),
363 Arc::new(prompt_model),
364 )
365 .await
366 } else {
367 let attempts = attempts_str.clone();
368 REFINEMENT_GENERATOR
369 .batch_with_config(
370 (0..self.breadth)
371 .map(|_| {
372 example! {
373 "attempted_instructions": "input" => attempts.clone()
374 }
375 })
376 .collect(),
377 Arc::clone(&get_lm()),
378 )
379 .await
380 };
381
382 if let Ok(predictions) = results {
383 let mut new_candidates = Vec::new();
384
385 for pred in predictions {
386 let instructions = if let Some(arr) = pred
388 .data
389 .get("proposed_instruction")
390 .and_then(|v| v.as_array())
391 {
392 arr.iter()
393 .filter_map(|v| v.as_str())
394 .map(|s| s.to_string())
395 .collect()
396 } else if let Some(s) = pred
397 .data
398 .get("proposed_instruction")
399 .and_then(|v| v.as_str())
400 {
401 vec![s.to_string()]
402 } else {
403 vec![]
404 };
405
406 let prefixes = if let Some(arr) = pred
407 .data
408 .get("proposed_prefix_for_output_field")
409 .and_then(|v| v.as_array())
410 {
411 arr.iter()
412 .filter_map(|v| v.as_str())
413 .map(|s| s.to_string())
414 .collect()
415 } else if let Some(s) = pred
416 .data
417 .get("proposed_prefix_for_output_field")
418 .and_then(|v| v.as_str())
419 {
420 vec![s.to_string()]
421 } else {
422 vec![]
423 };
424
425 for (inst, pref) in instructions.iter().zip(prefixes.iter()) {
426 new_candidates.push((inst.clone(), pref.clone()));
427 }
428 }
429
430 all_candidates
432 .get_mut(predictor_name)
433 .unwrap()
434 .extend(new_candidates.clone());
435 new_latest_candidates.insert(predictor_name.clone(), new_candidates);
436 }
437 }
438
439 latest_candidates = new_latest_candidates;
440 }
441
442 let mut best_overall: Option<(String, Candidate)> = None;
444
445 for (predictor_name, candidates_map) in &evaluated_candidates {
446 if let Some(best) = candidates_map
447 .values()
448 .max_by(|a, b| a.score.partial_cmp(&b.score).unwrap())
449 && (best_overall.is_none() || best.score > best_overall.as_ref().unwrap().1.score)
450 {
451 best_overall = Some((predictor_name.clone(), best.clone()));
452 }
453 }
454
455 if let Some((_, best_candidate)) = best_overall {
457 let module_predictors = module.parameters();
458 for (predictor_name, predictor) in module_predictors {
459 if let Some(best) = evaluated_candidates.get(&predictor_name).and_then(|m| {
460 m.values()
461 .max_by(|a, b| a.score.partial_cmp(&b.score).unwrap())
462 }) {
463 predictor.update_signature_instruction(best.instruction.clone())?;
464 }
465 }
466
467 if self.track_stats {
468 println!("\n=== Optimization Complete ===");
469 println!("Total calls: {}", stats.total_calls);
470 println!("Best score: {:.3}", best_candidate.score);
471 println!("Best instruction: {}", best_candidate.instruction);
472 if !best_candidate.prefix.is_empty() {
473 println!("Best prefix: {}", best_candidate.prefix);
474 }
475 }
476 }
477
478 Ok(())
479 }
480}