09_gepa_sentiment/
09-gepa-sentiment.rs

1/// Example: Using GEPA to optimize a sentiment analysis module
2///
3/// This example demonstrates:
4/// 1. Implementing FeedbackEvaluator with rich textual feedback
5/// 2. Using GEPA optimizer for reflective prompt evolution
6/// 3. Tracking optimization progress with detailed statistics
7///
8/// To run:
9/// ```
10/// OPENAI_API_KEY=your_key cargo run --example 09-gepa-sentiment
11/// ```
12use anyhow::Result;
13use bon::Builder;
14use dspy_rs::*;
15use dsrs_macros::{Optimizable, Signature};
16
17#[Signature]
18struct SentimentSignature {
19    /// Analyze the sentiment of the given text. Classify as 'Positive', 'Negative', or 'Neutral'.
20
21    #[input]
22    pub text: String,
23
24    #[output]
25    pub sentiment: String,
26
27    #[output]
28    pub reasoning: String,
29}
30
31#[derive(Builder, Optimizable)]
32struct SentimentAnalyzer {
33    #[parameter]
34    predictor: Predict,
35}
36
37impl Module for SentimentAnalyzer {
38    async fn forward(&self, inputs: Example) -> Result<Prediction> {
39        self.predictor.forward(inputs).await
40    }
41}
42
43impl Evaluator for SentimentAnalyzer {
44    async fn metric(&self, example: &Example, prediction: &Prediction) -> f32 {
45        let feedback = self.feedback_metric(example, prediction).await;
46        feedback.score
47    }
48}
49
50impl FeedbackEvaluator for SentimentAnalyzer {
51    async fn feedback_metric(&self, example: &Example, prediction: &Prediction) -> FeedbackMetric {
52        let predicted = prediction
53            .get("sentiment", None)
54            .as_str()
55            .unwrap_or("")
56            .to_string()
57            .to_lowercase();
58
59        let expected = example
60            .get("expected_sentiment", None)
61            .as_str()
62            .unwrap_or("")
63            .to_string()
64            .to_lowercase();
65
66        let text = example.get("text", None).as_str().unwrap_or("").to_string();
67
68        let reasoning = prediction
69            .get("reasoning", None)
70            .as_str()
71            .unwrap_or("")
72            .to_string();
73
74        // Calculate score
75        let correct = predicted == expected;
76        let score = if correct { 1.0 } else { 0.0 };
77
78        // Create rich feedback
79        let mut feedback = if correct {
80            format!("Correct classification: \"{}\"\n", expected)
81        } else {
82            format!(
83                "Incorrect classification\n  Expected: \"{}\"\n  Predicted: \"{}\"\n",
84                expected, predicted
85            )
86        };
87
88        // Add context about the input
89        feedback.push_str(&format!("  Input text: \"{}\"\n", text));
90
91        // Add reasoning analysis
92        if !reasoning.is_empty() {
93            feedback.push_str(&format!("  Reasoning: {}\n", reasoning));
94
95            // Check if reasoning mentions key sentiment words
96            let has_reasoning_quality = if correct {
97                // For correct answers, check if reasoning is substantive
98                reasoning.len() > 20
99            } else {
100                // For incorrect answers, note what went wrong
101                false
102            };
103
104            if has_reasoning_quality {
105                feedback.push_str("  Reasoning appears detailed\n");
106            } else if !correct {
107                feedback.push_str("  May have misunderstood the text sentiment\n");
108            }
109        }
110
111        FeedbackMetric::new(score, feedback)
112    }
113}
114
115#[tokio::main]
116async fn main() -> Result<()> {
117    println!("GEPA Sentiment Analysis Optimization Example\n");
118
119    // Setup LM
120    let lm = LM::builder().temperature(0.7).build().await.unwrap();
121
122    configure(lm.clone(), ChatAdapter);
123
124    // Create training examples with diverse sentiments
125    let trainset = vec![
126        example! {
127            "text": "input" => "This movie was absolutely fantastic! I loved every minute of it.",
128            "expected_sentiment": "input" => "positive"
129        },
130        example! {
131            "text": "input" => "Terrible service, will never come back again.",
132            "expected_sentiment": "input" => "negative"
133        },
134        example! {
135            "text": "input" => "The weather is okay, nothing special.",
136            "expected_sentiment": "input" => "neutral"
137        },
138        example! {
139            "text": "input" => "Despite some minor issues, I'm quite happy with the purchase.",
140            "expected_sentiment": "input" => "positive"
141        },
142        example! {
143            "text": "input" => "I have mixed feelings about this product.",
144            "expected_sentiment": "input" => "neutral"
145        },
146        example! {
147            "text": "input" => "This is the worst experience I've ever had!",
148            "expected_sentiment": "input" => "negative"
149        },
150        example! {
151            "text": "input" => "It's fine. Does what it's supposed to do.",
152            "expected_sentiment": "input" => "neutral"
153        },
154        example! {
155            "text": "input" => "Exceeded all my expectations! Highly recommend!",
156            "expected_sentiment": "input" => "positive"
157        },
158        example! {
159            "text": "input" => "Disappointed and frustrated with the outcome.",
160            "expected_sentiment": "input" => "negative"
161        },
162        example! {
163            "text": "input" => "Standard quality, nothing remarkable.",
164            "expected_sentiment": "input" => "neutral"
165        },
166    ];
167
168    // Create module
169    let mut module = SentimentAnalyzer::builder()
170        .predictor(Predict::new(SentimentSignature::new()))
171        .build();
172
173    // Evaluate baseline performance
174    println!("Baseline Performance:");
175    let baseline_score = module.evaluate(trainset.clone()).await;
176    println!("  Average score: {:.3}\n", baseline_score);
177
178    // Configure GEPA optimizer
179    let gepa = GEPA::builder()
180        .num_iterations(5)
181        .minibatch_size(5)
182        .num_trials(3)
183        .temperature(0.9)
184        .track_stats(true)
185        .build();
186
187    // Run optimization
188    println!("Starting GEPA optimization...\n");
189    let result = gepa
190        .compile_with_feedback(&mut module, trainset.clone())
191        .await?;
192
193    // Display results
194    println!("\nOptimization Results:");
195    println!(
196        "  Best average score: {:.3}",
197        result.best_candidate.average_score()
198    );
199    println!("  Total rollouts: {}", result.total_rollouts);
200    println!("  Total LM calls: {}", result.total_lm_calls);
201    println!("  Generations: {}", result.evolution_history.len());
202
203    println!("\nBest Instruction:");
204    println!("  {}", result.best_candidate.instruction);
205
206    if !result.evolution_history.is_empty() {
207        println!("\nEvolution History:");
208        for entry in &result.evolution_history {
209            println!("  Generation {}: {:.3}", entry.0, entry.1);
210        }
211    }
212
213    // Test optimized module on a new example
214    println!("\nTesting Optimized Module:");
215    let test_example = example! {
216        "text": "input" => "This product changed my life! Absolutely amazing!",
217        "expected_sentiment": "input" => "positive"
218    };
219
220    let test_prediction = module.forward(test_example.clone()).await?;
221    let test_feedback = module
222        .feedback_metric(&test_example, &test_prediction)
223        .await;
224
225    println!(
226        "  Test prediction: {}",
227        test_prediction.get("sentiment", None)
228    );
229    println!("  Test score: {:.3}", test_feedback.score);
230    println!("  Feedback:\n{}", test_feedback.feedback);
231
232    Ok(())
233}