08_optimize_mipro/
08-optimize-mipro.rs

1/*
2Example: Optimize a QA module using MIPROv2
3
4This example demonstrates the advanced MIPROv2 optimizer, which uses a 3-stage process:
51. Generate traces from your training data
62. Use an LLM to generate candidate prompts with best practices
73. Evaluate candidates and select the best one
8
9MIPROv2 is more sophisticated than COPRO and typically produces better results
10by leveraging prompting best practices and program understanding.
11
12Run with:
13```
14cargo run --example 08-optimize-mipro --features dataloaders
15```
16
17Note: The `dataloaders` feature is required for loading datasets.
18*/
19
20use anyhow::Result;
21use bon::Builder;
22use dspy_rs::{
23    ChatAdapter, DataLoader, Evaluator, Example, LM, MIPROv2, Module, Optimizable, Optimizer,
24    Predict, Prediction, Predictor, Signature, configure, example,
25};
26
27#[Signature]
28struct QuestionAnswering {
29    /// Answer the question accurately and concisely.
30
31    #[input]
32    pub question: String,
33
34    #[output]
35    pub answer: String,
36}
37
38#[derive(Builder, Optimizable)]
39pub struct SimpleQA {
40    #[parameter]
41    #[builder(default = Predict::new(QuestionAnswering::new()))]
42    pub answerer: Predict,
43}
44
45impl Module for SimpleQA {
46    async fn forward(&self, inputs: Example) -> Result<Prediction> {
47        self.answerer.forward(inputs).await
48    }
49}
50
51impl Evaluator for SimpleQA {
52    async fn metric(&self, example: &Example, prediction: &Prediction) -> f32 {
53        let expected = example
54            .data
55            .get("answer")
56            .and_then(|v| v.as_str())
57            .unwrap_or("");
58        let predicted = prediction
59            .data
60            .get("answer")
61            .and_then(|v| v.as_str())
62            .unwrap_or("");
63
64        // Normalize and compare
65        let expected_normalized = expected.to_lowercase().trim().to_string();
66        let predicted_normalized = predicted.to_lowercase().trim().to_string();
67
68        if expected_normalized == predicted_normalized {
69            1.0
70        } else {
71            // Partial credit for substring matches
72            if expected_normalized.contains(&predicted_normalized)
73                || predicted_normalized.contains(&expected_normalized)
74            {
75                0.5
76            } else {
77                0.0
78            }
79        }
80    }
81}
82
83#[tokio::main]
84async fn main() -> Result<()> {
85    println!("=== MIPROv2 Optimizer Example ===\n");
86
87    // Configure the LM
88    configure(LM::default(), ChatAdapter);
89
90    // Load training data from HuggingFace
91    println!("Loading training data from HuggingFace...");
92    let train_examples = DataLoader::load_hf(
93        "hotpotqa/hotpot_qa",
94        vec!["question".to_string()],
95        vec!["answer".to_string()],
96        "fullwiki",
97        "validation",
98        true,
99    )?;
100
101    // Use a small subset for faster optimization
102    let train_subset = train_examples[..15].to_vec();
103    println!("Using {} training examples\n", train_subset.len());
104
105    // Create the module
106    let mut qa_module = SimpleQA::builder().build();
107
108    // Show initial instruction
109    println!("Initial instruction:");
110    println!(
111        "  \"{}\"\n",
112        qa_module.answerer.get_signature().instruction()
113    );
114
115    // Test baseline performance
116    println!("Evaluating baseline performance...");
117    let baseline_score = qa_module.evaluate(train_subset[..5].to_vec()).await;
118    println!("Baseline score: {:.3}\n", baseline_score);
119
120    // Create MIPROv2 optimizer
121    let optimizer = MIPROv2::builder()
122        .num_candidates(8) // Generate 8 candidate prompts
123        .num_trials(15) // Run 15 evaluation trials
124        .minibatch_size(10) // Evaluate on 10 examples per candidate
125        .temperature(1.0) // Temperature for prompt generation
126        .track_stats(true) // Display detailed statistics
127        .build();
128
129    // Optimize the module
130    println!("Starting MIPROv2 optimization...");
131    println!("This will:");
132    println!("  1. Generate execution traces");
133    println!("  2. Create a program description using LLM");
134    println!("  3. Generate {} candidate prompts with best practices", 8);
135    println!("  4. Evaluate each candidate");
136    println!("  5. Select and apply the best prompt\n");
137
138    optimizer
139        .compile(&mut qa_module, train_subset.clone())
140        .await?;
141
142    // Show optimized instruction
143    println!("\nOptimized instruction:");
144    println!(
145        "  \"{}\"\n",
146        qa_module.answerer.get_signature().instruction()
147    );
148
149    // Test optimized performance
150    println!("Evaluating optimized performance...");
151    let optimized_score = qa_module.evaluate(train_subset[..5].to_vec()).await;
152    println!("Optimized score: {:.3}", optimized_score);
153
154    // Show improvement
155    let improvement = ((optimized_score - baseline_score) / baseline_score) * 100.0;
156    println!(
157        "\n✓ Improvement: {:.1}% ({:.3} -> {:.3})",
158        improvement, baseline_score, optimized_score
159    );
160
161    // Test on a new example
162    println!("\n--- Testing on a new example ---");
163    let test_example = example! {
164        "question": "input" => "What is the capital of France?",
165    };
166
167    let result = qa_module.forward(test_example).await?;
168    println!("Question: What is the capital of France?");
169    println!("Answer: {}", result.get("answer", None));
170
171    println!("\n=== Example Complete ===");
172    Ok(())
173}