09_gepa_sentiment/
09-gepa-sentiment.rs1use anyhow::Result;
13use bon::Builder;
14use dspy_rs::*;
15use dsrs_macros::{Optimizable, Signature};
16
17#[Signature]
18struct SentimentSignature {
19 #[input]
22 pub text: String,
23
24 #[output]
25 pub sentiment: String,
26
27 #[output]
28 pub reasoning: String,
29}
30
31#[derive(Builder, Optimizable)]
32struct SentimentAnalyzer {
33 #[parameter]
34 predictor: Predict,
35}
36
37impl Module for SentimentAnalyzer {
38 async fn forward(&self, inputs: Example) -> Result<Prediction> {
39 self.predictor.forward(inputs).await
40 }
41}
42
43impl Evaluator for SentimentAnalyzer {
44 async fn metric(&self, example: &Example, prediction: &Prediction) -> f32 {
45 let feedback = self.feedback_metric(example, prediction).await;
46 feedback.score
47 }
48}
49
50impl FeedbackEvaluator for SentimentAnalyzer {
51 async fn feedback_metric(&self, example: &Example, prediction: &Prediction) -> FeedbackMetric {
52 let predicted = prediction
53 .get("sentiment", None)
54 .as_str()
55 .unwrap_or("")
56 .to_string()
57 .to_lowercase();
58
59 let expected = example
60 .get("expected_sentiment", None)
61 .as_str()
62 .unwrap_or("")
63 .to_string()
64 .to_lowercase();
65
66 let text = example.get("text", None).as_str().unwrap_or("").to_string();
67
68 let reasoning = prediction
69 .get("reasoning", None)
70 .as_str()
71 .unwrap_or("")
72 .to_string();
73
74 let correct = predicted == expected;
76 let score = if correct { 1.0 } else { 0.0 };
77
78 let mut feedback = if correct {
80 format!("Correct classification: \"{}\"\n", expected)
81 } else {
82 format!(
83 "Incorrect classification\n Expected: \"{}\"\n Predicted: \"{}\"\n",
84 expected, predicted
85 )
86 };
87
88 feedback.push_str(&format!(" Input text: \"{}\"\n", text));
90
91 if !reasoning.is_empty() {
93 feedback.push_str(&format!(" Reasoning: {}\n", reasoning));
94
95 let has_reasoning_quality = if correct {
97 reasoning.len() > 20
99 } else {
100 false
102 };
103
104 if has_reasoning_quality {
105 feedback.push_str(" Reasoning appears detailed\n");
106 } else if !correct {
107 feedback.push_str(" May have misunderstood the text sentiment\n");
108 }
109 }
110
111 FeedbackMetric::new(score, feedback)
112 }
113}
114
115#[tokio::main]
116async fn main() -> Result<()> {
117 println!("GEPA Sentiment Analysis Optimization Example\n");
118
119 let lm = LM::new(LMConfig {
121 temperature: 0.7,
122 ..LMConfig::default()
123 })
124 .await;
125
126 configure(lm.clone(), ChatAdapter);
127
128 let trainset = vec![
130 example! {
131 "text": "input" => "This movie was absolutely fantastic! I loved every minute of it.",
132 "expected_sentiment": "input" => "positive"
133 },
134 example! {
135 "text": "input" => "Terrible service, will never come back again.",
136 "expected_sentiment": "input" => "negative"
137 },
138 example! {
139 "text": "input" => "The weather is okay, nothing special.",
140 "expected_sentiment": "input" => "neutral"
141 },
142 example! {
143 "text": "input" => "Despite some minor issues, I'm quite happy with the purchase.",
144 "expected_sentiment": "input" => "positive"
145 },
146 example! {
147 "text": "input" => "I have mixed feelings about this product.",
148 "expected_sentiment": "input" => "neutral"
149 },
150 example! {
151 "text": "input" => "This is the worst experience I've ever had!",
152 "expected_sentiment": "input" => "negative"
153 },
154 example! {
155 "text": "input" => "It's fine. Does what it's supposed to do.",
156 "expected_sentiment": "input" => "neutral"
157 },
158 example! {
159 "text": "input" => "Exceeded all my expectations! Highly recommend!",
160 "expected_sentiment": "input" => "positive"
161 },
162 example! {
163 "text": "input" => "Disappointed and frustrated with the outcome.",
164 "expected_sentiment": "input" => "negative"
165 },
166 example! {
167 "text": "input" => "Standard quality, nothing remarkable.",
168 "expected_sentiment": "input" => "neutral"
169 },
170 ];
171
172 let mut module = SentimentAnalyzer::builder()
174 .predictor(Predict::new(SentimentSignature::new()))
175 .build();
176
177 println!("Baseline Performance:");
179 let baseline_score = module.evaluate(trainset.clone()).await;
180 println!(" Average score: {:.3}\n", baseline_score);
181
182 let gepa = GEPA::builder()
184 .num_iterations(5)
185 .minibatch_size(5)
186 .num_trials(3)
187 .temperature(0.9)
188 .track_stats(true)
189 .build();
190
191 println!("Starting GEPA optimization...\n");
193 let result = gepa
194 .compile_with_feedback(&mut module, trainset.clone())
195 .await?;
196
197 println!("\nOptimization Results:");
199 println!(
200 " Best average score: {:.3}",
201 result.best_candidate.average_score()
202 );
203 println!(" Total rollouts: {}", result.total_rollouts);
204 println!(" Total LM calls: {}", result.total_lm_calls);
205 println!(" Generations: {}", result.evolution_history.len());
206
207 println!("\nBest Instruction:");
208 println!(" {}", result.best_candidate.instruction);
209
210 if !result.evolution_history.is_empty() {
211 println!("\nEvolution History:");
212 for entry in &result.evolution_history {
213 println!(" Generation {}: {:.3}", entry.0, entry.1);
214 }
215 }
216
217 println!("\nTesting Optimized Module:");
219 let test_example = example! {
220 "text": "input" => "This product changed my life! Absolutely amazing!",
221 "expected_sentiment": "input" => "positive"
222 };
223
224 let test_prediction = module.forward(test_example.clone()).await?;
225 let test_feedback = module
226 .feedback_metric(&test_example, &test_prediction)
227 .await;
228
229 println!(
230 " Test prediction: {}",
231 test_prediction.get("sentiment", None)
232 );
233 println!(" Test score: {:.3}", test_feedback.score);
234 println!(" Feedback:\n{}", test_feedback.feedback);
235
236 Ok(())
237}