use crate::error::{OptimError, Result};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::numeric::Float;
use std::collections::HashMap;
use std::fmt::Debug;
pub mod actor_critic;
pub mod natural_gradients;
pub mod policy_gradient;
pub mod trust_region;
pub use actor_critic::{ActorCriticConfig, ActorCriticMethod, ActorCriticOptimizer};
pub use natural_gradients::{NaturalGradientConfig, NaturalPolicyGradient};
pub use policy_gradient::{PolicyGradientConfig, PolicyGradientMethod, PolicyGradientOptimizer};
pub use trust_region::{TrustRegionConfig, TrustRegionMethod, TrustRegionOptimizer};
#[derive(Debug, Clone)]
pub struct RLOptimizerConfig<T: Float + Debug + Send + Sync + 'static> {
pub policy_lr: T,
pub value_lr: T,
pub discount_factor: T,
pub gae_lambda: T,
pub clip_epsilon: T,
pub entropy_coeff: T,
pub value_loss_coeff: T,
pub max_grad_norm: T,
pub n_epochs: usize,
pub mini_batchsize: usize,
pub trust_region_config: Option<TrustRegionConfig<T>>,
pub use_natural_gradients: bool,
pub fisher_approximation: FisherApproximationMethod,
}
#[derive(Debug, Clone, Copy)]
pub enum FisherApproximationMethod {
Empirical,
KroneckerFactored,
Diagonal,
BlockDiagonal,
LowRank,
}
impl<T: Float + Debug + Send + Sync + 'static> Default for RLOptimizerConfig<T> {
fn default() -> Self {
Self {
policy_lr: T::from(3e-4).unwrap_or_else(|| T::zero()),
value_lr: T::from(1e-3).unwrap_or_else(|| T::zero()),
discount_factor: T::from(0.99).unwrap_or_else(|| T::zero()),
gae_lambda: T::from(0.95).unwrap_or_else(|| T::zero()),
clip_epsilon: T::from(0.2).unwrap_or_else(|| T::zero()),
entropy_coeff: T::from(0.01).unwrap_or_else(|| T::zero()),
value_loss_coeff: T::from(0.5).unwrap_or_else(|| T::zero()),
max_grad_norm: T::from(0.5).unwrap_or_else(|| T::zero()),
n_epochs: 4,
mini_batchsize: 64,
trust_region_config: None,
use_natural_gradients: false,
fisher_approximation: FisherApproximationMethod::Diagonal,
}
}
}
#[derive(Debug, Clone)]
pub struct TrajectoryBatch<T: Float + Debug + Send + Sync + 'static> {
pub observations: Array2<T>,
pub actions: Array2<T>,
pub log_probs: Array1<T>,
pub rewards: Array1<T>,
pub values: Array1<T>,
pub dones: Array1<bool>,
pub advantages: Array1<T>,
pub returns: Array1<T>,
}
impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::numeric::FromPrimitive> TrajectoryBatch<T> {
pub fn new(
observations: Array2<T>,
actions: Array2<T>,
log_probs: Array1<T>,
rewards: Array1<T>,
values: Array1<T>,
dones: Array1<bool>,
) -> Result<Self> {
let batch_size = observations.nrows();
if actions.nrows() != batch_size
|| log_probs.len() != batch_size
|| rewards.len() != batch_size
|| values.len() != batch_size
|| dones.len() != batch_size
{
return Err(OptimError::InvalidConfig(
"Inconsistent batch dimensions".to_string(),
));
}
let advantages = Array1::zeros(batch_size);
let returns = Array1::zeros(batch_size);
Ok(Self {
observations,
actions,
log_probs,
rewards,
values,
dones,
advantages,
returns,
})
}
pub fn compute_advantages(&mut self, gamma: T, lambda: T, nextvalue: T) -> Result<()> {
let batch_size = self.rewards.len();
let mut gae = T::zero();
for t in (0..batch_size).rev() {
let is_terminal = if t == batch_size - 1 {
false } else {
self.dones[t]
};
let next_val = if t == batch_size - 1 {
nextvalue
} else {
self.values[t + 1]
};
let delta = self.rewards[t] + gamma * next_val * T::from(!is_terminal as u8).unwrap_or_else(|| T::zero())
- self.values[t];
gae = delta + gamma * lambda * T::from(!is_terminal as u8).unwrap_or_else(|| T::zero()) * gae;
self.advantages[t] = gae;
self.returns[t] = gae + self.values[t];
}
let mean = self.advantages.mean().unwrap_or(T::zero());
let std = self
.advantages
.mapv(|x| (x - mean) * (x - mean))
.mean()
.unwrap_or(T::one())
.sqrt();
if std > T::from(1e-8).unwrap_or_else(|| T::zero()) {
self.advantages.mapv_inplace(|x| (x - mean) / std);
}
Ok(())
}
pub fn get_mini_batches(&self, mini_batchsize: usize) -> Vec<TrajectoryBatch<T>> {
let batch_size = self.observations.nrows();
let n_mini_batches = (batch_size + mini_batchsize - 1) / mini_batchsize;
let mut mini_batches = Vec::new();
for i in 0..n_mini_batches {
let start = i * mini_batchsize;
let end = ((i + 1) * mini_batchsize).min(batch_size);
if start >= end {
break;
}
let obs = self.observations.slice(s![start..end, ..]).to_owned();
let acts = self.actions.slice(s![start..end, ..]).to_owned();
let log_probs = self.log_probs.slice(s![start..end]).to_owned();
let rewards = self.rewards.slice(s![start..end]).to_owned();
let values = self.values.slice(s![start..end]).to_owned();
let dones = self.dones.slice(s![start..end]).to_owned().to_vec();
let advantages = self.advantages.slice(s![start..end]).to_owned();
let returns = self.returns.slice(s![start..end]).to_owned();
let dones_array = Array1::from_vec(dones);
let mini_batch = TrajectoryBatch {
observations: obs,
actions: acts,
log_probs,
rewards,
values,
dones: dones_array,
advantages,
returns,
};
mini_batches.push(mini_batch);
}
mini_batches
}
}
pub trait PolicyNetwork<T: Float + Debug + Send + Sync + 'static> {
fn evaluate_actions(
&self,
observations: &Array2<T>,
actions: &Array2<T>,
) -> Result<PolicyEvaluation<T>>;
fn get_action_distribution(&self, observations: &Array2<T>) -> Result<ActionDistribution<T>>;
fn update_parameters(&mut self, gradients: &HashMap<String, Array1<T>>) -> Result<()>;
fn get_parameters(&self) -> HashMap<String, Array1<T>>;
}
pub trait ValueNetwork<T: Float + Debug + Send + Sync + 'static> {
fn evaluate_value(&self, observations: &Array2<T>) -> Result<Array1<T>>;
fn update_parameters(&mut self, gradients: &HashMap<String, Array1<T>>) -> Result<()>;
fn get_parameters(&self) -> HashMap<String, Array1<T>>;
}
#[derive(Debug, Clone)]
pub struct PolicyEvaluation<T: Float + Debug + Send + Sync + 'static> {
pub log_probs: Array1<T>,
pub entropy: Array1<T>,
pub metrics: HashMap<String, T>,
}
#[derive(Debug, Clone)]
pub struct ActionDistribution<T: Float + Debug + Send + Sync + 'static> {
pub mean: Option<Array2<T>>,
pub std: Option<Array2<T>>,
pub logits: Option<Array2<T>>,
pub distribution_type: DistributionType,
}
#[derive(Debug, Clone, Copy)]
pub enum DistributionType {
Gaussian,
Categorical,
Beta,
Mixed,
}
#[derive(Debug, Clone)]
pub struct RLScheduler<T: Float + Debug + Send + Sync + 'static> {
pub initiallr: T,
pub current_lr: T,
pub decay_factor: T,
pub schedule: ScheduleType,
pub update_count: usize,
pub schedule_params: HashMap<String, T>,
}
#[derive(Debug, Clone, Copy)]
pub enum ScheduleType {
Constant,
Linear,
Exponential,
Cosine,
Step,
Adaptive,
}
impl<T: Float + Debug + Send + Sync + 'static> RLScheduler<T> {
pub fn new(initiallr: T, schedule: ScheduleType) -> Self {
Self {
initiallr,
current_lr: initiallr,
decay_factor: T::from(0.99).unwrap_or_else(|| T::zero()),
schedule,
update_count: 0,
schedule_params: HashMap::new(),
}
}
pub fn step(&mut self) -> T {
self.update_count += 1;
match self.schedule {
ScheduleType::Constant => {
}
ScheduleType::Linear => {
let decay_steps = self
.schedule_params
.get("decay_steps")
.copied()
.unwrap_or(T::from(10000).unwrap_or_else(|| T::zero()));
let progress = T::from(self.update_count).unwrap_or_else(|| T::zero()) / decay_steps;
self.current_lr = self.initiallr * (T::one() - progress).max(T::zero());
}
ScheduleType::Exponential => {
self.current_lr = self.current_lr * self.decay_factor;
}
ScheduleType::Step => {
let step_size = self
.schedule_params
.get("step_size")
.copied()
.unwrap_or(T::from(1000).unwrap_or_else(|| T::zero()));
if T::from(self.update_count).unwrap_or_else(|| T::zero()) % step_size == T::zero() {
self.current_lr = self.current_lr * self.decay_factor;
}
}
ScheduleType::Cosine => {
let max_steps = self
.schedule_params
.get("max_steps")
.copied()
.unwrap_or(T::from(10000).unwrap_or_else(|| T::zero()));
let progress = T::from(self.update_count).unwrap_or_else(|| T::zero()) / max_steps;
let pi = T::from(std::f64::consts::PI).unwrap_or_else(|| T::zero());
self.current_lr =
self.initiallr * (T::one() + (pi * progress).cos()) / T::from(2).unwrap_or_else(|| T::zero());
}
ScheduleType::Adaptive => {
}
}
self.current_lr
}
pub fn get_lr(&self) -> T {
self.current_lr
}
pub fn set_param(&mut self, key: &str, value: T) {
self.schedule_params.insert(key.to_string(), value);
}
}
#[derive(Debug, Clone)]
pub struct RLOptimizationMetrics<T: Float + Debug + Send + Sync + 'static> {
pub policy_loss: T,
pub value_loss: T,
pub entropy_loss: T,
pub total_loss: T,
pub kl_divergence: Option<T>,
pub explained_variance: T,
pub clip_fraction: Option<T>,
pub policy_lr: T,
pub value_lr: T,
pub policy_grad_norm: T,
pub value_grad_norm: T,
pub custom_metrics: HashMap<String, T>,
}
impl<T: Float + Debug + Send + Sync + 'static> Default for RLOptimizationMetrics<T> {
fn default() -> Self {
Self {
policy_loss: T::zero(),
value_loss: T::zero(),
entropy_loss: T::zero(),
total_loss: T::zero(),
kl_divergence: None,
explained_variance: T::zero(),
clip_fraction: None,
policy_lr: T::from(3e-4).unwrap_or_else(|| T::zero()),
value_lr: T::from(1e-3).unwrap_or_else(|| T::zero()),
policy_grad_norm: T::zero(),
value_grad_norm: T::zero(),
custom_metrics: HashMap::new(),
}
}
}
use scirs2_core::ndarray::s;