use crate::episode::Episode;
use crate::types::{ComplexityLevel, RewardScore, TaskOutcome};
use tracing::{debug, instrument};
use crate::reward::domain_stats::DomainStatistics;
#[derive(Clone)]
pub struct AdaptiveRewardCalculator {
pub duration_weight: f32,
pub step_count_weight: f32,
pub fallback_duration_secs: f32,
pub fallback_step_count: usize,
}
impl Default for AdaptiveRewardCalculator {
fn default() -> Self {
Self::new()
}
}
impl AdaptiveRewardCalculator {
#[must_use]
pub fn new() -> Self {
Self {
duration_weight: 0.5,
step_count_weight: 0.5,
fallback_duration_secs: 60.0,
fallback_step_count: 10,
}
}
#[must_use]
pub fn with_config(
duration_weight: f32,
step_count_weight: f32,
fallback_duration_secs: f32,
fallback_step_count: usize,
) -> Self {
Self {
duration_weight,
step_count_weight,
fallback_duration_secs,
fallback_step_count,
}
}
#[instrument(skip(self, episode, domain_stats), fields(episode_id = %episode.episode_id))]
pub fn calculate(
&self,
episode: &Episode,
domain_stats: Option<&DomainStatistics>,
) -> RewardScore {
let base = self.calculate_base_reward(episode);
let efficiency = if let Some(stats) = domain_stats {
if stats.is_reliable() {
self.calculate_adaptive_efficiency(episode, stats)
} else {
self.calculate_fixed_efficiency(episode)
}
} else {
self.calculate_fixed_efficiency(episode)
};
let complexity_bonus = self.calculate_complexity_bonus(episode);
let quality_multiplier = self.calculate_quality_multiplier(episode);
let learning_bonus = self.calculate_learning_bonus(episode);
let total = (base * efficiency * complexity_bonus * quality_multiplier) + learning_bonus;
debug!(
base = base,
efficiency = efficiency,
complexity_bonus = complexity_bonus,
quality_multiplier = quality_multiplier,
learning_bonus = learning_bonus,
total = total,
adaptive = domain_stats.map(|s| s.is_reliable()).unwrap_or(false),
"Calculated adaptive reward score"
);
RewardScore {
total,
base,
efficiency,
complexity_bonus,
quality_multiplier,
learning_bonus,
}
}
fn calculate_base_reward(&self, episode: &Episode) -> f32 {
match &episode.outcome {
Some(TaskOutcome::Success { .. }) => 1.0,
Some(TaskOutcome::PartialSuccess {
completed, failed, ..
}) => {
let total = completed.len() + failed.len();
if total == 0 {
0.5
} else {
completed.len() as f32 / total as f32
}
}
Some(TaskOutcome::Failure { .. }) => 0.0,
None => 0.0,
}
}
fn calculate_adaptive_efficiency(&self, episode: &Episode, stats: &DomainStatistics) -> f32 {
let duration_score = if let Some(duration) = episode.duration() {
let duration_secs = duration.num_seconds() as f32;
if duration_secs <= 0.0 {
return 1.5; }
let baseline = stats.p50_duration_secs.max(1.0); let ratio = duration_secs / baseline;
let score = (-ratio / 2.0).exp();
0.5 + (score * 1.0)
} else {
1.0
};
let step_count_score = {
let step_count = episode.steps.len();
if step_count == 0 {
return 0.5;
}
let baseline = stats.p50_step_count.max(1); let ratio = step_count as f32 / baseline as f32;
let score = (-ratio / 2.0).exp();
0.5 + (score * 1.0)
};
let combined =
(duration_score * self.duration_weight) + (step_count_score * self.step_count_weight);
combined.clamp(0.5, 1.5)
}
fn calculate_fixed_efficiency(&self, episode: &Episode) -> f32 {
let duration_score = if let Some(duration) = episode.duration() {
let duration_secs = duration.num_seconds() as f32;
if duration_secs <= 0.0 {
return 1.5;
}
let ratio = duration_secs / self.fallback_duration_secs;
let score = (-ratio / 2.0).exp();
0.5 + (score * 1.0)
} else {
1.0
};
let step_count_score = {
let step_count = episode.steps.len();
if step_count == 0 {
return 0.5;
}
let ratio = step_count as f32 / self.fallback_step_count as f32;
let score = (-ratio / 2.0).exp();
0.5 + (score * 1.0)
};
let combined =
(duration_score * self.duration_weight) + (step_count_score * self.step_count_weight);
combined.clamp(0.5, 1.5)
}
fn calculate_complexity_bonus(&self, episode: &Episode) -> f32 {
match episode.context.complexity {
ComplexityLevel::Simple => 1.0,
ComplexityLevel::Moderate => 1.1,
ComplexityLevel::Complex => 1.2,
}
}
fn calculate_quality_multiplier(&self, episode: &Episode) -> f32 {
let mut quality: f32 = 1.0;
if let Some(TaskOutcome::Success { artifacts, .. }) = &episode.outcome {
let has_test_coverage = artifacts
.iter()
.any(|a| a.contains("coverage") || a.contains("test"));
if has_test_coverage {
quality += 0.1;
}
if artifacts.len() >= 3 {
quality += 0.05;
}
if let Some(coverage_str) = episode.metadata.get("test_coverage") {
if let Ok(coverage) = coverage_str.parse::<f32>() {
#[allow(clippy::excessive_nesting)]
if coverage > 80.0 {
quality += 0.15;
} else if coverage > 60.0 {
quality += 0.1;
}
}
}
}
let total_steps = episode.steps.len();
if total_steps > 0 {
let error_rate = episode.failed_steps_count() as f32 / total_steps as f32;
if error_rate > 0.3 {
quality -= 0.2;
} else if error_rate > 0.1 {
quality -= 0.1;
} else if error_rate == 0.0 {
quality += 0.1;
}
}
if episode.metadata.contains_key("clippy_warnings") {
if let Some(warnings) = episode.metadata.get("clippy_warnings") {
if warnings == "0" {
quality += 0.05;
}
}
}
quality.clamp(0.5, 1.5)
}
fn calculate_learning_bonus(&self, episode: &Episode) -> f32 {
let mut bonus = 0.0;
let pattern_count = episode.patterns.len();
if pattern_count > 0 {
bonus += (pattern_count as f32 * 0.1).min(0.3);
}
if let Some(novelty) = self.calculate_novelty_bonus(episode) {
bonus += novelty;
}
let total_steps = episode.steps.len();
if total_steps > 0 {
let success_rate = episode.successful_steps_count() as f32 / total_steps as f32;
if success_rate > 0.9 && total_steps >= 5 {
bonus += 0.2;
} else if success_rate == 1.0 && total_steps >= 3 {
bonus += 0.15;
}
}
if self.detect_error_recovery(episode) {
bonus += 0.15;
}
if let Some(duration) = episode.duration() {
let duration_secs = duration.num_seconds() as f32;
if duration_secs < 30.0 && total_steps > 0 && total_steps < 10 {
bonus += 0.1;
}
}
bonus.min(0.5)
}
fn calculate_novelty_bonus(&self, episode: &Episode) -> Option<f32> {
if episode.steps.len() < 3 {
return None;
}
let unique_tools: std::collections::HashSet<_> =
episode.steps.iter().map(|s| &s.tool).collect();
if unique_tools.len() >= 5 {
Some(0.15)
} else if unique_tools.len() >= 3 {
Some(0.1)
} else {
None
}
}
fn detect_error_recovery(&self, episode: &Episode) -> bool {
for i in 0..episode.steps.len().saturating_sub(1) {
let current = &episode.steps[i];
let next = &episode.steps[i + 1];
if !current.is_success() && next.is_success() {
return true;
}
}
false
}
}