1use super::FeedbackMetric;
9use serde_json::json;
10use std::collections::{HashMap, HashSet};
11
12pub fn retrieval_feedback(
31 retrieved: &[impl AsRef<str>],
32 expected: &[impl AsRef<str>],
33 context_docs: Option<&[impl AsRef<str>]>,
34) -> FeedbackMetric {
35 let retrieved_set: HashSet<String> = retrieved.iter().map(|s| s.as_ref().to_string()).collect();
36
37 let expected_set: HashSet<String> = expected.iter().map(|s| s.as_ref().to_string()).collect();
38
39 let correct: Vec<String> = retrieved_set.intersection(&expected_set).cloned().collect();
40
41 let missed: Vec<String> = expected_set.difference(&retrieved_set).cloned().collect();
42
43 let incorrect: Vec<String> = retrieved_set.difference(&expected_set).cloned().collect();
44
45 let precision = if retrieved.is_empty() {
46 0.0
47 } else {
48 correct.len() as f32 / retrieved.len() as f32
49 };
50
51 let recall = if expected.is_empty() {
52 1.0
53 } else {
54 correct.len() as f32 / expected.len() as f32
55 };
56
57 let f1 = if precision + recall > 0.0 {
58 2.0 * precision * recall / (precision + recall)
59 } else {
60 0.0
61 };
62
63 let mut feedback = format!(
64 "Retrieved {}/{} correct documents (Precision: {:.3}, Recall: {:.3}, F1: {:.3})\n",
65 correct.len(),
66 expected.len(),
67 precision,
68 recall,
69 f1
70 );
71
72 if !correct.is_empty() {
73 feedback.push_str(&format!("Correctly retrieved: {}\n", correct.join(", ")));
74 }
75
76 if !missed.is_empty() {
77 feedback.push_str(&format!("Missed: {}\n", missed.join(", ")));
78 }
79
80 if !incorrect.is_empty() {
81 feedback.push_str(&format!(
82 "Incorrectly retrieved: {}\n",
83 incorrect.join(", ")
84 ));
85 }
86
87 let mut metadata = HashMap::new();
88 metadata.insert("precision".to_string(), json!(precision));
89 metadata.insert("recall".to_string(), json!(recall));
90 metadata.insert("f1".to_string(), json!(f1));
91 metadata.insert("correct_count".to_string(), json!(correct.len()));
92 metadata.insert("missed_count".to_string(), json!(missed.len()));
93 metadata.insert("incorrect_count".to_string(), json!(incorrect.len()));
94
95 if let Some(docs) = context_docs {
96 metadata.insert("total_available".to_string(), json!(docs.len()));
97 }
98
99 FeedbackMetric {
100 score: f1,
101 feedback,
102 metadata,
103 }
104}
105
106#[derive(Debug, Clone, PartialEq, Eq)]
112pub enum CodeStage {
113 Parse,
114 Compile,
115 Execute,
116 Test,
117}
118
119impl std::fmt::Display for CodeStage {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 match self {
122 CodeStage::Parse => write!(f, "Parse"),
123 CodeStage::Compile => write!(f, "Compile"),
124 CodeStage::Execute => write!(f, "Execute"),
125 CodeStage::Test => write!(f, "Test"),
126 }
127 }
128}
129
130#[derive(Debug, Clone)]
132pub enum StageResult {
133 Success,
134 Failure { error: String },
135}
136
137pub fn code_pipeline_feedback(
150 stages: &[(CodeStage, StageResult)],
151 final_score: f32,
152) -> FeedbackMetric {
153 let mut feedback = String::new();
154 let mut metadata = HashMap::new();
155
156 let mut last_successful_stage = None;
157 let mut failure_stage = None;
158
159 for (i, (stage, result)) in stages.iter().enumerate() {
160 let stage_name = stage.to_string();
161 metadata.insert(format!("stage_{}_name", i), json!(stage_name));
162
163 match result {
164 StageResult::Success => {
165 feedback.push_str(&format!("{}: Success\n", stage));
166 metadata.insert(format!("stage_{}_result", i), json!("success"));
167 last_successful_stage = Some(stage);
168 }
169 StageResult::Failure { error } => {
170 feedback.push_str(&format!("{}: {}\n", stage, error));
171 metadata.insert(format!("stage_{}_result", i), json!("failure"));
172 metadata.insert(format!("stage_{}_error", i), json!(error));
173 failure_stage = Some((stage, error));
174 break; }
176 }
177 }
178
179 if let Some((stage, error)) = failure_stage {
180 metadata.insert("failed_at_stage".to_string(), json!(stage.to_string()));
181 metadata.insert("failure_error".to_string(), json!(error));
182 }
183
184 if let Some(stage) = last_successful_stage {
185 metadata.insert(
186 "last_successful_stage".to_string(),
187 json!(stage.to_string()),
188 );
189 }
190
191 FeedbackMetric {
192 score: final_score,
193 feedback,
194 metadata,
195 }
196}
197
198pub fn multi_objective_feedback(
216 objectives: &HashMap<String, (f32, String)>,
217 weights: Option<&HashMap<String, f32>>,
218) -> FeedbackMetric {
219 let mut feedback = String::new();
220 let mut metadata = HashMap::new();
221
222 let mut total_score = 0.0;
223 let mut total_weight = 0.0;
224
225 let mut objective_names: Vec<_> = objectives.keys().collect();
226 objective_names.sort();
227
228 for name in objective_names {
229 if let Some((score, obj_feedback)) = objectives.get(name.as_str()) {
230 let weight = weights
231 .and_then(|w| w.get(name.as_str()))
232 .copied()
233 .unwrap_or(1.0);
234
235 feedback.push_str(&format!(
236 "[{}] Score: {:.3} - {}\n",
237 name, score, obj_feedback
238 ));
239
240 metadata.insert(format!("objective_{}_score", name), json!(score));
241 metadata.insert(format!("objective_{}_weight", name), json!(weight));
242 metadata.insert(format!("objective_{}_feedback", name), json!(obj_feedback));
243
244 total_score += score * weight;
245 total_weight += weight;
246 }
247 }
248
249 let aggregate_score = if total_weight > 0.0 {
250 total_score / total_weight
251 } else {
252 0.0
253 };
254
255 feedback.push_str(&format!(
256 "\nOverall: {:.3} (weighted average)",
257 aggregate_score
258 ));
259 metadata.insert("aggregate_score".to_string(), json!(aggregate_score));
260 metadata.insert("num_objectives".to_string(), json!(objectives.len()));
261
262 FeedbackMetric {
263 score: aggregate_score,
264 feedback,
265 metadata,
266 }
267}
268
269pub fn string_similarity_feedback(predicted: &str, expected: &str) -> FeedbackMetric {
277 let exact_match = predicted.trim() == expected.trim();
278
279 if exact_match {
280 return FeedbackMetric::new(1.0, "Exact match");
281 }
282
283 let pred_lower = predicted.to_lowercase();
284 let exp_lower = expected.to_lowercase();
285
286 if pred_lower == exp_lower {
287 return FeedbackMetric::new(0.95, "Match ignoring case (minor formatting difference)");
288 }
289
290 let pred_words: HashSet<&str> = pred_lower.split_whitespace().collect();
292 let exp_words: HashSet<&str> = exp_lower.split_whitespace().collect();
293
294 let common_words: HashSet<_> = pred_words.intersection(&exp_words).collect();
295 let missing_words: Vec<_> = exp_words.difference(&pred_words).collect();
296 let extra_words: Vec<_> = pred_words.difference(&exp_words).collect();
297
298 let recall = if !exp_words.is_empty() {
299 common_words.len() as f32 / exp_words.len() as f32
300 } else {
301 1.0
302 };
303
304 let precision = if !pred_words.is_empty() {
305 common_words.len() as f32 / pred_words.len() as f32
306 } else {
307 0.0
308 };
309
310 let f1 = if precision + recall > 0.0 {
311 2.0 * precision * recall / (precision + recall)
312 } else {
313 0.0
314 };
315
316 let mut feedback = format!("Partial match (F1: {:.3})\n", f1);
317 feedback.push_str(&format!("Expected: \"{}\"\n", expected));
318 feedback.push_str(&format!("Predicted: \"{}\"\n", predicted));
319
320 if !missing_words.is_empty() {
321 feedback.push_str(&format!(
322 "Missing words: {}\n",
323 missing_words
324 .iter()
325 .map(|w| format!("\"{}\"", w))
326 .collect::<Vec<_>>()
327 .join(", ")
328 ));
329 }
330
331 if !extra_words.is_empty() {
332 feedback.push_str(&format!(
333 "Extra words: {}\n",
334 extra_words
335 .iter()
336 .map(|w| format!("\"{}\"", w))
337 .collect::<Vec<_>>()
338 .join(", ")
339 ));
340 }
341
342 FeedbackMetric::new(f1, feedback)
343}
344
345pub fn classification_feedback(
351 predicted_class: &str,
352 expected_class: &str,
353 confidence: Option<f32>,
354) -> FeedbackMetric {
355 let correct = predicted_class == expected_class;
356 let score = if correct { 1.0 } else { 0.0 };
357
358 let mut feedback = if correct {
359 format!("Correct classification: \"{}\"", predicted_class)
360 } else {
361 format!(
362 "Incorrect classification\n Expected: \"{}\"\n Predicted: \"{}\"",
363 expected_class, predicted_class
364 )
365 };
366
367 if let Some(conf) = confidence {
368 feedback.push_str(&format!("\n Confidence: {:.3}", conf));
369 }
370
371 let mut metadata = HashMap::new();
372 metadata.insert("predicted_class".to_string(), json!(predicted_class));
373 metadata.insert("expected_class".to_string(), json!(expected_class));
374 metadata.insert("correct".to_string(), json!(correct));
375
376 if let Some(conf) = confidence {
377 metadata.insert("confidence".to_string(), json!(conf));
378 }
379
380 FeedbackMetric::with_metadata(score, feedback, metadata)
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386
387 #[test]
388 fn test_retrieval_feedback_perfect() {
389 let retrieved = vec!["doc1", "doc2", "doc3"];
390 let expected = vec!["doc1", "doc2", "doc3"];
391
392 let feedback = retrieval_feedback(&retrieved, &expected, None::<&[&str]>);
393 assert_eq!(feedback.score, 1.0);
394 assert!(feedback.feedback.contains("3/3"));
395 }
396
397 #[test]
398 fn test_retrieval_feedback_partial() {
399 let retrieved = vec!["doc1", "doc2", "doc4"];
400 let expected = vec!["doc1", "doc2", "doc3"];
401
402 let feedback = retrieval_feedback(&retrieved, &expected, None::<&[&str]>);
403 assert!(feedback.score < 1.0 && feedback.score > 0.0);
404 assert!(feedback.feedback.contains("Missed: doc3"));
405 assert!(feedback.feedback.contains("Incorrectly retrieved: doc4"));
406 }
407
408 #[test]
409 fn test_code_pipeline_feedback() {
410 let stages = vec![
411 (CodeStage::Parse, StageResult::Success),
412 (CodeStage::Compile, StageResult::Success),
413 (
414 CodeStage::Execute,
415 StageResult::Failure {
416 error: "Division by zero".to_string(),
417 },
418 ),
419 ];
420
421 let feedback = code_pipeline_feedback(&stages, 0.6);
422 assert!(feedback.feedback.contains("Parse"));
423 assert!(feedback.feedback.contains("Compile"));
424 assert!(feedback.feedback.contains("Execute"));
425 assert_eq!(feedback.score, 0.6);
426 }
427
428 #[test]
429 fn test_multi_objective_feedback() {
430 let mut objectives = HashMap::new();
431 objectives.insert("accuracy".to_string(), (0.9, "Good accuracy".to_string()));
432 objectives.insert("latency".to_string(), (0.7, "Slow response".to_string()));
433
434 let feedback = multi_objective_feedback(&objectives, None);
435 assert!(feedback.feedback.contains("[accuracy]"));
436 assert!(feedback.feedback.contains("[latency]"));
437 assert!((feedback.score - 0.8).abs() < 0.01); }
439
440 #[test]
441 fn test_string_similarity_exact() {
442 let feedback = string_similarity_feedback("hello world", "hello world");
443 assert_eq!(feedback.score, 1.0);
444 }
445
446 #[test]
447 fn test_string_similarity_case() {
448 let feedback = string_similarity_feedback("Hello World", "hello world");
449 assert_eq!(feedback.score, 0.95);
450 }
451
452 #[test]
453 fn test_classification_feedback() {
454 let feedback = classification_feedback("positive", "positive", Some(0.95));
455 assert_eq!(feedback.score, 1.0);
456 assert!(feedback.feedback.contains("Correct"));
457
458 let feedback = classification_feedback("negative", "positive", Some(0.85));
459 assert_eq!(feedback.score, 0.0);
460 assert!(feedback.feedback.contains("Incorrect"));
461 }
462}