do_memory_core/reward/
mod.rs1#![allow(clippy::if_not_else)]
13#![allow(clippy::cast_precision_loss)]
14#![allow(clippy::map_unwrap_or)]
15#![allow(clippy::doc_markdown)]
16
17pub mod adaptive;
19pub mod domain_stats;
20
21#[cfg(feature = "agentfs")]
22pub mod external;
23
24#[cfg(test)]
25pub mod tests;
26
27pub use adaptive::AdaptiveRewardCalculator;
29pub use domain_stats::{DomainStatistics, DomainStatisticsCache};
30
31use crate::episode::Episode;
32use crate::types::{ComplexityLevel, RewardScore, TaskOutcome};
33use tracing::{debug, instrument};
34
35const EFFICIENT_DURATION_SECS: f32 = 60.0;
37
38const EFFICIENT_STEP_COUNT: usize = 10;
40
41const MAX_EFFICIENCY_MULTIPLIER: f32 = 1.5;
43
44const MIN_EFFICIENCY_MULTIPLIER: f32 = 0.5;
46
47#[derive(Clone)]
49pub struct RewardCalculator {
50 duration_weight: f32,
52 step_count_weight: f32,
54}
55
56impl Default for RewardCalculator {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl RewardCalculator {
63 #[must_use]
65 pub fn new() -> Self {
66 Self {
67 duration_weight: 0.5,
68 step_count_weight: 0.5,
69 }
70 }
71
72 #[must_use]
74 pub fn with_weights(duration_weight: f32, step_count_weight: f32) -> Self {
75 Self {
76 duration_weight,
77 step_count_weight,
78 }
79 }
80
81 #[instrument(skip(self, episode), fields(episode_id = %episode.episode_id))]
83 pub fn calculate(&self, episode: &Episode) -> RewardScore {
84 let base = self.calculate_base_reward(episode);
85 let efficiency = self.calculate_efficiency_multiplier(episode);
86 let complexity_bonus = self.calculate_complexity_bonus(episode);
87 let quality_multiplier = self.calculate_quality_multiplier(episode);
88 let learning_bonus = self.calculate_learning_bonus(episode);
89
90 let total = (base * efficiency * complexity_bonus * quality_multiplier) + learning_bonus;
92
93 debug!(
94 base = base,
95 efficiency = efficiency,
96 complexity_bonus = complexity_bonus,
97 quality_multiplier = quality_multiplier,
98 learning_bonus = learning_bonus,
99 total = total,
100 "Calculated reward score"
101 );
102
103 RewardScore {
104 total,
105 base,
106 efficiency,
107 complexity_bonus,
108 quality_multiplier,
109 learning_bonus,
110 }
111 }
112
113 fn calculate_base_reward(&self, episode: &Episode) -> f32 {
115 match &episode.outcome {
116 Some(TaskOutcome::Success { .. }) => 1.0,
117 Some(TaskOutcome::PartialSuccess {
118 completed, failed, ..
119 }) => {
120 let total = completed.len() + failed.len();
122 if total == 0 {
123 0.5 } else {
125 completed.len() as f32 / total as f32
126 }
127 }
128 Some(TaskOutcome::Failure { .. }) => 0.0,
129 None => 0.0, }
131 }
132
133 fn calculate_efficiency_multiplier(&self, episode: &Episode) -> f32 {
135 let duration_score = self.calculate_duration_efficiency(episode);
136 let step_count_score = self.calculate_step_count_efficiency(episode);
137
138 let combined =
139 (duration_score * self.duration_weight) + (step_count_score * self.step_count_weight);
140
141 combined.clamp(MIN_EFFICIENCY_MULTIPLIER, MAX_EFFICIENCY_MULTIPLIER)
143 }
144
145 fn calculate_duration_efficiency(&self, episode: &Episode) -> f32 {
147 if let Some(duration) = episode.duration() {
148 let duration_secs = duration.num_seconds() as f32;
149
150 if duration_secs <= 0.0 {
151 return MAX_EFFICIENCY_MULTIPLIER;
152 }
153
154 let ratio = duration_secs / EFFICIENT_DURATION_SECS;
157 let score = (-ratio / 2.0).exp();
158
159 MIN_EFFICIENCY_MULTIPLIER
161 + (score * (MAX_EFFICIENCY_MULTIPLIER - MIN_EFFICIENCY_MULTIPLIER))
162 } else {
163 1.0 }
165 }
166
167 fn calculate_step_count_efficiency(&self, episode: &Episode) -> f32 {
169 let step_count = episode.steps.len();
170
171 if step_count == 0 {
172 return MIN_EFFICIENCY_MULTIPLIER;
173 }
174
175 let ratio = step_count as f32 / EFFICIENT_STEP_COUNT as f32;
177 let score = (-ratio / 2.0).exp();
178
179 MIN_EFFICIENCY_MULTIPLIER
181 + (score * (MAX_EFFICIENCY_MULTIPLIER - MIN_EFFICIENCY_MULTIPLIER))
182 }
183
184 fn calculate_complexity_bonus(&self, episode: &Episode) -> f32 {
186 match episode.context.complexity {
187 ComplexityLevel::Simple => 1.0,
188 ComplexityLevel::Moderate => 1.1,
189 ComplexityLevel::Complex => 1.2,
190 }
191 }
192
193 fn calculate_quality_multiplier(&self, episode: &Episode) -> f32 {
201 let mut quality: f32 = 1.0;
202
203 if let Some(TaskOutcome::Success { artifacts, .. }) = &episode.outcome {
205 let has_test_coverage = artifacts
207 .iter()
208 .any(|a| a.contains("coverage") || a.contains("test"));
209 if has_test_coverage {
210 quality += 0.1;
211 }
212
213 if artifacts.len() >= 3 {
215 quality += 0.05;
216 }
217
218 if let Some(coverage_str) = episode.metadata.get("test_coverage") {
220 if let Ok(coverage) = coverage_str.parse::<f32>() {
221 #[allow(clippy::excessive_nesting)]
223 if coverage > 80.0 {
224 quality += 0.15;
225 } else if coverage > 60.0 {
226 quality += 0.1;
227 }
228 }
229 }
230 }
231
232 let total_steps = episode.steps.len();
234 if total_steps > 0 {
235 let error_rate = episode.failed_steps_count() as f32 / total_steps as f32;
236
237 if error_rate > 0.3 {
239 quality -= 0.2;
240 } else if error_rate > 0.1 {
241 quality -= 0.1;
242 } else if error_rate == 0.0 {
243 quality += 0.1;
245 }
246 }
247
248 if episode.metadata.contains_key("clippy_warnings") {
250 if let Some(warnings) = episode.metadata.get("clippy_warnings") {
251 if warnings == "0" {
252 quality += 0.05;
253 }
254 }
255 }
256
257 quality.clamp(0.5, 1.5)
259 }
260
261 fn calculate_learning_bonus(&self, episode: &Episode) -> f32 {
268 let mut bonus = 0.0;
269
270 let pattern_count = episode.patterns.len();
272 if pattern_count > 0 {
273 bonus += (pattern_count as f32 * 0.1).min(0.3);
275 }
276
277 if let Some(novelty) = self.calculate_novelty_bonus(episode) {
279 bonus += novelty;
280 }
281
282 let total_steps = episode.steps.len();
284 if total_steps > 0 {
285 let success_rate = episode.successful_steps_count() as f32 / total_steps as f32;
286
287 if success_rate > 0.9 && total_steps >= 5 {
288 bonus += 0.2;
290 } else if success_rate == 1.0 && total_steps >= 3 {
291 bonus += 0.15;
293 }
294 }
295
296 if self.detect_error_recovery(episode) {
298 bonus += 0.15;
299 }
300
301 if let Some(duration) = episode.duration() {
303 let duration_secs = duration.num_seconds() as f32;
304 if duration_secs < 30.0 && total_steps > 0 && total_steps < 10 {
305 bonus += 0.1;
306 }
307 }
308
309 bonus.min(0.5)
311 }
312
313 fn calculate_novelty_bonus(&self, episode: &Episode) -> Option<f32> {
315 if episode.steps.len() < 3 {
316 return None;
317 }
318
319 let unique_tools: std::collections::HashSet<_> =
321 episode.steps.iter().map(|s| &s.tool).collect();
322
323 if unique_tools.len() >= 5 {
325 Some(0.15)
326 } else if unique_tools.len() >= 3 {
327 Some(0.1)
328 } else {
329 None
330 }
331 }
332
333 fn detect_error_recovery(&self, episode: &Episode) -> bool {
335 for i in 0..episode.steps.len().saturating_sub(1) {
336 let current = &episode.steps[i];
337 let next = &episode.steps[i + 1];
338
339 if !current.is_success() && next.is_success() {
341 return true;
342 }
343 }
344 false
345 }
346
347 #[must_use]
383 pub fn calculate_adoption_bonus(
384 &self,
385 applied_pattern_ids: &[String],
386 outcome_success: bool,
387 ) -> f32 {
388 if !outcome_success || applied_pattern_ids.is_empty() {
389 return 0.0;
390 }
391
392 let pattern_count = applied_pattern_ids.len();
395 (pattern_count as f32 * 0.1).min(0.3)
396 }
397}
398
399#[cfg(test)]
400mod adoption_bonus_tests {
401 use super::*;
402
403 #[test]
404 fn test_adoption_bonus_no_patterns() {
405 let calc = RewardCalculator::new();
406 let bonus = calc.calculate_adoption_bonus(&[], true);
407 assert_eq!(bonus, 0.0);
408 }
409
410 #[test]
411 fn test_adoption_bonus_failed_outcome() {
412 let calc = RewardCalculator::new();
413 let bonus = calc.calculate_adoption_bonus(&["p1".to_string()], false);
414 assert_eq!(bonus, 0.0);
415 }
416
417 #[test]
418 fn test_adoption_bonus_single_pattern() {
419 let calc = RewardCalculator::new();
420 let bonus = calc.calculate_adoption_bonus(&["p1".to_string()], true);
421 assert!((bonus - 0.1).abs() < 0.01);
422 }
423
424 #[test]
425 fn test_adoption_bonus_multiple_patterns() {
426 let calc = RewardCalculator::new();
427 let bonus = calc.calculate_adoption_bonus(
428 &["p1".to_string(), "p2".to_string(), "p3".to_string()],
429 true,
430 );
431 assert!((bonus - 0.3).abs() < 0.01); }
433}