10_gepa_llm_judge/
10-gepa-llm-judge.rs

1/// Example: Using LLM-as-a-Judge with GEPA for Math Word Problems
2///
3/// This example demonstrates how to use an LLM judge to automatically generate
4/// rich textual feedback for GEPA optimization. The judge evaluates both the
5/// correctness of answers AND the quality of reasoning.
6///
7/// To run:
8/// ```
9/// OPENAI_API_KEY=your_key cargo run --example 10-gepa-llm-judge
10/// ```
11use anyhow::Result;
12use bon::Builder;
13use dspy_rs::*;
14use dsrs_macros::{Optimizable, Signature};
15use std::sync::Arc;
16
17// ============================================================================
18// Step 1: Define the task signature with chain-of-thought reasoning
19// ============================================================================
20
21#[Signature(cot)]
22struct MathWordProblem {
23    /// Solve the math word problem step by step. Show your work clearly.
24
25    #[input]
26    pub problem: String,
27
28    #[output]
29    pub reasoning: String,
30
31    #[output]
32    pub answer: String,
33}
34
35// ============================================================================
36// Step 2: Define the LLM judge signature
37// ============================================================================
38
39#[Signature]
40struct MathJudge {
41    /// You are an expert math teacher evaluating student work. Analyze both
42    /// the final answer and the reasoning process. Be specific about what
43    /// went wrong or what was done well.
44
45    #[input(desc = "The math problem that was given")]
46    pub problem: String,
47
48    #[input(desc = "The expected correct answer")]
49    pub expected_answer: String,
50
51    #[input(desc = "The student's answer")]
52    pub student_answer: String,
53
54    #[input(desc = "The student's reasoning/work shown")]
55    pub student_reasoning: String,
56
57    #[output(desc = "Detailed evaluation of the work")]
58    pub evaluation: String,
59}
60
61// ============================================================================
62// Step 3: Create the main module with LLM judge
63// ============================================================================
64
65#[derive(Builder, Optimizable)]
66struct MathSolver {
67    // The main predictor we want to optimize
68    #[parameter]
69    solver: Predict,
70
71    // The judge predictor (not optimized, just used for evaluation)
72    judge: Predict,
73
74    // LM for the judge (could be different/cheaper model)
75    judge_lm: Arc<LM>,
76}
77
78impl Module for MathSolver {
79    async fn forward(&self, inputs: Example) -> Result<Prediction> {
80        // Just forward to the solver - judge only used during evaluation
81        self.solver.forward(inputs).await
82    }
83}
84
85// ============================================================================
86// Step 4: Implement regular Evaluator for non-GEPA optimizers
87// ============================================================================
88
89impl Evaluator for MathSolver {
90    async fn metric(&self, example: &Example, prediction: &Prediction) -> f32 {
91        // For regular optimizers, just return scalar score
92        let feedback = self.feedback_metric(example, prediction).await;
93        feedback.score
94    }
95}
96
97// ============================================================================
98// Step 5: Implement FeedbackEvaluator with LLM judge for GEPA
99// ============================================================================
100
101impl FeedbackEvaluator for MathSolver {
102    async fn feedback_metric(&self, example: &Example, prediction: &Prediction) -> FeedbackMetric {
103        // Extract the problem and answers
104        let problem = example
105            .get("problem", None)
106            .as_str()
107            .unwrap_or("")
108            .to_string();
109
110        let expected = example
111            .get("expected_answer", None)
112            .as_str()
113            .unwrap_or("")
114            .to_string();
115
116        let student_answer = prediction
117            .get("answer", None)
118            .as_str()
119            .unwrap_or("")
120            .to_string();
121
122        let student_reasoning = prediction
123            .get("reasoning", None)
124            .as_str()
125            .unwrap_or("No reasoning provided")
126            .to_string();
127
128        // Quick check: is the answer exactly correct?
129        let answer_matches = student_answer.trim() == expected.trim();
130
131        // Use LLM judge to analyze the reasoning quality
132        // This is where the magic happens - the judge provides rich feedback
133        let judge_input = example! {
134            "problem": "input" => &problem,
135            "expected_answer": "input" => &expected,
136            "student_answer": "input" => &student_answer,
137            "student_reasoning": "input" => &student_reasoning
138        };
139
140        let judge_output = match self
141            .judge
142            .forward_with_config(judge_input, Arc::clone(&self.judge_lm))
143            .await
144        {
145            Ok(output) => output,
146            Err(_) => {
147                // If judge fails, fall back to simple feedback
148                let score = if answer_matches { 1.0 } else { 0.0 };
149                let simple_feedback = format!(
150                    "Problem: {}\nExpected: {}\nPredicted: {}\nAnswer: {}",
151                    problem,
152                    expected,
153                    student_answer,
154                    if answer_matches {
155                        "CORRECT"
156                    } else {
157                        "INCORRECT"
158                    }
159                );
160                return FeedbackMetric::new(score, simple_feedback);
161            }
162        };
163
164        let judge_evaluation = judge_output
165            .get("evaluation", None)
166            .as_str()
167            .unwrap_or("Unable to evaluate")
168            .to_string();
169
170        // Calculate score based on answer correctness and reasoning quality
171        // The judge's evaluation helps us assign partial credit
172        let score = if answer_matches {
173            // Correct answer - check if reasoning is also sound
174            if judge_evaluation.to_lowercase().contains("sound reasoning")
175                || judge_evaluation.to_lowercase().contains("correct approach")
176            {
177                1.0 // Perfect: right answer, good reasoning
178            } else {
179                0.7 // Right answer but flawed reasoning (lucky guess?)
180            }
181        } else {
182            // Wrong answer - check if there's any partial credit
183            if judge_evaluation.to_lowercase().contains("correct approach")
184                || judge_evaluation.to_lowercase().contains("good start")
185            {
186                0.3 // Wrong answer but some valid steps
187            } else {
188                0.0 // Completely wrong
189            }
190        };
191
192        // Construct rich textual feedback
193        // This combines factual info with the judge's analysis
194        let mut feedback = String::new();
195
196        feedback.push_str(&format!("Problem: {}\n", problem));
197        feedback.push_str(&format!("Expected: {}\n", expected));
198        feedback.push_str(&format!("Predicted: {}\n", student_answer));
199
200        if answer_matches {
201            feedback.push_str("Answer: CORRECT\n\n");
202        } else {
203            feedback.push_str("Answer: INCORRECT\n\n");
204        }
205
206        feedback.push_str("Reasoning Quality Analysis:\n");
207        feedback.push_str(&judge_evaluation);
208
209        // Return the feedback metric with score and rich text
210        FeedbackMetric::new(score, feedback)
211    }
212}
213
214// ============================================================================
215// Step 6: Main function - Set up and run GEPA optimization
216// ============================================================================
217
218#[tokio::main]
219async fn main() -> Result<()> {
220    println!("GEPA with LLM-as-a-Judge Example\n");
221    println!("This example shows how to use an LLM judge to automatically");
222    println!("generate rich feedback for optimizing a math solver.\n");
223
224    // Setup: Configure the LLM
225    // Main LM for the task
226    let task_lm = LM::new(LMConfig {
227        temperature: 0.7,
228        ..LMConfig::default()
229    })
230    .await;
231
232    // Judge LM (could use a different/cheaper model)
233    let judge_lm = LM::new(LMConfig {
234        temperature: 0.3,
235        ..LMConfig::default()
236    })
237    .await;
238
239    configure(task_lm, ChatAdapter);
240
241    // Create training examples
242    let trainset = vec![
243        example! {
244            "problem": "input" => "Sarah has 12 apples. She gives 3 to her friend and buys 5 more. How many apples does she have now?",
245            "expected_answer": "input" => "14"
246        },
247        example! {
248            "problem": "input" => "A train travels 60 miles in 1 hour. How far will it travel in 3.5 hours at the same speed?",
249            "expected_answer": "input" => "210"
250        },
251        example! {
252            "problem": "input" => "There are 24 students in a class. If 1/3 of them are absent, how many students are present?",
253            "expected_answer": "input" => "16"
254        },
255        example! {
256            "problem": "input" => "A rectangle has length 8 cm and width 5 cm. What is its area?",
257            "expected_answer": "input" => "40"
258        },
259        example! {
260            "problem": "input" => "John has $50. He spends $12 on lunch and $8 on a book. How much money does he have left?",
261            "expected_answer": "input" => "30"
262        },
263    ];
264
265    // Create the module
266    let mut module = MathSolver::builder()
267        .solver(Predict::new(MathWordProblem::new()))
268        .judge(Predict::new(MathJudge::new()))
269        .judge_lm(Arc::new(judge_lm))
270        .build();
271
272    // Evaluate baseline performance
273    println!("Step 1: Baseline Performance");
274    println!("Testing the solver before optimization...\n");
275    let baseline_score = module.evaluate(trainset.clone()).await;
276    println!("  Baseline average score: {:.3}\n", baseline_score);
277
278    // Configure GEPA optimizer
279    println!("Step 2: Configure GEPA");
280    println!("Setting up the optimizer with budget controls...\n");
281
282    let gepa = GEPA::builder()
283        .num_iterations(3) // Fewer iterations for demo
284        .minibatch_size(3) // Smaller batches
285        .temperature(0.9)
286        .track_stats(true)
287        .maybe_max_lm_calls(Some(100)) // Important: we're using 2x LM calls (task + judge)
288        .build();
289
290    // Run GEPA optimization
291    println!("Step 3: Run GEPA Optimization");
292    println!("The judge will analyze reasoning quality and provide feedback...\n");
293
294    let result = gepa
295        .compile_with_feedback(&mut module, trainset.clone())
296        .await?;
297
298    // Display results
299    println!("\nStep 4: Results");
300    println!("===============\n");
301    println!("Optimization complete!");
302    println!(
303        "  Best average score: {:.3}",
304        result.best_candidate.average_score()
305    );
306    println!(
307        "  Improvement: {:.3}",
308        result.best_candidate.average_score() - baseline_score
309    );
310    println!("  Total rollouts: {}", result.total_rollouts);
311    println!(
312        "  Total LM calls: {} (includes judge evaluations)",
313        result.total_lm_calls
314    );
315
316    println!("\nEvolution over time:");
317    for (generation, score) in &result.evolution_history {
318        println!("  Generation {}: {:.3}", generation, score);
319    }
320
321    println!("\nOptimized instruction:");
322    println!("  {}", result.best_candidate.instruction);
323
324    // Test the optimized solver
325    println!("\nStep 5: Test Optimized Solver");
326    println!("==============================\n");
327
328    let test_problem = example! {
329        "problem": "input" => "A store sells pencils for $0.25 each. If you buy 8 pencils, how much will you pay?",
330        "expected_answer": "input" => "2"
331    };
332
333    let test_prediction = module.forward(test_problem.clone()).await?;
334    let test_feedback = module
335        .feedback_metric(&test_problem, &test_prediction)
336        .await;
337
338    println!(
339        "Test problem: A store sells pencils for $0.25 each. If you buy 8 pencils, how much will you pay?"
340    );
341    println!("\nAnswer: {}", test_prediction.get("answer", None));
342    println!("Score: {:.3}\n", test_feedback.score);
343    println!("Detailed Feedback from Judge:");
344    println!("{}", test_feedback.feedback);
345
346    Ok(())
347}