08_optimize_mipro/
08-optimize-mipro.rs

1/*
2Example: Optimize a QA module using MIPROv2
3
4This example demonstrates the advanced MIPROv2 optimizer, which uses a 3-stage process:
51. Generate traces from your training data
62. Use an LLM to generate candidate prompts with best practices
73. Evaluate candidates and select the best one
8
9MIPROv2 is more sophisticated than COPRO and typically produces better results
10by leveraging prompting best practices and program understanding.
11
12Run with:
13```
14cargo run --example 08-optimize-mipro --features parquet
15```
16
17Note: The `parquet` feature is required for loading HuggingFace datasets.
18*/
19
20use anyhow::Result;
21use bon::Builder;
22use dspy_rs::{
23    MIPROv2, ChatAdapter, DataLoader, Evaluator, Example, LM, Module, Optimizable, Optimizer,
24    Predict, Prediction, Predictor, Signature, configure, example,
25};
26use secrecy::SecretString;
27
28#[Signature]
29struct QuestionAnswering {
30    /// Answer the question accurately and concisely.
31
32    #[input]
33    pub question: String,
34
35    #[output]
36    pub answer: String,
37}
38
39#[derive(Builder, Optimizable)]
40pub struct SimpleQA {
41    #[parameter]
42    #[builder(default = Predict::new(QuestionAnswering::new()))]
43    pub answerer: Predict,
44}
45
46impl Module for SimpleQA {
47    async fn forward(&self, inputs: Example) -> Result<Prediction> {
48        self.answerer.forward(inputs).await
49    }
50}
51
52impl Evaluator for SimpleQA {
53    async fn metric(&self, example: &Example, prediction: &Prediction) -> f32 {
54        let expected = example
55            .data
56            .get("answer")
57            .and_then(|v| v.as_str())
58            .unwrap_or("");
59        let predicted = prediction
60            .data
61            .get("answer")
62            .and_then(|v| v.as_str())
63            .unwrap_or("");
64
65        // Normalize and compare
66        let expected_normalized = expected.to_lowercase().trim().to_string();
67        let predicted_normalized = predicted.to_lowercase().trim().to_string();
68
69        if expected_normalized == predicted_normalized {
70            1.0
71        } else {
72            // Partial credit for substring matches
73            if expected_normalized.contains(&predicted_normalized)
74                || predicted_normalized.contains(&expected_normalized)
75            {
76                0.5
77            } else {
78                0.0
79            }
80        }
81    }
82}
83
84#[tokio::main]
85async fn main() -> Result<()> {
86    println!("=== MIPROv2 Optimizer Example ===\n");
87
88    // Configure the LM
89    configure(
90        LM::builder()
91            .api_key(SecretString::from(std::env::var("OPENAI_API_KEY")?))
92            .build(),
93        ChatAdapter {},
94    );
95
96    // Load training data from HuggingFace
97    println!("Loading training data from HuggingFace...");
98    let train_examples = DataLoader::load_hf(
99        "hotpotqa/hotpot_qa",
100        vec!["question".to_string()],
101        vec!["answer".to_string()],
102        "fullwiki",
103        "validation",
104        true,
105    )?;
106
107    // Use a small subset for faster optimization
108    let train_subset = train_examples[..15].to_vec();
109    println!("Using {} training examples\n", train_subset.len());
110
111    // Create the module
112    let mut qa_module = SimpleQA::builder().build();
113
114    // Show initial instruction
115    println!("Initial instruction:");
116    println!(
117        "  \"{}\"\n",
118        qa_module.answerer.get_signature().instruction()
119    );
120
121    // Test baseline performance
122    println!("Evaluating baseline performance...");
123    let baseline_score = qa_module.evaluate(train_subset[..5].to_vec()).await;
124    println!("Baseline score: {:.3}\n", baseline_score);
125
126    // Create MIPROv2 optimizer
127    let optimizer = MIPROv2::builder()
128        .num_candidates(8) // Generate 8 candidate prompts
129        .num_trials(15) // Run 15 evaluation trials
130        .minibatch_size(10) // Evaluate on 10 examples per candidate
131        .temperature(1.0) // Temperature for prompt generation
132        .track_stats(true) // Display detailed statistics
133        .build();
134
135    // Optimize the module
136    println!("Starting MIPROv2 optimization...");
137    println!("This will:");
138    println!("  1. Generate execution traces");
139    println!("  2. Create a program description using LLM");
140    println!("  3. Generate {} candidate prompts with best practices", 8);
141    println!("  4. Evaluate each candidate");
142    println!("  5. Select and apply the best prompt\n");
143
144    optimizer.compile(&mut qa_module, train_subset.clone()).await?;
145
146    // Show optimized instruction
147    println!("\nOptimized instruction:");
148    println!(
149        "  \"{}\"\n",
150        qa_module.answerer.get_signature().instruction()
151    );
152
153    // Test optimized performance
154    println!("Evaluating optimized performance...");
155    let optimized_score = qa_module.evaluate(train_subset[..5].to_vec()).await;
156    println!("Optimized score: {:.3}", optimized_score);
157
158    // Show improvement
159    let improvement = ((optimized_score - baseline_score) / baseline_score) * 100.0;
160    println!(
161        "\n✓ Improvement: {:.1}% ({:.3} -> {:.3})",
162        improvement, baseline_score, optimized_score
163    );
164
165    // Test on a new example
166    println!("\n--- Testing on a new example ---");
167    let test_example = example! {
168        "question": "input" => "What is the capital of France?",
169    };
170
171    let result = qa_module.forward(test_example).await?;
172    println!("Question: What is the capital of France?");
173    println!("Answer: {}", result.get("answer", None));
174
175    println!("\n=== Example Complete ===");
176    Ok(())
177}