10_gepa_llm_judge/
10-gepa-llm-judge.rs

1/// Example: Using LLM-as-a-Judge with GEPA for Math Word Problems
2///
3/// This example demonstrates how to use an LLM judge to automatically generate
4/// rich textual feedback for GEPA optimization. The judge evaluates both the
5/// correctness of answers AND the quality of reasoning.
6///
7/// To run:
8/// ```
9/// OPENAI_API_KEY=your_key cargo run --example 10-gepa-llm-judge
10/// ```
11use anyhow::Result;
12use bon::Builder;
13use dspy_rs::*;
14use dsrs_macros::{Optimizable, Signature};
15use std::sync::Arc;
16
17// ============================================================================
18// Step 1: Define the task signature with chain-of-thought reasoning
19// ============================================================================
20
21#[Signature(cot)]
22struct MathWordProblem {
23    /// Solve the math word problem step by step. Show your work clearly.
24
25    #[input]
26    pub problem: String,
27
28    #[output]
29    pub reasoning: String,
30
31    #[output]
32    pub answer: String,
33}
34
35// ============================================================================
36// Step 2: Define the LLM judge signature
37// ============================================================================
38
39#[Signature]
40struct MathJudge {
41    /// You are an expert math teacher evaluating student work. Analyze both
42    /// the final answer and the reasoning process. Be specific about what
43    /// went wrong or what was done well.
44
45    #[input(desc = "The math problem that was given")]
46    pub problem: String,
47
48    #[input(desc = "The expected correct answer")]
49    pub expected_answer: String,
50
51    #[input(desc = "The student's answer")]
52    pub student_answer: String,
53
54    #[input(desc = "The student's reasoning/work shown")]
55    pub student_reasoning: String,
56
57    #[output(desc = "Detailed evaluation of the work")]
58    pub evaluation: String,
59}
60
61// ============================================================================
62// Step 3: Create the main module with LLM judge
63// ============================================================================
64
65#[derive(Builder, Optimizable)]
66struct MathSolver {
67    // The main predictor we want to optimize
68    #[parameter]
69    solver: Predict,
70
71    // The judge predictor (not optimized, just used for evaluation)
72    judge: Predict,
73
74    // LM for the judge (could be different/cheaper model)
75    judge_lm: Arc<LM>,
76}
77
78impl Module for MathSolver {
79    async fn forward(&self, inputs: Example) -> Result<Prediction> {
80        // Just forward to the solver - judge only used during evaluation
81        self.solver.forward(inputs).await
82    }
83}
84
85// ============================================================================
86// Step 4: Implement regular Evaluator for non-GEPA optimizers
87// ============================================================================
88
89impl Evaluator for MathSolver {
90    async fn metric(&self, example: &Example, prediction: &Prediction) -> f32 {
91        // For regular optimizers, just return scalar score
92        let feedback = self.feedback_metric(example, prediction).await;
93        feedback.score
94    }
95}
96
97// ============================================================================
98// Step 5: Implement FeedbackEvaluator with LLM judge for GEPA
99// ============================================================================
100
101impl FeedbackEvaluator for MathSolver {
102    async fn feedback_metric(&self, example: &Example, prediction: &Prediction) -> FeedbackMetric {
103        // Extract the problem and answers
104        let problem = example
105            .get("problem", None)
106            .as_str()
107            .unwrap_or("")
108            .to_string();
109
110        let expected = example
111            .get("expected_answer", None)
112            .as_str()
113            .unwrap_or("")
114            .to_string();
115
116        let student_answer = prediction
117            .get("answer", None)
118            .as_str()
119            .unwrap_or("")
120            .to_string();
121
122        let student_reasoning = prediction
123            .get("reasoning", None)
124            .as_str()
125            .unwrap_or("No reasoning provided")
126            .to_string();
127
128        // Quick check: is the answer exactly correct?
129        let answer_matches = student_answer.trim() == expected.trim();
130
131        // Use LLM judge to analyze the reasoning quality
132        // This is where the magic happens - the judge provides rich feedback
133        let judge_input = example! {
134            "problem": "input" => &problem,
135            "expected_answer": "input" => &expected,
136            "student_answer": "input" => &student_answer,
137            "student_reasoning": "input" => &student_reasoning
138        };
139
140        let judge_output = match self
141            .judge
142            .forward_with_config(judge_input, Arc::clone(&self.judge_lm))
143            .await
144        {
145            Ok(output) => output,
146            Err(_) => {
147                // If judge fails, fall back to simple feedback
148                let score = if answer_matches { 1.0 } else { 0.0 };
149                let simple_feedback = format!(
150                    "Problem: {}\nExpected: {}\nPredicted: {}\nAnswer: {}",
151                    problem,
152                    expected,
153                    student_answer,
154                    if answer_matches {
155                        "CORRECT"
156                    } else {
157                        "INCORRECT"
158                    }
159                );
160                return FeedbackMetric::new(score, simple_feedback);
161            }
162        };
163
164        let judge_evaluation = judge_output
165            .get("evaluation", None)
166            .as_str()
167            .unwrap_or("Unable to evaluate")
168            .to_string();
169
170        // Calculate score based on answer correctness and reasoning quality
171        // The judge's evaluation helps us assign partial credit
172        let score = if answer_matches {
173            // Correct answer - check if reasoning is also sound
174            if judge_evaluation.to_lowercase().contains("sound reasoning")
175                || judge_evaluation.to_lowercase().contains("correct approach")
176            {
177                1.0 // Perfect: right answer, good reasoning
178            } else {
179                0.7 // Right answer but flawed reasoning (lucky guess?)
180            }
181        } else {
182            // Wrong answer - check if there's any partial credit
183            if judge_evaluation.to_lowercase().contains("correct approach")
184                || judge_evaluation.to_lowercase().contains("good start")
185            {
186                0.3 // Wrong answer but some valid steps
187            } else {
188                0.0 // Completely wrong
189            }
190        };
191
192        // Construct rich textual feedback
193        // This combines factual info with the judge's analysis
194        let mut feedback = String::new();
195
196        feedback.push_str(&format!("Problem: {}\n", problem));
197        feedback.push_str(&format!("Expected: {}\n", expected));
198        feedback.push_str(&format!("Predicted: {}\n", student_answer));
199
200        if answer_matches {
201            feedback.push_str("Answer: CORRECT\n\n");
202        } else {
203            feedback.push_str("Answer: INCORRECT\n\n");
204        }
205
206        feedback.push_str("Reasoning Quality Analysis:\n");
207        feedback.push_str(&judge_evaluation);
208
209        // Return the feedback metric with score and rich text
210        FeedbackMetric::new(score, feedback)
211    }
212}
213
214// ============================================================================
215// Step 6: Main function - Set up and run GEPA optimization
216// ============================================================================
217
218#[tokio::main]
219async fn main() -> Result<()> {
220    println!("GEPA with LLM-as-a-Judge Example\n");
221    println!("This example shows how to use an LLM judge to automatically");
222    println!("generate rich feedback for optimizing a math solver.\n");
223
224    // Setup: Configure the LLM
225    // Main LM for the task
226    let task_lm = LM::builder().temperature(0.7).build().await.unwrap();
227
228    // Judge LM (could use a different/cheaper model)
229    let judge_lm = LM::builder().temperature(0.3).build().await.unwrap();
230
231    configure(task_lm, ChatAdapter);
232
233    // Create training examples
234    let trainset = vec![
235        example! {
236            "problem": "input" => "Sarah has 12 apples. She gives 3 to her friend and buys 5 more. How many apples does she have now?",
237            "expected_answer": "input" => "14"
238        },
239        example! {
240            "problem": "input" => "A train travels 60 miles in 1 hour. How far will it travel in 3.5 hours at the same speed?",
241            "expected_answer": "input" => "210"
242        },
243        example! {
244            "problem": "input" => "There are 24 students in a class. If 1/3 of them are absent, how many students are present?",
245            "expected_answer": "input" => "16"
246        },
247        example! {
248            "problem": "input" => "A rectangle has length 8 cm and width 5 cm. What is its area?",
249            "expected_answer": "input" => "40"
250        },
251        example! {
252            "problem": "input" => "John has $50. He spends $12 on lunch and $8 on a book. How much money does he have left?",
253            "expected_answer": "input" => "30"
254        },
255    ];
256
257    // Create the module
258    let mut module = MathSolver::builder()
259        .solver(Predict::new(MathWordProblem::new()))
260        .judge(Predict::new(MathJudge::new()))
261        .judge_lm(Arc::new(judge_lm))
262        .build();
263
264    // Evaluate baseline performance
265    println!("Step 1: Baseline Performance");
266    println!("Testing the solver before optimization...\n");
267    let baseline_score = module.evaluate(trainset.clone()).await;
268    println!("  Baseline average score: {:.3}\n", baseline_score);
269
270    // Configure GEPA optimizer
271    println!("Step 2: Configure GEPA");
272    println!("Setting up the optimizer with budget controls...\n");
273
274    let gepa = GEPA::builder()
275        .num_iterations(3) // Fewer iterations for demo
276        .minibatch_size(3) // Smaller batches
277        .temperature(0.9)
278        .track_stats(true)
279        .maybe_max_lm_calls(Some(100)) // Important: we're using 2x LM calls (task + judge)
280        .build();
281
282    // Run GEPA optimization
283    println!("Step 3: Run GEPA Optimization");
284    println!("The judge will analyze reasoning quality and provide feedback...\n");
285
286    let result = gepa
287        .compile_with_feedback(&mut module, trainset.clone())
288        .await?;
289
290    // Display results
291    println!("\nStep 4: Results");
292    println!("===============\n");
293    println!("Optimization complete!");
294    println!(
295        "  Best average score: {:.3}",
296        result.best_candidate.average_score()
297    );
298    println!(
299        "  Improvement: {:.3}",
300        result.best_candidate.average_score() - baseline_score
301    );
302    println!("  Total rollouts: {}", result.total_rollouts);
303    println!(
304        "  Total LM calls: {} (includes judge evaluations)",
305        result.total_lm_calls
306    );
307
308    println!("\nEvolution over time:");
309    for (generation, score) in &result.evolution_history {
310        println!("  Generation {}: {:.3}", generation, score);
311    }
312
313    println!("\nOptimized instruction:");
314    println!("  {}", result.best_candidate.instruction);
315
316    // Test the optimized solver
317    println!("\nStep 5: Test Optimized Solver");
318    println!("==============================\n");
319
320    let test_problem = example! {
321        "problem": "input" => "A store sells pencils for $0.25 each. If you buy 8 pencils, how much will you pay?",
322        "expected_answer": "input" => "2"
323    };
324
325    let test_prediction = module.forward(test_problem.clone()).await?;
326    let test_feedback = module
327        .feedback_metric(&test_problem, &test_prediction)
328        .await;
329
330    println!(
331        "Test problem: A store sells pencils for $0.25 each. If you buy 8 pencils, how much will you pay?"
332    );
333    println!("\nAnswer: {}", test_prediction.get("answer", None));
334    println!("Score: {:.3}\n", test_feedback.score);
335    println!("Detailed Feedback from Judge:");
336    println!("{}", test_feedback.feedback);
337
338    Ok(())
339}