#![allow(clippy::use_self)] mod opt;
mod rtg;
pub use opt::{ValuesOpt, ValuesOptConfig};
pub use rtg::{RewardToGo, RewardToGoConfig};
use super::features::HistoryFeatures;
use crate::logging::StatsLogger;
use crate::torch::modules::SeqPacked;
use crate::torch::packed::PackedTensor;
use serde::{Deserialize, Serialize};
use tch::Device;
pub trait Critic {
fn advantages(&self, features: &dyn HistoryFeatures) -> PackedTensor;
fn update(&mut self, features: &dyn HistoryFeatures, logger: &mut dyn StatsLogger);
}
pub trait BuildCritic {
type Critic: Critic;
fn build_critic(&self, in_dim: usize, discount_factor: f64, device: Device) -> Self::Critic;
}
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub enum AdvantageFn {
Gae {
lambda: f32,
},
}
impl Default for AdvantageFn {
fn default() -> Self {
Self::Gae { lambda: 0.95 }
}
}
impl AdvantageFn {
pub fn advantages<M: SeqPacked + ?Sized>(
&self,
state_value_fn: &M,
discount_factor: f32,
features: &dyn HistoryFeatures,
) -> PackedTensor {
match self {
Self::Gae { lambda } => gae(state_value_fn, discount_factor, *lambda, features),
}
}
}
pub fn reward_to_go(discount_factor: f32, features: &dyn HistoryFeatures) -> PackedTensor {
features
.rewards()
.discounted_cumsum_from_end(discount_factor)
}
pub fn eval_extended_state_values<M: SeqPacked + ?Sized>(
state_value_fn: &M,
features: &dyn HistoryFeatures,
) -> PackedTensor {
let (extended_observation_features, is_invalid) = features.extended_observation_features();
let mut extended_estimated_values = state_value_fn
.seq_packed(extended_observation_features)
.batch_map(|t| t.squeeze_dim(-1));
let _ = extended_estimated_values
.tensor_mut()
.masked_fill_(is_invalid.tensor(), 0.0);
extended_estimated_values
}
pub fn one_step_values<M: SeqPacked + ?Sized>(
state_value_fn: &M,
discount_factor: f32,
features: &dyn HistoryFeatures,
) -> PackedTensor {
let estimated_next_values =
eval_extended_state_values(state_value_fn, features).view_trim_start(1);
features
.rewards()
.batch_map_ref(|rewards| rewards + discount_factor * estimated_next_values.tensor())
}
pub fn temporal_differences<M: SeqPacked + ?Sized>(
state_value_fn: &M,
discount_factor: f32,
features: &dyn HistoryFeatures,
) -> PackedTensor {
let extended_state_values = eval_extended_state_values(state_value_fn, features);
let estimated_values = extended_state_values.trim_end(1);
let estimated_next_values = extended_state_values.view_trim_start(1);
features.rewards().batch_map_ref(|rewards| {
rewards + discount_factor * estimated_next_values.tensor() - estimated_values.tensor()
})
}
pub fn gae<M: SeqPacked + ?Sized>(
state_value_fn: &M,
discount_factor: f32,
lambda: f32,
features: &dyn HistoryFeatures,
) -> PackedTensor {
let residuals = temporal_differences(state_value_fn, discount_factor, features);
residuals.discounted_cumsum_from_end(lambda * discount_factor)
}
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub enum StepValueTarget {
RewardToGo,
OneStepTd,
}
impl Default for StepValueTarget {
fn default() -> Self {
Self::RewardToGo
}
}
impl StepValueTarget {
pub fn targets<M: SeqPacked + ?Sized>(
&self,
state_value_fn: &M,
discount_factor: f32,
features: &dyn HistoryFeatures,
) -> PackedTensor {
match self {
Self::RewardToGo => reward_to_go(discount_factor, features),
Self::OneStepTd => one_step_values(state_value_fn, discount_factor, features),
}
}
}