dspy_rs/optimizer/
mipro.rs

1use crate as dspy_rs;
2/// MIPROv2 Optimizer Implementation
3///
4/// Multi-prompt Instruction Proposal Optimizer (MIPROv2) is an advanced optimizer
5/// that automatically generates and evaluates candidate prompts using LLMs.
6///
7/// ## Three-Stage Process
8///
9/// 1. **Trace Generation**: Runs the module with training data to generate execution traces
10/// 2. **Prompt Generation**: Uses an LLM to generate candidate prompts based on:
11///    - Program descriptions (LLM-generated)
12///    - Execution traces
13///    - Prompting tips library
14/// 3. **Evaluation & Combination**: Evaluates candidates in batches and combines best components
15use crate::{
16    Evaluator, Example, LM, Module, Optimizable, Optimizer, Predict, Prediction, Predictor,
17    example, get_lm,
18};
19use anyhow::{Context, Result};
20use bon::Builder;
21use dsrs_macros::Signature;
22use std::sync::Arc;
23
24// ============================================================================
25// Signature Definitions for LLM-based Prompt Generation
26// ============================================================================
27
28#[Signature]
29struct GenerateProgramDescription {
30    /// You are an expert at understanding and describing programs. Given a task signature with input and output fields, and some example traces, generate a clear and concise description of what the program does.
31
32    #[input(desc = "The task signature showing input and output fields")]
33    pub signature_fields: String,
34
35    #[input(desc = "Example input-output traces from the program")]
36    pub example_traces: String,
37
38    #[output(desc = "A clear description of what the program does")]
39    pub program_description: String,
40}
41
42#[Signature]
43struct GenerateInstructionFromTips {
44    /// You are an expert prompt engineer. Given a program description, example traces, and a collection of prompting best practices, generate an effective instruction that will help a language model perform this task well.
45    ///
46    /// Be creative and consider various prompting techniques like chain-of-thought, few-shot examples, role-playing, and output formatting.
47
48    #[input(desc = "Description of what the program should do")]
49    pub program_description: String,
50
51    #[input(desc = "Example input-output traces showing desired behavior")]
52    pub example_traces: String,
53
54    #[input(desc = "Best practices and tips for writing effective prompts")]
55    pub prompting_tips: String,
56
57    #[output(desc = "An optimized instruction for the language model")]
58    pub instruction: String,
59}
60
61// ============================================================================
62// Core Data Structures
63// ============================================================================
64
65/// Represents a single execution trace of the program
66#[derive(Clone, Debug)]
67pub struct Trace {
68    /// Input example
69    pub inputs: Example,
70    /// Output prediction
71    pub outputs: Prediction,
72    /// Evaluation score (if available)
73    pub score: Option<f32>,
74}
75
76impl Trace {
77    /// Creates a new trace
78    pub fn new(inputs: Example, outputs: Prediction, score: Option<f32>) -> Self {
79        Self {
80            inputs,
81            outputs,
82            score,
83        }
84    }
85
86    /// Formats the trace as a human-readable string for LLM consumption
87    pub fn format_for_prompt(&self) -> String {
88        let mut result = String::new();
89        result.push_str("Input:\n");
90
91        for (key, value) in &self.inputs.data {
92            result.push_str(&format!("  {}: {}\n", key, value));
93        }
94
95        result.push_str("Output:\n");
96        for (key, value) in &self.outputs.data {
97            result.push_str(&format!("  {}: {}\n", key, value));
98        }
99
100        if let Some(score) = self.score {
101            result.push_str(&format!("Score: {:.3}\n", score));
102        }
103
104        result
105    }
106}
107
108/// Represents a candidate prompt with its associated examples and score
109#[derive(Clone, Debug)]
110pub struct PromptCandidate {
111    /// The instruction text
112    pub instruction: String,
113    /// Few-shot demonstration examples (reserved for future enhancement)
114    #[allow(dead_code)]
115    pub demos: Vec<Example>,
116    /// Evaluation score
117    pub score: f32,
118}
119
120impl PromptCandidate {
121    /// Creates a new candidate with default score
122    pub fn new(instruction: String, demos: Vec<Example>) -> Self {
123        Self {
124            instruction,
125            demos,
126            score: 0.0,
127        }
128    }
129
130    /// Updates the candidate's score
131    pub fn with_score(mut self, score: f32) -> Self {
132        self.score = score;
133        self
134    }
135}
136
137/// Library of prompting tips and best practices
138pub struct PromptingTips {
139    pub tips: Vec<String>,
140}
141
142impl PromptingTips {
143    /// Creates a new prompting tips library with default tips
144    pub fn default_tips() -> Self {
145        Self {
146            tips: vec![
147                "Use clear and specific language".to_string(),
148                "Provide context about the task domain".to_string(),
149                "Specify the desired output format".to_string(),
150                "Use chain-of-thought reasoning for complex tasks".to_string(),
151                "Include few-shot examples when helpful".to_string(),
152                "Break down complex instructions into steps".to_string(),
153                "Use role-playing (e.g., 'You are an expert...') when appropriate".to_string(),
154                "Specify constraints and edge cases".to_string(),
155                "Request explanations or reasoning when needed".to_string(),
156                "Use structured output formats (JSON, lists, etc.) when applicable".to_string(),
157                "Consider the model's strengths and limitations".to_string(),
158                "Be explicit about what to avoid or exclude".to_string(),
159                "Use positive framing (what to do vs. what not to do)".to_string(),
160                "Provide examples of both correct and incorrect outputs when useful".to_string(),
161                "Use delimiters or markers to separate different sections".to_string(),
162            ],
163        }
164    }
165
166    /// Formats tips as a string for LLM consumption
167    pub fn format_for_prompt(&self) -> String {
168        self.tips
169            .iter()
170            .enumerate()
171            .map(|(i, tip)| format!("{}. {}", i + 1, tip))
172            .collect::<Vec<_>>()
173            .join("\n")
174    }
175}
176
177// ============================================================================
178// MIPROv2 Optimizer
179// ============================================================================
180
181/// MIPROv2 (Multi-prompt Instruction Proposal Optimizer v2)
182///
183/// An advanced optimizer that uses LLMs to automatically generate and refine
184/// prompts based on program traces, descriptions, and prompting best practices.
185#[derive(Builder)]
186pub struct MIPROv2 {
187    /// Number of candidate prompts to generate per iteration
188    #[builder(default = 10)]
189    pub num_candidates: usize,
190
191    /// Maximum number of bootstrapped (generated) demos to include
192    #[builder(default = 3)]
193    pub max_bootstrapped_demos: usize,
194
195    /// Maximum number of labeled demos to include from training set
196    #[builder(default = 3)]
197    pub max_labeled_demos: usize,
198
199    /// Number of evaluation trials (iterations)
200    #[builder(default = 20)]
201    pub num_trials: usize,
202
203    /// Size of minibatch for evaluation
204    #[builder(default = 25)]
205    pub minibatch_size: usize,
206
207    /// Temperature for prompt generation
208    #[builder(default = 1.0)]
209    pub temperature: f32,
210
211    /// Optional separate LM for prompt generation (defaults to global LM)
212    pub prompt_model: Option<LM>,
213
214    /// Track and display statistics
215    #[builder(default = true)]
216    pub track_stats: bool,
217
218    /// Random seed for reproducibility
219    pub seed: Option<u64>,
220}
221
222impl MIPROv2 {
223    // ========================================================================
224    // Stage 1: Trace Generation
225    // ========================================================================
226
227    /// Generates execution traces by running the module on training examples
228    async fn generate_traces<M>(&self, module: &M, examples: &[Example]) -> Result<Vec<Trace>>
229    where
230        M: Module + Evaluator,
231    {
232        let mut traces = Vec::with_capacity(examples.len());
233
234        println!(
235            "Stage 1: Generating traces from {} examples",
236            examples.len()
237        );
238
239        for (idx, example) in examples.iter().enumerate() {
240            if idx % 10 == 0 {
241                println!("  Processing example {}/{}", idx + 1, examples.len());
242            }
243
244            // Run forward pass
245            let prediction = module
246                .forward(example.clone())
247                .await
248                .context("Failed to generate prediction for trace")?;
249
250            // Evaluate the prediction
251            let score = module.metric(example, &prediction).await;
252
253            traces.push(Trace::new(example.clone(), prediction, Some(score)));
254        }
255
256        println!("Generated {} traces", traces.len());
257        Ok(traces)
258    }
259
260    /// Selects the best traces based on their scores
261    pub fn select_best_traces(&self, traces: &[Trace], num_select: usize) -> Vec<Trace> {
262        let mut scored_traces: Vec<_> = traces
263            .iter()
264            .filter(|t| t.score.is_some())
265            .cloned()
266            .collect();
267
268        // Sort by score descending
269        scored_traces.sort_by(|a, b| {
270            b.score
271                .partial_cmp(&a.score)
272                .unwrap_or(std::cmp::Ordering::Equal)
273        });
274
275        scored_traces.into_iter().take(num_select).collect()
276    }
277
278    // ========================================================================
279    // Stage 2: Candidate Prompt Generation
280    // ========================================================================
281
282    /// Generates a program description using an LLM
283    async fn generate_program_description(
284        &self,
285        signature_desc: &str,
286        traces: &[Trace],
287    ) -> Result<String> {
288        let description_generator = Predict::new(GenerateProgramDescription::new());
289
290        // Format traces for the prompt
291        let traces_str = traces
292            .iter()
293            .take(5) // Use first 5 traces
294            .map(|t| t.format_for_prompt())
295            .collect::<Vec<_>>()
296            .join("\n---\n");
297
298        let input = example! {
299            "signature_fields": "input" => signature_desc.to_string(),
300            "example_traces": "input" => traces_str,
301        };
302
303        let prediction = if let Some(mut pm) = self.prompt_model.clone() {
304            pm.temperature = 0.7;
305            description_generator
306                .forward_with_config(input, Arc::new(pm))
307                .await?
308        } else {
309            let lm = get_lm();
310            description_generator.forward_with_config(input, lm).await?
311        };
312
313        Ok(prediction
314            .data
315            .get("program_description")
316            .and_then(|v| v.as_str())
317            .unwrap_or("Generate accurate outputs for the given inputs.")
318            .to_string())
319    }
320
321    /// Generates candidate instructions using LLM with prompting tips
322    async fn generate_candidate_instructions(
323        &self,
324        program_description: &str,
325        traces: &[Trace],
326        num_candidates: usize,
327    ) -> Result<Vec<String>> {
328        let instruction_generator = Predict::new(GenerateInstructionFromTips::new());
329        let tips = PromptingTips::default_tips();
330
331        // Format traces
332        let traces_str = traces
333            .iter()
334            .take(8)
335            .map(|t| t.format_for_prompt())
336            .collect::<Vec<_>>()
337            .join("\n---\n");
338
339        println!(
340            "Stage 2: Generating {} candidate instructions",
341            num_candidates
342        );
343
344        let mut candidates = Vec::new();
345
346        // Generate candidates sequentially (simpler and avoids lifetime issues)
347        for i in 0..num_candidates {
348            let input = example! {
349                "program_description": "input" => program_description.to_string(),
350                "example_traces": "input" => traces_str.clone(),
351                "prompting_tips": "input" => tips.format_for_prompt(),
352            };
353
354            let result = if let Some(mut pm) = self.prompt_model.clone() {
355                pm.temperature = self.temperature;
356                instruction_generator
357                    .forward_with_config(input, Arc::new(pm))
358                    .await
359            } else {
360                let lm = get_lm();
361                instruction_generator.forward_with_config(input, lm).await
362            };
363
364            if let Ok(pred) = result
365                && let Some(instruction) = pred.data.get("instruction").and_then(|v| v.as_str())
366            {
367                candidates.push(instruction.to_string());
368            }
369
370            if (i + 1) % 3 == 0 || i == num_candidates - 1 {
371                println!(
372                    "  Generated {}/{} candidates",
373                    candidates.len(),
374                    num_candidates
375                );
376            }
377        }
378
379        println!(
380            "Generated {} total candidate instructions",
381            candidates.len()
382        );
383        Ok(candidates)
384    }
385
386    /// Creates prompt candidates by pairing instructions with demo selections
387    pub fn create_prompt_candidates(
388        &self,
389        instructions: Vec<String>,
390        traces: &[Trace],
391    ) -> Vec<PromptCandidate> {
392        let best_traces = self.select_best_traces(traces, self.max_labeled_demos);
393        let demo_examples: Vec<Example> = best_traces.into_iter().map(|t| t.inputs).collect();
394
395        instructions
396            .into_iter()
397            .map(|inst| PromptCandidate::new(inst, demo_examples.clone()))
398            .collect()
399    }
400
401    // ========================================================================
402    // Stage 3: Evaluation and Selection
403    // ========================================================================
404
405    /// Evaluates a single prompt candidate
406    async fn evaluate_candidate<M>(
407        &self,
408        module: &mut M,
409        candidate: &PromptCandidate,
410        eval_examples: &[Example],
411        predictor_name: &str,
412    ) -> Result<f32>
413    where
414        M: Module + Optimizable + Evaluator,
415    {
416        // Update module with candidate instruction
417        {
418            let mut params = module.parameters();
419            if let Some(predictor) = params.get_mut(predictor_name) {
420                predictor.update_signature_instruction(candidate.instruction.clone())?;
421
422                // Note: Demo setting would require mutable signature access
423                // This is a design consideration for future enhancement
424            }
425        }
426
427        // Evaluate on minibatch
428        let minibatch: Vec<Example> = eval_examples
429            .iter()
430            .take(self.minibatch_size)
431            .cloned()
432            .collect();
433
434        let score = module.evaluate(minibatch).await;
435        Ok(score)
436    }
437
438    /// Evaluates all candidates and returns the best one
439    async fn evaluate_and_select_best<M>(
440        &self,
441        module: &mut M,
442        candidates: Vec<PromptCandidate>,
443        eval_examples: &[Example],
444        predictor_name: &str,
445    ) -> Result<PromptCandidate>
446    where
447        M: Module + Optimizable + Evaluator,
448    {
449        println!(
450            "Stage 3: Evaluating {} candidates on minibatch of {} examples",
451            candidates.len(),
452            self.minibatch_size.min(eval_examples.len())
453        );
454
455        let mut evaluated_candidates = Vec::new();
456
457        for (idx, candidate) in candidates.into_iter().enumerate() {
458            println!("  Evaluating candidate {}/{}", idx + 1, self.num_candidates);
459
460            let score = self
461                .evaluate_candidate(module, &candidate, eval_examples, predictor_name)
462                .await?;
463
464            evaluated_candidates.push(candidate.with_score(score));
465
466            if self.track_stats {
467                println!("    Score: {:.3}", score);
468            }
469        }
470
471        // Find best candidate
472        let best = evaluated_candidates
473            .into_iter()
474            .max_by(|a, b| {
475                a.score
476                    .partial_cmp(&b.score)
477                    .unwrap_or(std::cmp::Ordering::Equal)
478            })
479            .context("No candidates to evaluate")?;
480
481        println!("Best candidate score: {:.3}", best.score);
482        Ok(best)
483    }
484
485    // ========================================================================
486    // Helper Methods
487    // ========================================================================
488
489    /// Formats signature fields as a string
490    pub fn format_signature_fields(&self, signature: &dyn crate::core::MetaSignature) -> String {
491        let mut result = String::new();
492
493        result.push_str("Input Fields:\n");
494        if let Some(obj) = signature.input_fields().as_object() {
495            for (name, field) in obj {
496                let desc = field
497                    .get("desc")
498                    .and_then(|v| v.as_str())
499                    .unwrap_or("No description");
500                result.push_str(&format!("  - {}: {}\n", name, desc));
501            }
502        }
503
504        result.push_str("\nOutput Fields:\n");
505        if let Some(obj) = signature.output_fields().as_object() {
506            for (name, field) in obj {
507                let desc = field
508                    .get("desc")
509                    .and_then(|v| v.as_str())
510                    .unwrap_or("No description");
511                result.push_str(&format!("  - {}: {}\n", name, desc));
512            }
513        }
514
515        result
516    }
517}
518
519// ============================================================================
520// Optimizer Trait Implementation
521// ============================================================================
522
523impl Optimizer for MIPROv2 {
524    async fn compile<M>(&self, module: &mut M, trainset: Vec<Example>) -> Result<()>
525    where
526        M: Module + Optimizable + Evaluator,
527    {
528        println!("\n=== MIPROv2 Optimization Started ===");
529        println!("Configuration:");
530        println!("  Candidates: {}", self.num_candidates);
531        println!("  Trials: {}", self.num_trials);
532        println!("  Minibatch size: {}", self.minibatch_size);
533        println!("  Training examples: {}", trainset.len());
534
535        // Get predictor information
536        let predictor_names: Vec<String> = module.parameters().keys().cloned().collect();
537
538        if predictor_names.is_empty() {
539            return Err(anyhow::anyhow!("No optimizable parameters found in module"));
540        }
541
542        println!(
543            "  Optimizing {} predictor(s): {:?}\n",
544            predictor_names.len(),
545            predictor_names
546        );
547
548        // Optimize each predictor
549        for predictor_name in predictor_names {
550            println!("--- Optimizing predictor: {} ---", predictor_name);
551
552            // Get signature for this predictor
553            let signature_desc = {
554                let params = module.parameters();
555                if let Some(predictor) = params.get(&predictor_name) {
556                    self.format_signature_fields(predictor.get_signature())
557                } else {
558                    continue;
559                }
560            };
561
562            // Stage 1: Generate traces
563            let traces = self.generate_traces(module, &trainset).await?;
564
565            // Stage 2: Generate candidates
566            let program_description = self
567                .generate_program_description(&signature_desc, &traces)
568                .await?;
569
570            println!("Generated program description: {}", program_description);
571
572            let instructions = self
573                .generate_candidate_instructions(&program_description, &traces, self.num_candidates)
574                .await?;
575
576            let candidates = self.create_prompt_candidates(instructions, &traces);
577
578            // Stage 3: Evaluate and select best
579            let best_candidate = self
580                .evaluate_and_select_best(module, candidates, &trainset, &predictor_name)
581                .await?;
582
583            // Apply best candidate
584            {
585                let mut params = module.parameters();
586                if let Some(predictor) = params.get_mut(&predictor_name) {
587                    predictor.update_signature_instruction(best_candidate.instruction.clone())?;
588                    // Note: Demo setting would require mutable signature access
589                    // This is a design consideration for future enhancement
590                }
591            }
592
593            println!(
594                "✓ Optimized {} with score {:.3}",
595                predictor_name, best_candidate.score
596            );
597            println!("  Instruction: {}\n", best_candidate.instruction);
598        }
599
600        println!("=== MIPROv2 Optimization Complete ===\n");
601        Ok(())
602    }
603}