dspy_rs/optimizer/
copro.rs

1use crate as dspy_rs;
2use crate::{
3    Evaluator, Example, LM, Module, Optimizable, Optimizer, Predict, Prediction, Predictor,
4    example, get_lm,
5};
6use anyhow::Result;
7use bon::Builder;
8use dsrs_macros::Signature;
9use futures::future::join_all;
10use std::sync::Arc;
11use std::{collections::HashMap, future::Future, pin::Pin, sync::LazyLock};
12
13#[Signature]
14struct BasicGenerateInstruction {
15    /// You are an instruction optimizer for large language models. I will give you a ``signature`` of fields (inputs and outputs) in English. Your task is to propose an instruction that will lead a good language model to perform the task well. Don't be afraid to be creative.
16
17    #[input(desc = "The initial instructions before optimization")]
18    pub basic_instruction: String,
19    #[output(desc = "The improved instructions for the language model")]
20    pub proposed_instruction: String,
21}
22
23#[Signature]
24struct GenerateInstructionGivenAttempts {
25    /// You are an instruction optimizer for large language models. I will give some task instructions I've tried, along with their corresponding validation scores. The instructions are arranged in increasing order based on their scores, where higher scores indicate better quality.
26    ///
27    /// Your task is to propose a new instruction that will lead a good language model to perform the task even better. Don't be afraid to be creative.
28
29    #[input(
30        desc = "The instructions I've tried, along with their corresponding validation scores"
31    )]
32    pub attempted_instructions: Vec<String>,
33    #[output(desc = "The improved instructions for the language model")]
34    pub proposed_instruction: String,
35}
36
37#[derive(Clone)]
38struct Candidate {
39    pub score: f32,
40    pub instruction: String,
41    pub prefix: String,
42}
43
44#[derive(Clone)]
45struct ProgramStats {
46    pub results_best: HashMap<String, Vec<f32>>,
47    pub results_latest: HashMap<String, Vec<f32>>,
48    pub total_calls: usize,
49}
50
51#[derive(Builder)]
52pub struct COPRO {
53    #[builder(default = 10)]
54    pub breadth: usize,
55    #[builder(default = 3)]
56    pub depth: usize,
57    #[builder(default = 1.4)]
58    pub init_temperature: f32,
59    #[builder(default = false)]
60    pub track_stats: bool,
61    pub prompt_model: Option<LM>,
62}
63
64static BASIC_GENERATOR: LazyLock<Predict> =
65    LazyLock::new(|| Predict::new(BasicGenerateInstruction::new()));
66static REFINEMENT_GENERATOR: LazyLock<Predict> =
67    LazyLock::new(|| Predict::new(GenerateInstructionGivenAttempts::new()));
68
69impl COPRO {
70    fn get_output_field_prefix(&self, predictor: &dyn Optimizable) -> String {
71        // Get the last output field's prefix/desc
72        let output_fields = predictor.get_signature().output_fields();
73        if let Some(obj) = output_fields.as_object()
74            && let Some((_, field)) = obj.iter().next_back()
75            && let Some(desc) = field.get("desc")
76        {
77            return desc.as_str().unwrap_or("").to_string();
78        }
79        "".to_string()
80    }
81}
82
83impl Optimizer for COPRO {
84    async fn compile<M: Module + Optimizable + Evaluator>(
85        &self,
86        module: &mut M,
87        trainset: Vec<Example>,
88    ) -> Result<()> {
89        if self.breadth <= 1 {
90            return Err(anyhow::anyhow!("Breadth must be greater than 1"));
91        }
92
93        // Collect predictor information first
94        let predictor_info: Vec<(String, String, String)> = {
95            let named_predictors = module.parameters();
96            named_predictors
97                .iter()
98                .map(|(name, predictor)| {
99                    let basic_instruction = predictor.get_signature().instruction();
100                    let basic_prefix = self.get_output_field_prefix(*predictor);
101                    (name.clone(), basic_instruction, basic_prefix)
102                })
103                .collect()
104        };
105
106        let mut all_candidates: HashMap<String, Vec<(String, String)>> = HashMap::new();
107        let mut latest_candidates: HashMap<String, Vec<(String, String)>> = HashMap::new();
108        let mut evaluated_candidates: HashMap<String, HashMap<(String, String), Candidate>> =
109            HashMap::new();
110
111        let mut stats = ProgramStats {
112            results_best: HashMap::new(),
113            results_latest: HashMap::new(),
114            total_calls: 0,
115        };
116
117        // Seed with initial instructions - generate breadth-1 new + 1 original
118        for (predictor_name, basic_instruction, basic_prefix) in &predictor_info {
119            let mut candidates = Vec::new();
120
121            // Generate new candidates
122            if self.breadth > 1 {
123                let mut futures: Vec<Pin<Box<dyn Future<Output = Result<Prediction>> + Send>>> =
124                    Vec::new();
125
126                for _ in 0..self.breadth - 1 {
127                    let inst = basic_instruction.clone();
128                    if let Some(mut prompt_model) = self.prompt_model.clone() {
129                        prompt_model.temperature = self.init_temperature;
130                        futures.push(Box::pin(async move {
131                            BASIC_GENERATOR
132                                .forward_with_config(
133                                    example! {
134                                        "basic_instruction": "input" => inst
135                                    },
136                                    Arc::new(prompt_model),
137                                )
138                                .await
139                        }));
140                    } else {
141                        futures.push(Box::pin(async move {
142                            BASIC_GENERATOR
143                                .forward_with_config(
144                                    example! {
145                                        "basic_instruction": "input" => inst
146                                    },
147                                    Arc::clone(&get_lm()),
148                                )
149                                .await
150                        }));
151                    }
152                }
153
154                let results = join_all(futures).await;
155                let predictions = results.into_iter().collect::<Result<Vec<_>>>()?;
156
157                for pred in predictions {
158                    let instruction = pred
159                        .data
160                        .get("proposed_instruction")
161                        .and_then(|v| v.as_str())
162                        .unwrap_or(basic_instruction)
163                        .to_string();
164                    let prefix = pred
165                        .data
166                        .get("proposed_prefix_for_output_field")
167                        .and_then(|v| v.as_str())
168                        .unwrap_or(basic_prefix)
169                        .to_string();
170                    candidates.push((instruction, prefix));
171                }
172            }
173
174            candidates.push((basic_instruction.clone(), basic_prefix.clone()));
175
176            all_candidates.insert(predictor_name.clone(), candidates.clone());
177            latest_candidates.insert(predictor_name.clone(), candidates);
178            evaluated_candidates.insert(predictor_name.clone(), HashMap::new());
179
180            if self.track_stats {
181                stats
182                    .results_best
183                    .insert(predictor_name.clone(), Vec::new());
184                stats
185                    .results_latest
186                    .insert(predictor_name.clone(), Vec::new());
187            }
188        }
189
190        // Main optimization loop
191        for d in 0..self.depth {
192            println!("Iteration Depth: {}/{}", d + 1, self.depth);
193
194            // Evaluate candidates for each predictor
195            for (p_i, (predictor_name, _, _)) in predictor_info.iter().enumerate() {
196                // Determine which candidates to evaluate
197                let candidates_to_eval = if predictor_info.len() > 1 {
198                    // Re-evaluate all candidates when multiple predictors
199                    all_candidates.get(predictor_name).unwrap().clone()
200                } else {
201                    // Just evaluate latest candidates
202                    latest_candidates.get(predictor_name).unwrap().clone()
203                };
204
205                let mut latest_scores = Vec::new();
206
207                for (c_i, (instruction, prefix)) in candidates_to_eval.iter().enumerate() {
208                    // Check if already evaluated
209                    let key = (instruction.clone(), prefix.clone());
210
211                    let score = if let Some(existing) = evaluated_candidates
212                        .get(predictor_name)
213                        .and_then(|m| m.get(&key))
214                    {
215                        // Skip if already evaluated with same or better score
216                        existing.score
217                    } else {
218                        // Update predictor with candidate
219                        {
220                            let mut module_predictors = module.parameters();
221                            if let Some(predictor) = module_predictors.get_mut(predictor_name) {
222                                predictor.update_signature_instruction(instruction.clone())?;
223                                // Note: We can't update prefix without modifying the signature system
224                                // This would require extending MetaSignature trait
225                            }
226                        }
227
228                        println!(
229                            "At Depth {}/{}, Evaluating Prompt Candidate #{}/{} for Predictor {} of {}",
230                            d + 1,
231                            self.depth,
232                            c_i + 1,
233                            candidates_to_eval.len(),
234                            p_i + 1,
235                            predictor_info.len()
236                        );
237
238                        // Evaluate
239                        let score = module.evaluate(trainset.clone()).await;
240                        stats.total_calls += 1;
241
242                        // Store evaluated candidate
243                        evaluated_candidates
244                            .get_mut(predictor_name)
245                            .unwrap()
246                            .insert(
247                                key,
248                                Candidate {
249                                    score,
250                                    instruction: instruction.clone(),
251                                    prefix: prefix.clone(),
252                                },
253                            );
254
255                        score
256                    };
257
258                    // Track latest scores for stats
259                    if candidates_to_eval.len() - self.breadth <= c_i {
260                        latest_scores.push(score);
261                    }
262                }
263
264                // Update to best candidate for this predictor
265                if let Some(best) = evaluated_candidates.get(predictor_name).and_then(|m| {
266                    m.values()
267                        .max_by(|a, b| a.score.partial_cmp(&b.score).unwrap())
268                }) {
269                    {
270                        let mut module_predictors = module.parameters();
271                        if let Some(predictor) = module_predictors.get_mut(predictor_name) {
272                            predictor.update_signature_instruction(best.instruction.clone())?;
273                        }
274                    }
275
276                    println!(
277                        "Updating Predictor {} to best candidate with score {:.3}",
278                        predictor_name, best.score
279                    );
280                }
281
282                // Track stats
283                if self.track_stats && !latest_scores.is_empty() {
284                    let avg = latest_scores.iter().sum::<f32>() / latest_scores.len() as f32;
285                    stats
286                        .results_latest
287                        .get_mut(predictor_name)
288                        .unwrap()
289                        .push(avg);
290
291                    // Track best scores
292                    let mut best_scores: Vec<f32> = evaluated_candidates
293                        .get(predictor_name)
294                        .unwrap()
295                        .values()
296                        .map(|c| c.score)
297                        .collect();
298                    best_scores.sort_by(|a, b| b.partial_cmp(a).unwrap());
299                    best_scores.truncate(10);
300
301                    if !best_scores.is_empty() {
302                        let best_avg = best_scores.iter().sum::<f32>() / best_scores.len() as f32;
303                        stats
304                            .results_best
305                            .get_mut(predictor_name)
306                            .unwrap()
307                            .push(best_avg);
308                    }
309                }
310            }
311
312            // Skip generation on last iteration
313            if d == self.depth - 1 {
314                break;
315            }
316
317            // Generate new candidates based on attempts
318            let mut new_latest_candidates = HashMap::new();
319
320            for (predictor_name, _, _) in &predictor_info {
321                // Build few-shot examples from best attempts
322                let mut attempts_list = Vec::new();
323                let mut best_candidates: Vec<_> = evaluated_candidates
324                    .get(predictor_name)
325                    .unwrap()
326                    .values()
327                    .cloned()
328                    .collect();
329                best_candidates.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
330
331                // Take up to breadth best candidates
332                let num_examples = std::cmp::min(self.breadth, best_candidates.len());
333                for (i, candidate) in best_candidates.iter().take(num_examples).enumerate() {
334                    attempts_list.push(format!(
335                        "Instruction #{}: {}",
336                        i + 1,
337                        candidate.instruction
338                    ));
339                    attempts_list.push(format!("Prefix #{}: {}", i + 1, candidate.prefix));
340                    attempts_list.push(format!(
341                        "Resulting Score #{}: {:.3}",
342                        i + 1,
343                        candidate.score
344                    ));
345                }
346
347                let attempts_str = attempts_list.join("\n");
348
349                // Generate new candidates
350                let results = if let Some(mut prompt_model) = self.prompt_model.clone() {
351                    prompt_model.temperature = self.init_temperature;
352                    let attempts = attempts_str.clone();
353
354                    REFINEMENT_GENERATOR
355                        .batch_with_config(
356                            (0..self.breadth)
357                                .map(|_| {
358                                    example! {
359                                        "attempted_instructions": "input" => attempts.clone()
360                                    }
361                                })
362                                .collect(),
363                            Arc::new(prompt_model),
364                        )
365                        .await
366                } else {
367                    let attempts = attempts_str.clone();
368                    REFINEMENT_GENERATOR
369                        .batch_with_config(
370                            (0..self.breadth)
371                                .map(|_| {
372                                    example! {
373                                        "attempted_instructions": "input" => attempts.clone()
374                                    }
375                                })
376                                .collect(),
377                            Arc::clone(&get_lm()),
378                        )
379                        .await
380                };
381
382                if let Ok(predictions) = results {
383                    let mut new_candidates = Vec::new();
384
385                    for pred in predictions {
386                        // Handle both single and multiple completions
387                        let instructions = if let Some(arr) = pred
388                            .data
389                            .get("proposed_instruction")
390                            .and_then(|v| v.as_array())
391                        {
392                            arr.iter()
393                                .filter_map(|v| v.as_str())
394                                .map(|s| s.to_string())
395                                .collect()
396                        } else if let Some(s) = pred
397                            .data
398                            .get("proposed_instruction")
399                            .and_then(|v| v.as_str())
400                        {
401                            vec![s.to_string()]
402                        } else {
403                            vec![]
404                        };
405
406                        let prefixes = if let Some(arr) = pred
407                            .data
408                            .get("proposed_prefix_for_output_field")
409                            .and_then(|v| v.as_array())
410                        {
411                            arr.iter()
412                                .filter_map(|v| v.as_str())
413                                .map(|s| s.to_string())
414                                .collect()
415                        } else if let Some(s) = pred
416                            .data
417                            .get("proposed_prefix_for_output_field")
418                            .and_then(|v| v.as_str())
419                        {
420                            vec![s.to_string()]
421                        } else {
422                            vec![]
423                        };
424
425                        for (inst, pref) in instructions.iter().zip(prefixes.iter()) {
426                            new_candidates.push((inst.clone(), pref.clone()));
427                        }
428                    }
429
430                    // Add to all candidates
431                    all_candidates
432                        .get_mut(predictor_name)
433                        .unwrap()
434                        .extend(new_candidates.clone());
435                    new_latest_candidates.insert(predictor_name.clone(), new_candidates);
436                }
437            }
438
439            latest_candidates = new_latest_candidates;
440        }
441
442        // Find best overall candidate and update module
443        let mut best_overall: Option<(String, Candidate)> = None;
444
445        for (predictor_name, candidates_map) in &evaluated_candidates {
446            if let Some(best) = candidates_map
447                .values()
448                .max_by(|a, b| a.score.partial_cmp(&b.score).unwrap())
449                && (best_overall.is_none() || best.score > best_overall.as_ref().unwrap().1.score)
450            {
451                best_overall = Some((predictor_name.clone(), best.clone()));
452            }
453        }
454
455        // Update original module with best candidates
456        if let Some((_, best_candidate)) = best_overall {
457            let module_predictors = module.parameters();
458            for (predictor_name, predictor) in module_predictors {
459                if let Some(best) = evaluated_candidates.get(&predictor_name).and_then(|m| {
460                    m.values()
461                        .max_by(|a, b| a.score.partial_cmp(&b.score).unwrap())
462                }) {
463                    predictor.update_signature_instruction(best.instruction.clone())?;
464                }
465            }
466
467            if self.track_stats {
468                println!("\n=== Optimization Complete ===");
469                println!("Total calls: {}", stats.total_calls);
470                println!("Best score: {:.3}", best_candidate.score);
471                println!("Best instruction: {}", best_candidate.instruction);
472                if !best_candidate.prefix.is_empty() {
473                    println!("Best prefix: {}", best_candidate.prefix);
474                }
475            }
476        }
477
478        Ok(())
479    }
480}