#![allow(clippy::if_not_else)]
#![allow(clippy::cast_precision_loss)]
#![allow(clippy::map_unwrap_or)]
#![allow(clippy::doc_markdown)]
pub mod adaptive;
pub mod domain_stats;
#[cfg(feature = "agentfs")]
pub mod external;
#[cfg(test)]
pub mod tests;
pub use adaptive::AdaptiveRewardCalculator;
pub use domain_stats::{DomainStatistics, DomainStatisticsCache};
use crate::episode::Episode;
use crate::types::{ComplexityLevel, RewardScore, TaskOutcome};
use tracing::{debug, instrument};
const EFFICIENT_DURATION_SECS: f32 = 60.0;
const EFFICIENT_STEP_COUNT: usize = 10;
const MAX_EFFICIENCY_MULTIPLIER: f32 = 1.5;
const MIN_EFFICIENCY_MULTIPLIER: f32 = 0.5;
#[derive(Clone)]
pub struct RewardCalculator {
duration_weight: f32,
step_count_weight: f32,
}
impl Default for RewardCalculator {
fn default() -> Self {
Self::new()
}
}
impl RewardCalculator {
#[must_use]
pub fn new() -> Self {
Self {
duration_weight: 0.5,
step_count_weight: 0.5,
}
}
#[must_use]
pub fn with_weights(duration_weight: f32, step_count_weight: f32) -> Self {
Self {
duration_weight,
step_count_weight,
}
}
#[instrument(skip(self, episode), fields(episode_id = %episode.episode_id))]
pub fn calculate(&self, episode: &Episode) -> RewardScore {
let base = self.calculate_base_reward(episode);
let efficiency = self.calculate_efficiency_multiplier(episode);
let complexity_bonus = self.calculate_complexity_bonus(episode);
let quality_multiplier = self.calculate_quality_multiplier(episode);
let learning_bonus = self.calculate_learning_bonus(episode);
let total = (base * efficiency * complexity_bonus * quality_multiplier) + learning_bonus;
debug!(
base = base,
efficiency = efficiency,
complexity_bonus = complexity_bonus,
quality_multiplier = quality_multiplier,
learning_bonus = learning_bonus,
total = total,
"Calculated reward score"
);
RewardScore {
total,
base,
efficiency,
complexity_bonus,
quality_multiplier,
learning_bonus,
}
}
fn calculate_base_reward(&self, episode: &Episode) -> f32 {
match &episode.outcome {
Some(TaskOutcome::Success { .. }) => 1.0,
Some(TaskOutcome::PartialSuccess {
completed, failed, ..
}) => {
let total = completed.len() + failed.len();
if total == 0 {
0.5 } else {
completed.len() as f32 / total as f32
}
}
Some(TaskOutcome::Failure { .. }) => 0.0,
None => 0.0, }
}
fn calculate_efficiency_multiplier(&self, episode: &Episode) -> f32 {
let duration_score = self.calculate_duration_efficiency(episode);
let step_count_score = self.calculate_step_count_efficiency(episode);
let combined =
(duration_score * self.duration_weight) + (step_count_score * self.step_count_weight);
combined.clamp(MIN_EFFICIENCY_MULTIPLIER, MAX_EFFICIENCY_MULTIPLIER)
}
fn calculate_duration_efficiency(&self, episode: &Episode) -> f32 {
if let Some(duration) = episode.duration() {
let duration_secs = duration.num_seconds() as f32;
if duration_secs <= 0.0 {
return MAX_EFFICIENCY_MULTIPLIER;
}
let ratio = duration_secs / EFFICIENT_DURATION_SECS;
let score = (-ratio / 2.0).exp();
MIN_EFFICIENCY_MULTIPLIER
+ (score * (MAX_EFFICIENCY_MULTIPLIER - MIN_EFFICIENCY_MULTIPLIER))
} else {
1.0 }
}
fn calculate_step_count_efficiency(&self, episode: &Episode) -> f32 {
let step_count = episode.steps.len();
if step_count == 0 {
return MIN_EFFICIENCY_MULTIPLIER;
}
let ratio = step_count as f32 / EFFICIENT_STEP_COUNT as f32;
let score = (-ratio / 2.0).exp();
MIN_EFFICIENCY_MULTIPLIER
+ (score * (MAX_EFFICIENCY_MULTIPLIER - MIN_EFFICIENCY_MULTIPLIER))
}
fn calculate_complexity_bonus(&self, episode: &Episode) -> f32 {
match episode.context.complexity {
ComplexityLevel::Simple => 1.0,
ComplexityLevel::Moderate => 1.1,
ComplexityLevel::Complex => 1.2,
}
}
fn calculate_quality_multiplier(&self, episode: &Episode) -> f32 {
let mut quality: f32 = 1.0;
if let Some(TaskOutcome::Success { artifacts, .. }) = &episode.outcome {
let has_test_coverage = artifacts
.iter()
.any(|a| a.contains("coverage") || a.contains("test"));
if has_test_coverage {
quality += 0.1;
}
if artifacts.len() >= 3 {
quality += 0.05;
}
if let Some(coverage_str) = episode.metadata.get("test_coverage") {
if let Ok(coverage) = coverage_str.parse::<f32>() {
#[allow(clippy::excessive_nesting)]
if coverage > 80.0 {
quality += 0.15;
} else if coverage > 60.0 {
quality += 0.1;
}
}
}
}
let total_steps = episode.steps.len();
if total_steps > 0 {
let error_rate = episode.failed_steps_count() as f32 / total_steps as f32;
if error_rate > 0.3 {
quality -= 0.2;
} else if error_rate > 0.1 {
quality -= 0.1;
} else if error_rate == 0.0 {
quality += 0.1;
}
}
if episode.metadata.contains_key("clippy_warnings") {
if let Some(warnings) = episode.metadata.get("clippy_warnings") {
if warnings == "0" {
quality += 0.05;
}
}
}
quality.clamp(0.5, 1.5)
}
fn calculate_learning_bonus(&self, episode: &Episode) -> f32 {
let mut bonus = 0.0;
let pattern_count = episode.patterns.len();
if pattern_count > 0 {
bonus += (pattern_count as f32 * 0.1).min(0.3);
}
if let Some(novelty) = self.calculate_novelty_bonus(episode) {
bonus += novelty;
}
let total_steps = episode.steps.len();
if total_steps > 0 {
let success_rate = episode.successful_steps_count() as f32 / total_steps as f32;
if success_rate > 0.9 && total_steps >= 5 {
bonus += 0.2;
} else if success_rate == 1.0 && total_steps >= 3 {
bonus += 0.15;
}
}
if self.detect_error_recovery(episode) {
bonus += 0.15;
}
if let Some(duration) = episode.duration() {
let duration_secs = duration.num_seconds() as f32;
if duration_secs < 30.0 && total_steps > 0 && total_steps < 10 {
bonus += 0.1;
}
}
bonus.min(0.5)
}
fn calculate_novelty_bonus(&self, episode: &Episode) -> Option<f32> {
if episode.steps.len() < 3 {
return None;
}
let unique_tools: std::collections::HashSet<_> =
episode.steps.iter().map(|s| &s.tool).collect();
if unique_tools.len() >= 5 {
Some(0.15)
} else if unique_tools.len() >= 3 {
Some(0.1)
} else {
None
}
}
fn detect_error_recovery(&self, episode: &Episode) -> bool {
for i in 0..episode.steps.len().saturating_sub(1) {
let current = &episode.steps[i];
let next = &episode.steps[i + 1];
if !current.is_success() && next.is_success() {
return true;
}
}
false
}
#[must_use]
pub fn calculate_adoption_bonus(
&self,
applied_pattern_ids: &[String],
outcome_success: bool,
) -> f32 {
if !outcome_success || applied_pattern_ids.is_empty() {
return 0.0;
}
let pattern_count = applied_pattern_ids.len();
(pattern_count as f32 * 0.1).min(0.3)
}
}
#[cfg(test)]
mod adoption_bonus_tests {
use super::*;
#[test]
fn test_adoption_bonus_no_patterns() {
let calc = RewardCalculator::new();
let bonus = calc.calculate_adoption_bonus(&[], true);
assert_eq!(bonus, 0.0);
}
#[test]
fn test_adoption_bonus_failed_outcome() {
let calc = RewardCalculator::new();
let bonus = calc.calculate_adoption_bonus(&["p1".to_string()], false);
assert_eq!(bonus, 0.0);
}
#[test]
fn test_adoption_bonus_single_pattern() {
let calc = RewardCalculator::new();
let bonus = calc.calculate_adoption_bonus(&["p1".to_string()], true);
assert!((bonus - 0.1).abs() < 0.01);
}
#[test]
fn test_adoption_bonus_multiple_patterns() {
let calc = RewardCalculator::new();
let bonus = calc.calculate_adoption_bonus(
&["p1".to_string(), "p2".to_string(), "p3".to_string()],
true,
);
assert!((bonus - 0.3).abs() < 0.01); }
}