09_gepa_sentiment/
09-gepa-sentiment.rs1use anyhow::Result;
13use bon::Builder;
14use dspy_rs::*;
15use dsrs_macros::{Optimizable, Signature};
16
17#[Signature]
18struct SentimentSignature {
19 #[input]
22 pub text: String,
23
24 #[output]
25 pub sentiment: String,
26
27 #[output]
28 pub reasoning: String,
29}
30
31#[derive(Builder, Optimizable)]
32struct SentimentAnalyzer {
33 #[parameter]
34 predictor: Predict,
35}
36
37impl Module for SentimentAnalyzer {
38 async fn forward(&self, inputs: Example) -> Result<Prediction> {
39 self.predictor.forward(inputs).await
40 }
41}
42
43impl Evaluator for SentimentAnalyzer {
44 async fn metric(&self, example: &Example, prediction: &Prediction) -> f32 {
45 let feedback = self.feedback_metric(example, prediction).await;
46 feedback.score
47 }
48}
49
50impl FeedbackEvaluator for SentimentAnalyzer {
51 async fn feedback_metric(&self, example: &Example, prediction: &Prediction) -> FeedbackMetric {
52 let predicted = prediction
53 .get("sentiment", None)
54 .as_str()
55 .unwrap_or("")
56 .to_string()
57 .to_lowercase();
58
59 let expected = example
60 .get("expected_sentiment", None)
61 .as_str()
62 .unwrap_or("")
63 .to_string()
64 .to_lowercase();
65
66 let text = example.get("text", None).as_str().unwrap_or("").to_string();
67
68 let reasoning = prediction
69 .get("reasoning", None)
70 .as_str()
71 .unwrap_or("")
72 .to_string();
73
74 let correct = predicted == expected;
76 let score = if correct { 1.0 } else { 0.0 };
77
78 let mut feedback = if correct {
80 format!("Correct classification: \"{}\"\n", expected)
81 } else {
82 format!(
83 "Incorrect classification\n Expected: \"{}\"\n Predicted: \"{}\"\n",
84 expected, predicted
85 )
86 };
87
88 feedback.push_str(&format!(" Input text: \"{}\"\n", text));
90
91 if !reasoning.is_empty() {
93 feedback.push_str(&format!(" Reasoning: {}\n", reasoning));
94
95 let has_reasoning_quality = if correct {
97 reasoning.len() > 20
99 } else {
100 false
102 };
103
104 if has_reasoning_quality {
105 feedback.push_str(" Reasoning appears detailed\n");
106 } else if !correct {
107 feedback.push_str(" May have misunderstood the text sentiment\n");
108 }
109 }
110
111 FeedbackMetric::new(score, feedback)
112 }
113}
114
115#[tokio::main]
116async fn main() -> Result<()> {
117 println!("GEPA Sentiment Analysis Optimization Example\n");
118
119 let lm = LM::builder().temperature(0.7).build().await.unwrap();
121
122 configure(lm.clone(), ChatAdapter);
123
124 let trainset = vec![
126 example! {
127 "text": "input" => "This movie was absolutely fantastic! I loved every minute of it.",
128 "expected_sentiment": "input" => "positive"
129 },
130 example! {
131 "text": "input" => "Terrible service, will never come back again.",
132 "expected_sentiment": "input" => "negative"
133 },
134 example! {
135 "text": "input" => "The weather is okay, nothing special.",
136 "expected_sentiment": "input" => "neutral"
137 },
138 example! {
139 "text": "input" => "Despite some minor issues, I'm quite happy with the purchase.",
140 "expected_sentiment": "input" => "positive"
141 },
142 example! {
143 "text": "input" => "I have mixed feelings about this product.",
144 "expected_sentiment": "input" => "neutral"
145 },
146 example! {
147 "text": "input" => "This is the worst experience I've ever had!",
148 "expected_sentiment": "input" => "negative"
149 },
150 example! {
151 "text": "input" => "It's fine. Does what it's supposed to do.",
152 "expected_sentiment": "input" => "neutral"
153 },
154 example! {
155 "text": "input" => "Exceeded all my expectations! Highly recommend!",
156 "expected_sentiment": "input" => "positive"
157 },
158 example! {
159 "text": "input" => "Disappointed and frustrated with the outcome.",
160 "expected_sentiment": "input" => "negative"
161 },
162 example! {
163 "text": "input" => "Standard quality, nothing remarkable.",
164 "expected_sentiment": "input" => "neutral"
165 },
166 ];
167
168 let mut module = SentimentAnalyzer::builder()
170 .predictor(Predict::new(SentimentSignature::new()))
171 .build();
172
173 println!("Baseline Performance:");
175 let baseline_score = module.evaluate(trainset.clone()).await;
176 println!(" Average score: {:.3}\n", baseline_score);
177
178 let gepa = GEPA::builder()
180 .num_iterations(5)
181 .minibatch_size(5)
182 .num_trials(3)
183 .temperature(0.9)
184 .track_stats(true)
185 .build();
186
187 println!("Starting GEPA optimization...\n");
189 let result = gepa
190 .compile_with_feedback(&mut module, trainset.clone())
191 .await?;
192
193 println!("\nOptimization Results:");
195 println!(
196 " Best average score: {:.3}",
197 result.best_candidate.average_score()
198 );
199 println!(" Total rollouts: {}", result.total_rollouts);
200 println!(" Total LM calls: {}", result.total_lm_calls);
201 println!(" Generations: {}", result.evolution_history.len());
202
203 println!("\nBest Instruction:");
204 println!(" {}", result.best_candidate.instruction);
205
206 if !result.evolution_history.is_empty() {
207 println!("\nEvolution History:");
208 for entry in &result.evolution_history {
209 println!(" Generation {}: {:.3}", entry.0, entry.1);
210 }
211 }
212
213 println!("\nTesting Optimized Module:");
215 let test_example = example! {
216 "text": "input" => "This product changed my life! Absolutely amazing!",
217 "expected_sentiment": "input" => "positive"
218 };
219
220 let test_prediction = module.forward(test_example.clone()).await?;
221 let test_feedback = module
222 .feedback_metric(&test_example, &test_prediction)
223 .await;
224
225 println!(
226 " Test prediction: {}",
227 test_prediction.get("sentiment", None)
228 );
229 println!(" Test score: {:.3}", test_feedback.score);
230 println!(" Feedback:\n{}", test_feedback.feedback);
231
232 Ok(())
233}