dspy_rs/evaluate/
feedback_helpers.rs

1/// Helper functions for creating rich feedback metrics
2///
3/// This module provides utilities for common feedback patterns in different domains:
4/// - Document retrieval (precision, recall, F1)
5/// - Code generation (compilation, execution, testing)
6/// - Multi-objective evaluation
7/// - Structured error reporting
8use super::FeedbackMetric;
9use serde_json::json;
10use std::collections::{HashMap, HashSet};
11
12// ============================================================================
13// Retrieval Feedback Helpers
14// ============================================================================
15
16/// Create feedback for document retrieval tasks
17///
18/// # Arguments
19/// * `retrieved` - Documents retrieved by the system
20/// * `expected` - Expected/gold documents
21/// * `context_docs` - Optional list of all available documents for context
22///
23/// # Example Feedback
24/// ```text
25/// Retrieved 3/5 correct documents (Precision: 0.6, Recall: 0.6, F1: 0.6)
26/// ✓ Correctly retrieved: doc1, doc2, doc3
27/// ✗ Missed: doc4, doc5
28/// ✗ Incorrectly retrieved: doc6, doc7
29/// ```
30pub fn retrieval_feedback(
31    retrieved: &[impl AsRef<str>],
32    expected: &[impl AsRef<str>],
33    context_docs: Option<&[impl AsRef<str>]>,
34) -> FeedbackMetric {
35    let retrieved_set: HashSet<String> = retrieved.iter().map(|s| s.as_ref().to_string()).collect();
36
37    let expected_set: HashSet<String> = expected.iter().map(|s| s.as_ref().to_string()).collect();
38
39    let correct: Vec<String> = retrieved_set.intersection(&expected_set).cloned().collect();
40
41    let missed: Vec<String> = expected_set.difference(&retrieved_set).cloned().collect();
42
43    let incorrect: Vec<String> = retrieved_set.difference(&expected_set).cloned().collect();
44
45    let precision = if retrieved.is_empty() {
46        0.0
47    } else {
48        correct.len() as f32 / retrieved.len() as f32
49    };
50
51    let recall = if expected.is_empty() {
52        1.0
53    } else {
54        correct.len() as f32 / expected.len() as f32
55    };
56
57    let f1 = if precision + recall > 0.0 {
58        2.0 * precision * recall / (precision + recall)
59    } else {
60        0.0
61    };
62
63    let mut feedback = format!(
64        "Retrieved {}/{} correct documents (Precision: {:.3}, Recall: {:.3}, F1: {:.3})\n",
65        correct.len(),
66        expected.len(),
67        precision,
68        recall,
69        f1
70    );
71
72    if !correct.is_empty() {
73        feedback.push_str(&format!("Correctly retrieved: {}\n", correct.join(", ")));
74    }
75
76    if !missed.is_empty() {
77        feedback.push_str(&format!("Missed: {}\n", missed.join(", ")));
78    }
79
80    if !incorrect.is_empty() {
81        feedback.push_str(&format!(
82            "Incorrectly retrieved: {}\n",
83            incorrect.join(", ")
84        ));
85    }
86
87    let mut metadata = HashMap::new();
88    metadata.insert("precision".to_string(), json!(precision));
89    metadata.insert("recall".to_string(), json!(recall));
90    metadata.insert("f1".to_string(), json!(f1));
91    metadata.insert("correct_count".to_string(), json!(correct.len()));
92    metadata.insert("missed_count".to_string(), json!(missed.len()));
93    metadata.insert("incorrect_count".to_string(), json!(incorrect.len()));
94
95    if let Some(docs) = context_docs {
96        metadata.insert("total_available".to_string(), json!(docs.len()));
97    }
98
99    FeedbackMetric {
100        score: f1,
101        feedback,
102        metadata,
103    }
104}
105
106// ============================================================================
107// Code Generation Feedback Helpers
108// ============================================================================
109
110/// Stage in code execution pipeline
111#[derive(Debug, Clone, PartialEq, Eq)]
112pub enum CodeStage {
113    Parse,
114    Compile,
115    Execute,
116    Test,
117}
118
119impl std::fmt::Display for CodeStage {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        match self {
122            CodeStage::Parse => write!(f, "Parse"),
123            CodeStage::Compile => write!(f, "Compile"),
124            CodeStage::Execute => write!(f, "Execute"),
125            CodeStage::Test => write!(f, "Test"),
126        }
127    }
128}
129
130/// Result of a code stage
131#[derive(Debug, Clone)]
132pub enum StageResult {
133    Success,
134    Failure { error: String },
135}
136
137/// Create feedback for code generation pipelines
138///
139/// # Arguments
140/// * `stages` - List of (stage, result) tuples showing pipeline progression
141/// * `final_score` - Overall score (0.0 to 1.0)
142///
143/// # Example Feedback
144/// ```text
145/// ✓ Parse: Success
146/// ✓ Compile: Success
147/// ✗ Execute: RuntimeError: division by zero on line 10
148/// ```
149pub fn code_pipeline_feedback(
150    stages: &[(CodeStage, StageResult)],
151    final_score: f32,
152) -> FeedbackMetric {
153    let mut feedback = String::new();
154    let mut metadata = HashMap::new();
155
156    let mut last_successful_stage = None;
157    let mut failure_stage = None;
158
159    for (i, (stage, result)) in stages.iter().enumerate() {
160        let stage_name = stage.to_string();
161        metadata.insert(format!("stage_{}_name", i), json!(stage_name));
162
163        match result {
164            StageResult::Success => {
165                feedback.push_str(&format!("{}: Success\n", stage));
166                metadata.insert(format!("stage_{}_result", i), json!("success"));
167                last_successful_stage = Some(stage);
168            }
169            StageResult::Failure { error } => {
170                feedback.push_str(&format!("{}: {}\n", stage, error));
171                metadata.insert(format!("stage_{}_result", i), json!("failure"));
172                metadata.insert(format!("stage_{}_error", i), json!(error));
173                failure_stage = Some((stage, error));
174                break; // Stop at first failure
175            }
176        }
177    }
178
179    if let Some((stage, error)) = failure_stage {
180        metadata.insert("failed_at_stage".to_string(), json!(stage.to_string()));
181        metadata.insert("failure_error".to_string(), json!(error));
182    }
183
184    if let Some(stage) = last_successful_stage {
185        metadata.insert(
186            "last_successful_stage".to_string(),
187            json!(stage.to_string()),
188        );
189    }
190
191    FeedbackMetric {
192        score: final_score,
193        feedback,
194        metadata,
195    }
196}
197
198// ============================================================================
199// Multi-Objective Feedback Helpers
200// ============================================================================
201
202/// Create feedback for multi-objective optimization
203///
204/// # Arguments
205/// * `objectives` - Map of objective name to (score, feedback) pairs
206/// * `weights` - Optional weights for aggregating objectives
207///
208/// # Example Feedback
209/// ```text
210/// [Correctness] Score: 0.9 - Output matches expected format
211/// [Latency] Score: 0.7 - Response took 450ms (target: <300ms)
212/// [Privacy] Score: 1.0 - No PII detected in output
213/// Overall: 0.87 (weighted average)
214/// ```
215pub fn multi_objective_feedback(
216    objectives: &HashMap<String, (f32, String)>,
217    weights: Option<&HashMap<String, f32>>,
218) -> FeedbackMetric {
219    let mut feedback = String::new();
220    let mut metadata = HashMap::new();
221
222    let mut total_score = 0.0;
223    let mut total_weight = 0.0;
224
225    let mut objective_names: Vec<_> = objectives.keys().collect();
226    objective_names.sort();
227
228    for name in objective_names {
229        if let Some((score, obj_feedback)) = objectives.get(name.as_str()) {
230            let weight = weights
231                .and_then(|w| w.get(name.as_str()))
232                .copied()
233                .unwrap_or(1.0);
234
235            feedback.push_str(&format!(
236                "[{}] Score: {:.3} - {}\n",
237                name, score, obj_feedback
238            ));
239
240            metadata.insert(format!("objective_{}_score", name), json!(score));
241            metadata.insert(format!("objective_{}_weight", name), json!(weight));
242            metadata.insert(format!("objective_{}_feedback", name), json!(obj_feedback));
243
244            total_score += score * weight;
245            total_weight += weight;
246        }
247    }
248
249    let aggregate_score = if total_weight > 0.0 {
250        total_score / total_weight
251    } else {
252        0.0
253    };
254
255    feedback.push_str(&format!(
256        "\nOverall: {:.3} (weighted average)",
257        aggregate_score
258    ));
259    metadata.insert("aggregate_score".to_string(), json!(aggregate_score));
260    metadata.insert("num_objectives".to_string(), json!(objectives.len()));
261
262    FeedbackMetric {
263        score: aggregate_score,
264        feedback,
265        metadata,
266    }
267}
268
269// ============================================================================
270// String Similarity Feedback
271// ============================================================================
272
273/// Create feedback for string similarity tasks
274///
275/// Uses simple word-level comparison to provide actionable feedback
276pub fn string_similarity_feedback(predicted: &str, expected: &str) -> FeedbackMetric {
277    let exact_match = predicted.trim() == expected.trim();
278
279    if exact_match {
280        return FeedbackMetric::new(1.0, "Exact match");
281    }
282
283    let pred_lower = predicted.to_lowercase();
284    let exp_lower = expected.to_lowercase();
285
286    if pred_lower == exp_lower {
287        return FeedbackMetric::new(0.95, "Match ignoring case (minor formatting difference)");
288    }
289
290    // Word-level comparison
291    let pred_words: HashSet<&str> = pred_lower.split_whitespace().collect();
292    let exp_words: HashSet<&str> = exp_lower.split_whitespace().collect();
293
294    let common_words: HashSet<_> = pred_words.intersection(&exp_words).collect();
295    let missing_words: Vec<_> = exp_words.difference(&pred_words).collect();
296    let extra_words: Vec<_> = pred_words.difference(&exp_words).collect();
297
298    let recall = if !exp_words.is_empty() {
299        common_words.len() as f32 / exp_words.len() as f32
300    } else {
301        1.0
302    };
303
304    let precision = if !pred_words.is_empty() {
305        common_words.len() as f32 / pred_words.len() as f32
306    } else {
307        0.0
308    };
309
310    let f1 = if precision + recall > 0.0 {
311        2.0 * precision * recall / (precision + recall)
312    } else {
313        0.0
314    };
315
316    let mut feedback = format!("Partial match (F1: {:.3})\n", f1);
317    feedback.push_str(&format!("Expected: \"{}\"\n", expected));
318    feedback.push_str(&format!("Predicted: \"{}\"\n", predicted));
319
320    if !missing_words.is_empty() {
321        feedback.push_str(&format!(
322            "Missing words: {}\n",
323            missing_words
324                .iter()
325                .map(|w| format!("\"{}\"", w))
326                .collect::<Vec<_>>()
327                .join(", ")
328        ));
329    }
330
331    if !extra_words.is_empty() {
332        feedback.push_str(&format!(
333            "Extra words: {}\n",
334            extra_words
335                .iter()
336                .map(|w| format!("\"{}\"", w))
337                .collect::<Vec<_>>()
338                .join(", ")
339        ));
340    }
341
342    FeedbackMetric::new(f1, feedback)
343}
344
345// ============================================================================
346// Classification Feedback
347// ============================================================================
348
349/// Create feedback for classification tasks
350pub fn classification_feedback(
351    predicted_class: &str,
352    expected_class: &str,
353    confidence: Option<f32>,
354) -> FeedbackMetric {
355    let correct = predicted_class == expected_class;
356    let score = if correct { 1.0 } else { 0.0 };
357
358    let mut feedback = if correct {
359        format!("Correct classification: \"{}\"", predicted_class)
360    } else {
361        format!(
362            "Incorrect classification\n  Expected: \"{}\"\n  Predicted: \"{}\"",
363            expected_class, predicted_class
364        )
365    };
366
367    if let Some(conf) = confidence {
368        feedback.push_str(&format!("\n  Confidence: {:.3}", conf));
369    }
370
371    let mut metadata = HashMap::new();
372    metadata.insert("predicted_class".to_string(), json!(predicted_class));
373    metadata.insert("expected_class".to_string(), json!(expected_class));
374    metadata.insert("correct".to_string(), json!(correct));
375
376    if let Some(conf) = confidence {
377        metadata.insert("confidence".to_string(), json!(conf));
378    }
379
380    FeedbackMetric::with_metadata(score, feedback, metadata)
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    #[test]
388    fn test_retrieval_feedback_perfect() {
389        let retrieved = vec!["doc1", "doc2", "doc3"];
390        let expected = vec!["doc1", "doc2", "doc3"];
391
392        let feedback = retrieval_feedback(&retrieved, &expected, None::<&[&str]>);
393        assert_eq!(feedback.score, 1.0);
394        assert!(feedback.feedback.contains("3/3"));
395    }
396
397    #[test]
398    fn test_retrieval_feedback_partial() {
399        let retrieved = vec!["doc1", "doc2", "doc4"];
400        let expected = vec!["doc1", "doc2", "doc3"];
401
402        let feedback = retrieval_feedback(&retrieved, &expected, None::<&[&str]>);
403        assert!(feedback.score < 1.0 && feedback.score > 0.0);
404        assert!(feedback.feedback.contains("Missed: doc3"));
405        assert!(feedback.feedback.contains("Incorrectly retrieved: doc4"));
406    }
407
408    #[test]
409    fn test_code_pipeline_feedback() {
410        let stages = vec![
411            (CodeStage::Parse, StageResult::Success),
412            (CodeStage::Compile, StageResult::Success),
413            (
414                CodeStage::Execute,
415                StageResult::Failure {
416                    error: "Division by zero".to_string(),
417                },
418            ),
419        ];
420
421        let feedback = code_pipeline_feedback(&stages, 0.6);
422        assert!(feedback.feedback.contains("Parse"));
423        assert!(feedback.feedback.contains("Compile"));
424        assert!(feedback.feedback.contains("Execute"));
425        assert_eq!(feedback.score, 0.6);
426    }
427
428    #[test]
429    fn test_multi_objective_feedback() {
430        let mut objectives = HashMap::new();
431        objectives.insert("accuracy".to_string(), (0.9, "Good accuracy".to_string()));
432        objectives.insert("latency".to_string(), (0.7, "Slow response".to_string()));
433
434        let feedback = multi_objective_feedback(&objectives, None);
435        assert!(feedback.feedback.contains("[accuracy]"));
436        assert!(feedback.feedback.contains("[latency]"));
437        assert!((feedback.score - 0.8).abs() < 0.01); // Average of 0.9 and 0.7
438    }
439
440    #[test]
441    fn test_string_similarity_exact() {
442        let feedback = string_similarity_feedback("hello world", "hello world");
443        assert_eq!(feedback.score, 1.0);
444    }
445
446    #[test]
447    fn test_string_similarity_case() {
448        let feedback = string_similarity_feedback("Hello World", "hello world");
449        assert_eq!(feedback.score, 0.95);
450    }
451
452    #[test]
453    fn test_classification_feedback() {
454        let feedback = classification_feedback("positive", "positive", Some(0.95));
455        assert_eq!(feedback.score, 1.0);
456        assert!(feedback.feedback.contains("Correct"));
457
458        let feedback = classification_feedback("negative", "positive", Some(0.85));
459        assert_eq!(feedback.score, 0.0);
460        assert!(feedback.feedback.contains("Incorrect"));
461    }
462}