08_optimize_mipro/
08-optimize-mipro.rs1use anyhow::Result;
21use bon::Builder;
22use dspy_rs::{
23 MIPROv2, ChatAdapter, DataLoader, Evaluator, Example, LM, Module, Optimizable, Optimizer,
24 Predict, Prediction, Predictor, Signature, configure, example,
25};
26use secrecy::SecretString;
27
28#[Signature]
29struct QuestionAnswering {
30 #[input]
33 pub question: String,
34
35 #[output]
36 pub answer: String,
37}
38
39#[derive(Builder, Optimizable)]
40pub struct SimpleQA {
41 #[parameter]
42 #[builder(default = Predict::new(QuestionAnswering::new()))]
43 pub answerer: Predict,
44}
45
46impl Module for SimpleQA {
47 async fn forward(&self, inputs: Example) -> Result<Prediction> {
48 self.answerer.forward(inputs).await
49 }
50}
51
52impl Evaluator for SimpleQA {
53 async fn metric(&self, example: &Example, prediction: &Prediction) -> f32 {
54 let expected = example
55 .data
56 .get("answer")
57 .and_then(|v| v.as_str())
58 .unwrap_or("");
59 let predicted = prediction
60 .data
61 .get("answer")
62 .and_then(|v| v.as_str())
63 .unwrap_or("");
64
65 let expected_normalized = expected.to_lowercase().trim().to_string();
67 let predicted_normalized = predicted.to_lowercase().trim().to_string();
68
69 if expected_normalized == predicted_normalized {
70 1.0
71 } else {
72 if expected_normalized.contains(&predicted_normalized)
74 || predicted_normalized.contains(&expected_normalized)
75 {
76 0.5
77 } else {
78 0.0
79 }
80 }
81 }
82}
83
84#[tokio::main]
85async fn main() -> Result<()> {
86 println!("=== MIPROv2 Optimizer Example ===\n");
87
88 configure(
90 LM::builder()
91 .api_key(SecretString::from(std::env::var("OPENAI_API_KEY")?))
92 .build(),
93 ChatAdapter {},
94 );
95
96 println!("Loading training data from HuggingFace...");
98 let train_examples = DataLoader::load_hf(
99 "hotpotqa/hotpot_qa",
100 vec!["question".to_string()],
101 vec!["answer".to_string()],
102 "fullwiki",
103 "validation",
104 true,
105 )?;
106
107 let train_subset = train_examples[..15].to_vec();
109 println!("Using {} training examples\n", train_subset.len());
110
111 let mut qa_module = SimpleQA::builder().build();
113
114 println!("Initial instruction:");
116 println!(
117 " \"{}\"\n",
118 qa_module.answerer.get_signature().instruction()
119 );
120
121 println!("Evaluating baseline performance...");
123 let baseline_score = qa_module.evaluate(train_subset[..5].to_vec()).await;
124 println!("Baseline score: {:.3}\n", baseline_score);
125
126 let optimizer = MIPROv2::builder()
128 .num_candidates(8) .num_trials(15) .minibatch_size(10) .temperature(1.0) .track_stats(true) .build();
134
135 println!("Starting MIPROv2 optimization...");
137 println!("This will:");
138 println!(" 1. Generate execution traces");
139 println!(" 2. Create a program description using LLM");
140 println!(" 3. Generate {} candidate prompts with best practices", 8);
141 println!(" 4. Evaluate each candidate");
142 println!(" 5. Select and apply the best prompt\n");
143
144 optimizer.compile(&mut qa_module, train_subset.clone()).await?;
145
146 println!("\nOptimized instruction:");
148 println!(
149 " \"{}\"\n",
150 qa_module.answerer.get_signature().instruction()
151 );
152
153 println!("Evaluating optimized performance...");
155 let optimized_score = qa_module.evaluate(train_subset[..5].to_vec()).await;
156 println!("Optimized score: {:.3}", optimized_score);
157
158 let improvement = ((optimized_score - baseline_score) / baseline_score) * 100.0;
160 println!(
161 "\n✓ Improvement: {:.1}% ({:.3} -> {:.3})",
162 improvement, baseline_score, optimized_score
163 );
164
165 println!("\n--- Testing on a new example ---");
167 let test_example = example! {
168 "question": "input" => "What is the capital of France?",
169 };
170
171 let result = qa_module.forward(test_example).await?;
172 println!("Question: What is the capital of France?");
173 println!("Answer: {}", result.get("answer", None));
174
175 println!("\n=== Example Complete ===");
176 Ok(())
177}