10_gepa_llm_judge/
10-gepa-llm-judge.rs1use anyhow::Result;
12use bon::Builder;
13use dspy_rs::*;
14use dsrs_macros::{Optimizable, Signature};
15use std::sync::Arc;
16
17#[Signature(cot)]
22struct MathWordProblem {
23 #[input]
26 pub problem: String,
27
28 #[output]
29 pub reasoning: String,
30
31 #[output]
32 pub answer: String,
33}
34
35#[Signature]
40struct MathJudge {
41 #[input(desc = "The math problem that was given")]
46 pub problem: String,
47
48 #[input(desc = "The expected correct answer")]
49 pub expected_answer: String,
50
51 #[input(desc = "The student's answer")]
52 pub student_answer: String,
53
54 #[input(desc = "The student's reasoning/work shown")]
55 pub student_reasoning: String,
56
57 #[output(desc = "Detailed evaluation of the work")]
58 pub evaluation: String,
59}
60
61#[derive(Builder, Optimizable)]
66struct MathSolver {
67 #[parameter]
69 solver: Predict,
70
71 judge: Predict,
73
74 judge_lm: Arc<LM>,
76}
77
78impl Module for MathSolver {
79 async fn forward(&self, inputs: Example) -> Result<Prediction> {
80 self.solver.forward(inputs).await
82 }
83}
84
85impl Evaluator for MathSolver {
90 async fn metric(&self, example: &Example, prediction: &Prediction) -> f32 {
91 let feedback = self.feedback_metric(example, prediction).await;
93 feedback.score
94 }
95}
96
97impl FeedbackEvaluator for MathSolver {
102 async fn feedback_metric(&self, example: &Example, prediction: &Prediction) -> FeedbackMetric {
103 let problem = example
105 .get("problem", None)
106 .as_str()
107 .unwrap_or("")
108 .to_string();
109
110 let expected = example
111 .get("expected_answer", None)
112 .as_str()
113 .unwrap_or("")
114 .to_string();
115
116 let student_answer = prediction
117 .get("answer", None)
118 .as_str()
119 .unwrap_or("")
120 .to_string();
121
122 let student_reasoning = prediction
123 .get("reasoning", None)
124 .as_str()
125 .unwrap_or("No reasoning provided")
126 .to_string();
127
128 let answer_matches = student_answer.trim() == expected.trim();
130
131 let judge_input = example! {
134 "problem": "input" => &problem,
135 "expected_answer": "input" => &expected,
136 "student_answer": "input" => &student_answer,
137 "student_reasoning": "input" => &student_reasoning
138 };
139
140 let judge_output = match self
141 .judge
142 .forward_with_config(judge_input, Arc::clone(&self.judge_lm))
143 .await
144 {
145 Ok(output) => output,
146 Err(_) => {
147 let score = if answer_matches { 1.0 } else { 0.0 };
149 let simple_feedback = format!(
150 "Problem: {}\nExpected: {}\nPredicted: {}\nAnswer: {}",
151 problem,
152 expected,
153 student_answer,
154 if answer_matches {
155 "CORRECT"
156 } else {
157 "INCORRECT"
158 }
159 );
160 return FeedbackMetric::new(score, simple_feedback);
161 }
162 };
163
164 let judge_evaluation = judge_output
165 .get("evaluation", None)
166 .as_str()
167 .unwrap_or("Unable to evaluate")
168 .to_string();
169
170 let score = if answer_matches {
173 if judge_evaluation.to_lowercase().contains("sound reasoning")
175 || judge_evaluation.to_lowercase().contains("correct approach")
176 {
177 1.0 } else {
179 0.7 }
181 } else {
182 if judge_evaluation.to_lowercase().contains("correct approach")
184 || judge_evaluation.to_lowercase().contains("good start")
185 {
186 0.3 } else {
188 0.0 }
190 };
191
192 let mut feedback = String::new();
195
196 feedback.push_str(&format!("Problem: {}\n", problem));
197 feedback.push_str(&format!("Expected: {}\n", expected));
198 feedback.push_str(&format!("Predicted: {}\n", student_answer));
199
200 if answer_matches {
201 feedback.push_str("Answer: CORRECT\n\n");
202 } else {
203 feedback.push_str("Answer: INCORRECT\n\n");
204 }
205
206 feedback.push_str("Reasoning Quality Analysis:\n");
207 feedback.push_str(&judge_evaluation);
208
209 FeedbackMetric::new(score, feedback)
211 }
212}
213
214#[tokio::main]
219async fn main() -> Result<()> {
220 println!("GEPA with LLM-as-a-Judge Example\n");
221 println!("This example shows how to use an LLM judge to automatically");
222 println!("generate rich feedback for optimizing a math solver.\n");
223
224 let task_lm = LM::builder().temperature(0.7).build().await.unwrap();
227
228 let judge_lm = LM::builder().temperature(0.3).build().await.unwrap();
230
231 configure(task_lm, ChatAdapter);
232
233 let trainset = vec![
235 example! {
236 "problem": "input" => "Sarah has 12 apples. She gives 3 to her friend and buys 5 more. How many apples does she have now?",
237 "expected_answer": "input" => "14"
238 },
239 example! {
240 "problem": "input" => "A train travels 60 miles in 1 hour. How far will it travel in 3.5 hours at the same speed?",
241 "expected_answer": "input" => "210"
242 },
243 example! {
244 "problem": "input" => "There are 24 students in a class. If 1/3 of them are absent, how many students are present?",
245 "expected_answer": "input" => "16"
246 },
247 example! {
248 "problem": "input" => "A rectangle has length 8 cm and width 5 cm. What is its area?",
249 "expected_answer": "input" => "40"
250 },
251 example! {
252 "problem": "input" => "John has $50. He spends $12 on lunch and $8 on a book. How much money does he have left?",
253 "expected_answer": "input" => "30"
254 },
255 ];
256
257 let mut module = MathSolver::builder()
259 .solver(Predict::new(MathWordProblem::new()))
260 .judge(Predict::new(MathJudge::new()))
261 .judge_lm(Arc::new(judge_lm))
262 .build();
263
264 println!("Step 1: Baseline Performance");
266 println!("Testing the solver before optimization...\n");
267 let baseline_score = module.evaluate(trainset.clone()).await;
268 println!(" Baseline average score: {:.3}\n", baseline_score);
269
270 println!("Step 2: Configure GEPA");
272 println!("Setting up the optimizer with budget controls...\n");
273
274 let gepa = GEPA::builder()
275 .num_iterations(3) .minibatch_size(3) .temperature(0.9)
278 .track_stats(true)
279 .maybe_max_lm_calls(Some(100)) .build();
281
282 println!("Step 3: Run GEPA Optimization");
284 println!("The judge will analyze reasoning quality and provide feedback...\n");
285
286 let result = gepa
287 .compile_with_feedback(&mut module, trainset.clone())
288 .await?;
289
290 println!("\nStep 4: Results");
292 println!("===============\n");
293 println!("Optimization complete!");
294 println!(
295 " Best average score: {:.3}",
296 result.best_candidate.average_score()
297 );
298 println!(
299 " Improvement: {:.3}",
300 result.best_candidate.average_score() - baseline_score
301 );
302 println!(" Total rollouts: {}", result.total_rollouts);
303 println!(
304 " Total LM calls: {} (includes judge evaluations)",
305 result.total_lm_calls
306 );
307
308 println!("\nEvolution over time:");
309 for (generation, score) in &result.evolution_history {
310 println!(" Generation {}: {:.3}", generation, score);
311 }
312
313 println!("\nOptimized instruction:");
314 println!(" {}", result.best_candidate.instruction);
315
316 println!("\nStep 5: Test Optimized Solver");
318 println!("==============================\n");
319
320 let test_problem = example! {
321 "problem": "input" => "A store sells pencils for $0.25 each. If you buy 8 pencils, how much will you pay?",
322 "expected_answer": "input" => "2"
323 };
324
325 let test_prediction = module.forward(test_problem.clone()).await?;
326 let test_feedback = module
327 .feedback_metric(&test_problem, &test_prediction)
328 .await;
329
330 println!(
331 "Test problem: A store sells pencils for $0.25 each. If you buy 8 pencils, how much will you pay?"
332 );
333 println!("\nAnswer: {}", test_prediction.get("answer", None));
334 println!("Score: {:.3}\n", test_feedback.score);
335 println!("Detailed Feedback from Judge:");
336 println!("{}", test_feedback.feedback);
337
338 Ok(())
339}