08_optimize_mipro/
08-optimize-mipro.rs1use anyhow::Result;
21use bon::Builder;
22use dspy_rs::{
23 ChatAdapter, DataLoader, Evaluator, Example, LM, MIPROv2, Module, Optimizable, Optimizer,
24 Predict, Prediction, Predictor, Signature, configure, example,
25};
26
27#[Signature]
28struct QuestionAnswering {
29 #[input]
32 pub question: String,
33
34 #[output]
35 pub answer: String,
36}
37
38#[derive(Builder, Optimizable)]
39pub struct SimpleQA {
40 #[parameter]
41 #[builder(default = Predict::new(QuestionAnswering::new()))]
42 pub answerer: Predict,
43}
44
45impl Module for SimpleQA {
46 async fn forward(&self, inputs: Example) -> Result<Prediction> {
47 self.answerer.forward(inputs).await
48 }
49}
50
51impl Evaluator for SimpleQA {
52 async fn metric(&self, example: &Example, prediction: &Prediction) -> f32 {
53 let expected = example
54 .data
55 .get("answer")
56 .and_then(|v| v.as_str())
57 .unwrap_or("");
58 let predicted = prediction
59 .data
60 .get("answer")
61 .and_then(|v| v.as_str())
62 .unwrap_or("");
63
64 let expected_normalized = expected.to_lowercase().trim().to_string();
66 let predicted_normalized = predicted.to_lowercase().trim().to_string();
67
68 if expected_normalized == predicted_normalized {
69 1.0
70 } else {
71 if expected_normalized.contains(&predicted_normalized)
73 || predicted_normalized.contains(&expected_normalized)
74 {
75 0.5
76 } else {
77 0.0
78 }
79 }
80 }
81}
82
83#[tokio::main]
84async fn main() -> Result<()> {
85 println!("=== MIPROv2 Optimizer Example ===\n");
86
87 configure(LM::default(), ChatAdapter);
89
90 println!("Loading training data from HuggingFace...");
92 let train_examples = DataLoader::load_hf(
93 "hotpotqa/hotpot_qa",
94 vec!["question".to_string()],
95 vec!["answer".to_string()],
96 "fullwiki",
97 "validation",
98 true,
99 )?;
100
101 let train_subset = train_examples[..15].to_vec();
103 println!("Using {} training examples\n", train_subset.len());
104
105 let mut qa_module = SimpleQA::builder().build();
107
108 println!("Initial instruction:");
110 println!(
111 " \"{}\"\n",
112 qa_module.answerer.get_signature().instruction()
113 );
114
115 println!("Evaluating baseline performance...");
117 let baseline_score = qa_module.evaluate(train_subset[..5].to_vec()).await;
118 println!("Baseline score: {:.3}\n", baseline_score);
119
120 let optimizer = MIPROv2::builder()
122 .num_candidates(8) .num_trials(15) .minibatch_size(10) .temperature(1.0) .track_stats(true) .build();
128
129 println!("Starting MIPROv2 optimization...");
131 println!("This will:");
132 println!(" 1. Generate execution traces");
133 println!(" 2. Create a program description using LLM");
134 println!(" 3. Generate {} candidate prompts with best practices", 8);
135 println!(" 4. Evaluate each candidate");
136 println!(" 5. Select and apply the best prompt\n");
137
138 optimizer
139 .compile(&mut qa_module, train_subset.clone())
140 .await?;
141
142 println!("\nOptimized instruction:");
144 println!(
145 " \"{}\"\n",
146 qa_module.answerer.get_signature().instruction()
147 );
148
149 println!("Evaluating optimized performance...");
151 let optimized_score = qa_module.evaluate(train_subset[..5].to_vec()).await;
152 println!("Optimized score: {:.3}", optimized_score);
153
154 let improvement = ((optimized_score - baseline_score) / baseline_score) * 100.0;
156 println!(
157 "\n✓ Improvement: {:.1}% ({:.3} -> {:.3})",
158 improvement, baseline_score, optimized_score
159 );
160
161 println!("\n--- Testing on a new example ---");
163 let test_example = example! {
164 "question": "input" => "What is the capital of France?",
165 };
166
167 let result = qa_module.forward(test_example).await?;
168 println!("Question: What is the capital of France?");
169 println!("Answer: {}", result.get("answer", None));
170
171 println!("\n=== Example Complete ===");
172 Ok(())
173}