Skip to main content

do_memory_core/reward/
mod.rs

1//! # Reward Calculator
2//!
3//! Calculates reward scores for episodes based on outcome, efficiency, and quality.
4//! Supports both fixed thresholds and adaptive domain-based calibration.
5//!
6//! ## Modules
7//!
8//! - `domain_stats` - Domain-specific statistics for adaptive calibration
9//! - `adaptive` - Adaptive reward calculator using domain baselines
10//!
11
12#![allow(clippy::if_not_else)]
13#![allow(clippy::cast_precision_loss)]
14#![allow(clippy::map_unwrap_or)]
15#![allow(clippy::doc_markdown)]
16
17// Public modules
18pub mod adaptive;
19pub mod domain_stats;
20
21#[cfg(feature = "agentfs")]
22pub mod external;
23
24#[cfg(test)]
25pub mod tests;
26
27// Re-export for convenience
28pub use adaptive::AdaptiveRewardCalculator;
29pub use domain_stats::{DomainStatistics, DomainStatisticsCache};
30
31use crate::episode::Episode;
32use crate::types::{ComplexityLevel, RewardScore, TaskOutcome};
33use tracing::{debug, instrument};
34
35/// Threshold for "efficient" episode duration (in seconds)
36const EFFICIENT_DURATION_SECS: f32 = 60.0;
37
38/// Threshold for "efficient" step count
39const EFFICIENT_STEP_COUNT: usize = 10;
40
41/// Maximum efficiency multiplier
42const MAX_EFFICIENCY_MULTIPLIER: f32 = 1.5;
43
44/// Minimum efficiency multiplier
45const MIN_EFFICIENCY_MULTIPLIER: f32 = 0.5;
46
47/// Calculator for episode reward scores
48#[derive(Clone)]
49pub struct RewardCalculator {
50    /// Weight for duration in efficiency calculation
51    duration_weight: f32,
52    /// Weight for step count in efficiency calculation
53    step_count_weight: f32,
54}
55
56impl Default for RewardCalculator {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl RewardCalculator {
63    /// Create a new reward calculator with default weights
64    #[must_use]
65    pub fn new() -> Self {
66        Self {
67            duration_weight: 0.5,
68            step_count_weight: 0.5,
69        }
70    }
71
72    /// Create a calculator with custom weights
73    #[must_use]
74    pub fn with_weights(duration_weight: f32, step_count_weight: f32) -> Self {
75        Self {
76            duration_weight,
77            step_count_weight,
78        }
79    }
80
81    /// Calculate reward score for an episode
82    #[instrument(skip(self, episode), fields(episode_id = %episode.episode_id))]
83    pub fn calculate(&self, episode: &Episode) -> RewardScore {
84        let base = self.calculate_base_reward(episode);
85        let efficiency = self.calculate_efficiency_multiplier(episode);
86        let complexity_bonus = self.calculate_complexity_bonus(episode);
87        let quality_multiplier = self.calculate_quality_multiplier(episode);
88        let learning_bonus = self.calculate_learning_bonus(episode);
89
90        // Calculate total: base reward * multipliers + bonuses
91        let total = (base * efficiency * complexity_bonus * quality_multiplier) + learning_bonus;
92
93        debug!(
94            base = base,
95            efficiency = efficiency,
96            complexity_bonus = complexity_bonus,
97            quality_multiplier = quality_multiplier,
98            learning_bonus = learning_bonus,
99            total = total,
100            "Calculated reward score"
101        );
102
103        RewardScore {
104            total,
105            base,
106            efficiency,
107            complexity_bonus,
108            quality_multiplier,
109            learning_bonus,
110        }
111    }
112
113    /// Calculate base reward from outcome
114    fn calculate_base_reward(&self, episode: &Episode) -> f32 {
115        match &episode.outcome {
116            Some(TaskOutcome::Success { .. }) => 1.0,
117            Some(TaskOutcome::PartialSuccess {
118                completed, failed, ..
119            }) => {
120                // Proportional reward based on completion ratio
121                let total = completed.len() + failed.len();
122                if total == 0 {
123                    0.5 // Default for partial success with no specifics
124                } else {
125                    completed.len() as f32 / total as f32
126                }
127            }
128            Some(TaskOutcome::Failure { .. }) => 0.0,
129            None => 0.0, // Not completed
130        }
131    }
132
133    /// Calculate efficiency multiplier based on duration and step count
134    fn calculate_efficiency_multiplier(&self, episode: &Episode) -> f32 {
135        let duration_score = self.calculate_duration_efficiency(episode);
136        let step_count_score = self.calculate_step_count_efficiency(episode);
137
138        let combined =
139            (duration_score * self.duration_weight) + (step_count_score * self.step_count_weight);
140
141        // Clamp to reasonable bounds
142        combined.clamp(MIN_EFFICIENCY_MULTIPLIER, MAX_EFFICIENCY_MULTIPLIER)
143    }
144
145    /// Calculate duration efficiency score
146    fn calculate_duration_efficiency(&self, episode: &Episode) -> f32 {
147        if let Some(duration) = episode.duration() {
148            let duration_secs = duration.num_seconds() as f32;
149
150            if duration_secs <= 0.0 {
151                return MAX_EFFICIENCY_MULTIPLIER;
152            }
153
154            // Efficiency decreases as duration increases
155            // Exponential decay: e^(-x/threshold)
156            let ratio = duration_secs / EFFICIENT_DURATION_SECS;
157            let score = (-ratio / 2.0).exp();
158
159            // Map to multiplier range
160            MIN_EFFICIENCY_MULTIPLIER
161                + (score * (MAX_EFFICIENCY_MULTIPLIER - MIN_EFFICIENCY_MULTIPLIER))
162        } else {
163            1.0 // Default if no duration
164        }
165    }
166
167    /// Calculate step count efficiency score
168    fn calculate_step_count_efficiency(&self, episode: &Episode) -> f32 {
169        let step_count = episode.steps.len();
170
171        if step_count == 0 {
172            return MIN_EFFICIENCY_MULTIPLIER;
173        }
174
175        // Efficiency decreases as step count increases
176        let ratio = step_count as f32 / EFFICIENT_STEP_COUNT as f32;
177        let score = (-ratio / 2.0).exp();
178
179        // Map to multiplier range
180        MIN_EFFICIENCY_MULTIPLIER
181            + (score * (MAX_EFFICIENCY_MULTIPLIER - MIN_EFFICIENCY_MULTIPLIER))
182    }
183
184    /// Calculate complexity bonus multiplier
185    fn calculate_complexity_bonus(&self, episode: &Episode) -> f32 {
186        match episode.context.complexity {
187            ComplexityLevel::Simple => 1.0,
188            ComplexityLevel::Moderate => 1.1,
189            ComplexityLevel::Complex => 1.2,
190        }
191    }
192
193    /// Calculate quality multiplier based on code quality metrics
194    ///
195    /// Analyzes artifacts and execution quality to determine a multiplier.
196    /// Factors include:
197    /// - Test coverage (detected from artifacts)
198    /// - Code quality indicators (linting, formatting)
199    /// - Error handling quality (low error rate)
200    fn calculate_quality_multiplier(&self, episode: &Episode) -> f32 {
201        let mut quality: f32 = 1.0;
202
203        // Analyze artifacts for quality indicators
204        if let Some(TaskOutcome::Success { artifacts, .. }) = &episode.outcome {
205            // Bonus for test coverage artifacts
206            let has_test_coverage = artifacts
207                .iter()
208                .any(|a| a.contains("coverage") || a.contains("test"));
209            if has_test_coverage {
210                quality += 0.1;
211            }
212
213            // Bonus for multiple quality artifacts (docs, tests, etc.)
214            if artifacts.len() >= 3 {
215                quality += 0.05;
216            }
217
218            // Check for quality-related metadata
219            if let Some(coverage_str) = episode.metadata.get("test_coverage") {
220                if let Ok(coverage) = coverage_str.parse::<f32>() {
221                    // Bonus for high test coverage (>80%)
222                    #[allow(clippy::excessive_nesting)]
223                    if coverage > 80.0 {
224                        quality += 0.15;
225                    } else if coverage > 60.0 {
226                        quality += 0.1;
227                    }
228                }
229            }
230        }
231
232        // Quality based on error handling
233        let total_steps = episode.steps.len();
234        if total_steps > 0 {
235            let error_rate = episode.failed_steps_count() as f32 / total_steps as f32;
236
237            // Penalize high error rates
238            if error_rate > 0.3 {
239                quality -= 0.2;
240            } else if error_rate > 0.1 {
241                quality -= 0.1;
242            } else if error_rate == 0.0 {
243                // Bonus for zero errors
244                quality += 0.1;
245            }
246        }
247
248        // Check for linting/formatting indicators
249        if episode.metadata.contains_key("clippy_warnings") {
250            if let Some(warnings) = episode.metadata.get("clippy_warnings") {
251                if warnings == "0" {
252                    quality += 0.05;
253                }
254            }
255        }
256
257        // Clamp to reasonable bounds (0.5 to 1.5)
258        quality.clamp(0.5, 1.5)
259    }
260
261    /// Calculate learning bonus for discovering patterns and improvements
262    ///
263    /// Awards bonus points for:
264    /// - Discovering new patterns (novel approaches)
265    /// - Improving on past attempts (learning from history)
266    /// - Efficient problem-solving (first-time success)
267    fn calculate_learning_bonus(&self, episode: &Episode) -> f32 {
268        let mut bonus = 0.0;
269
270        // Bonus for discovering new patterns
271        let pattern_count = episode.patterns.len();
272        if pattern_count > 0 {
273            // More patterns = more learning
274            bonus += (pattern_count as f32 * 0.1).min(0.3);
275        }
276
277        // Bonus for novel tool sequences
278        if let Some(novelty) = self.calculate_novelty_bonus(episode) {
279            bonus += novelty;
280        }
281
282        // Bonus for efficient problem solving (high success rate)
283        let total_steps = episode.steps.len();
284        if total_steps > 0 {
285            let success_rate = episode.successful_steps_count() as f32 / total_steps as f32;
286
287            if success_rate > 0.9 && total_steps >= 5 {
288                // High reliability with meaningful complexity
289                bonus += 0.2;
290            } else if success_rate == 1.0 && total_steps >= 3 {
291                // Perfect execution
292                bonus += 0.15;
293            }
294        }
295
296        // Bonus for error recovery (learning from failures)
297        if self.detect_error_recovery(episode) {
298            bonus += 0.15;
299        }
300
301        // Bonus for optimization (completing quickly with few steps)
302        if let Some(duration) = episode.duration() {
303            let duration_secs = duration.num_seconds() as f32;
304            if duration_secs < 30.0 && total_steps > 0 && total_steps < 10 {
305                bonus += 0.1;
306            }
307        }
308
309        // Cap learning bonus
310        bonus.min(0.5)
311    }
312
313    /// Calculate novelty bonus for unique tool combinations
314    fn calculate_novelty_bonus(&self, episode: &Episode) -> Option<f32> {
315        if episode.steps.len() < 3 {
316            return None;
317        }
318
319        // Count unique tools used
320        let unique_tools: std::collections::HashSet<_> =
321            episode.steps.iter().map(|s| &s.tool).collect();
322
323        // Bonus for diverse tool usage
324        if unique_tools.len() >= 5 {
325            Some(0.15)
326        } else if unique_tools.len() >= 3 {
327            Some(0.1)
328        } else {
329            None
330        }
331    }
332
333    /// Detect if the episode shows error recovery
334    fn detect_error_recovery(&self, episode: &Episode) -> bool {
335        for i in 0..episode.steps.len().saturating_sub(1) {
336            let current = &episode.steps[i];
337            let next = &episode.steps[i + 1];
338
339            // Error followed by success = recovery
340            if !current.is_success() && next.is_success() {
341                return true;
342            }
343        }
344        false
345    }
346
347    /// Calculate adoption bonus for patterns that were recommended AND applied AND succeeded.
348    ///
349    /// This bonus rewards episodes where the agent successfully applied recommended patterns.
350    /// The bonus is proportional to the number of successfully adopted patterns.
351    ///
352    /// # Arguments
353    ///
354    /// * `applied_pattern_ids` - Pattern IDs that were actually applied
355    /// * `outcome_success` - Whether the episode outcome was successful
356    ///
357    /// # Returns
358    ///
359    /// Bonus value between 0.0 and 0.3 (max 30% bonus for 3+ successful adoptions)
360    ///
361    /// # Example
362    ///
363    /// ```
364    /// use do_memory_core::reward::RewardCalculator;
365    ///
366    /// let calculator = RewardCalculator::new();
367    ///
368    /// // 2 patterns applied successfully
369    /// let bonus = calculator.calculate_adoption_bonus(
370    ///     &["p1".to_string(), "p2".to_string()],
371    ///     true
372    /// );
373    /// assert!(bonus > 0.0);
374    ///
375    /// // No bonus for failed outcome
376    /// let no_bonus = calculator.calculate_adoption_bonus(
377    ///     &["p1".to_string()],
378    ///     false
379    /// );
380    /// assert_eq!(no_bonus, 0.0);
381    /// ```
382    #[must_use]
383    pub fn calculate_adoption_bonus(
384        &self,
385        applied_pattern_ids: &[String],
386        outcome_success: bool,
387    ) -> f32 {
388        if !outcome_success || applied_pattern_ids.is_empty() {
389            return 0.0;
390        }
391
392        // Bonus scales with number of successfully adopted patterns
393        // 1 pattern = 0.1, 2 patterns = 0.2, 3+ patterns = 0.3 (capped)
394        let pattern_count = applied_pattern_ids.len();
395        (pattern_count as f32 * 0.1).min(0.3)
396    }
397}
398
399#[cfg(test)]
400mod adoption_bonus_tests {
401    use super::*;
402
403    #[test]
404    fn test_adoption_bonus_no_patterns() {
405        let calc = RewardCalculator::new();
406        let bonus = calc.calculate_adoption_bonus(&[], true);
407        assert_eq!(bonus, 0.0);
408    }
409
410    #[test]
411    fn test_adoption_bonus_failed_outcome() {
412        let calc = RewardCalculator::new();
413        let bonus = calc.calculate_adoption_bonus(&["p1".to_string()], false);
414        assert_eq!(bonus, 0.0);
415    }
416
417    #[test]
418    fn test_adoption_bonus_single_pattern() {
419        let calc = RewardCalculator::new();
420        let bonus = calc.calculate_adoption_bonus(&["p1".to_string()], true);
421        assert!((bonus - 0.1).abs() < 0.01);
422    }
423
424    #[test]
425    fn test_adoption_bonus_multiple_patterns() {
426        let calc = RewardCalculator::new();
427        let bonus = calc.calculate_adoption_bonus(
428            &["p1".to_string(), "p2".to_string(), "p3".to_string()],
429            true,
430        );
431        assert!((bonus - 0.3).abs() < 0.01); // Capped at 0.3
432    }
433}