#[allow(dead_code)]
use super::{PolicyNetwork, RLOptimizationMetrics, RLOptimizerConfig, TrajectoryBatch};
use crate::error::Result;
use scirs2_core::ndarray::{Array1, Array2, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct NaturalGradientConfig<T: Float + Debug + Send + Sync + 'static> {
pub base_config: RLOptimizerConfig<T>,
pub fisher_method: FisherEstimationMethod,
pub damping: T,
pub fisher_update_freq: usize,
pub use_empirical_fisher: bool,
pub cg_iters: usize,
pub cg_tolerance: T,
pub natural_grad_scale: T,
pub enable_preconditioning: bool,
pub diagonal_fisher: bool,
pub block_diagonal_fisher: bool,
pub kronecker_factored: bool,
}
#[derive(Debug, Clone, Copy)]
pub enum FisherEstimationMethod {
Empirical,
True,
Diagonal,
BlockDiagonal,
KroneckerFactored,
GaussNewton,
BFGS,
}
impl<T: Float + Debug + Send + Sync + 'static> Default for NaturalGradientConfig<T> {
fn default() -> Self {
Self {
base_config: RLOptimizerConfig::default(),
fisher_method: FisherEstimationMethod::Empirical,
damping: T::from(1e-4).unwrap_or_else(|| T::zero()),
fisher_update_freq: 10,
use_empirical_fisher: true,
cg_iters: 10,
cg_tolerance: T::from(1e-8).unwrap_or_else(|| T::zero()),
natural_grad_scale: T::from(1.0).unwrap_or_else(|| T::zero()),
enable_preconditioning: true,
diagonal_fisher: false,
block_diagonal_fisher: false,
kronecker_factored: false,
}
}
}
pub struct NaturalPolicyGradient<T: Float + Debug, P: PolicyNetwork<T>> {
_config: NaturalGradientConfig<T>,
policy: P,
fisher_matrix: Option<Array2<T>>,
fisher_diagonal: Option<Array1<T>>,
kronecker_factors: Option<KroneckerFactors<T>>,
empirical_fisher_accumulator: FisherAccumulator<T>,
natural_grad_state: NaturalGradientState<T>,
update_count: usize,
paramdim: usize,
}
#[derive(Debug, Clone)]
pub struct KroneckerFactors<T: Float + Debug + Send + Sync + 'static> {
pub input_factors: Vec<Array2<T>>,
pub output_factors: Vec<Array2<T>>,
pub layer_indices: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct FisherAccumulator<T: Float + Debug + Send + Sync + 'static> {
pub fisher_sum: Array2<T>,
pub sample_count: usize,
pub gradient_history: Vec<Array1<T>>,
pub max_history_size: usize,
}
#[derive(Debug, Clone)]
pub struct NaturalGradientState<T: Float + Debug + Send + Sync + 'static> {
pub prev_natural_grad: Option<Array1<T>>,
pub momentum: T,
pub adaptive_scales: Option<Array1<T>>,
pub trust_radius: T,
pub kl_history: Vec<T>,
}
impl<T: Float + Debug + ScalarOperand + std::ops::AddAssign + std::iter::Sum, P: PolicyNetwork<T + Send + Sync>>
NaturalPolicyGradient<T, P>
{
pub fn new(_config: NaturalGradientConfig<T>, policy: P, paramdim: usize) -> Self {
let fisher_accumulator = FisherAccumulator {
fisher_sum: Array2::zeros((paramdim, paramdim)),
sample_count: 0,
gradient_history: Vec::new(),
max_history_size: 1000,
};
let natural_grad_state = NaturalGradientState {
prev_natural_grad: None,
momentum: T::from(0.9).unwrap_or_else(|| T::zero()),
adaptive_scales: None,
trust_radius: T::from(1.0).unwrap_or_else(|| T::zero()),
kl_history: Vec::new(),
};
Self {
_config,
policy,
fisher_matrix: None,
fisher_diagonal: None,
kronecker_factors: None,
empirical_fisher_accumulator: fisher_accumulator,
natural_grad_state,
update_count: 0,
paramdim,
}
}
pub fn update(
&mut self,
trajectory: TrajectoryBatch<T>,
gradients: Array1<T>,
) -> Result<RLOptimizationMetrics<T>> {
if self.update_count % self._config.fisher_update_freq == 0 {
self.update_fisher_information(&trajectory)?;
}
let naturalgradients = self.compute_natural_gradients(&gradients)?;
self.apply_natural_gradient_update(&naturalgradients)?;
self.natural_grad_state.prev_natural_grad = Some(naturalgradients);
self.update_count += 1;
let mut metrics = RLOptimizationMetrics::default();
metrics.policy_grad_norm = self.vector_norm(&gradients);
Ok(metrics)
}
fn update_fisher_information(&mut self, trajectory: &TrajectoryBatch<T>) -> Result<()> {
match self._config.fisher_method {
FisherEstimationMethod::Empirical => self.update_empirical_fisher(trajectory)?,
FisherEstimationMethod::True => self.update_true_fisher(trajectory)?,
FisherEstimationMethod::Diagonal => self.update_diagonal_fisher(trajectory)?,
FisherEstimationMethod::BlockDiagonal => {
self.update_block_diagonal_fisher(trajectory)?
}
FisherEstimationMethod::KroneckerFactored => {
self.update_kronecker_factors(trajectory)?
}
_ => {
self.update_empirical_fisher(trajectory)?;
}
}
Ok(())
}
fn update_empirical_fisher(&mut self, trajectory: &TrajectoryBatch<T>) -> Result<()> {
let batch_size = trajectory.observations.nrows();
for i in 0..batch_size {
let obs = trajectory.observations.row(i).to_owned();
let action = trajectory.actions.row(i).to_owned();
let log_prob_grad = self.compute_log_prob_gradients(&obs, &action)?;
self.add_to_empirical_fisher(&log_prob_grad)?;
}
self.finalize_empirical_fisher()?;
Ok(())
}
fn update_true_fisher(&mut self, trajectory: &TrajectoryBatch<T>) -> Result<()> {
self.update_empirical_fisher(trajectory) }
fn update_diagonal_fisher(&mut self, trajectory: &TrajectoryBatch<T>) -> Result<()> {
let mut diagonal = Array1::zeros(self.paramdim);
let batch_size = trajectory.observations.nrows();
for i in 0..batch_size {
let obs = trajectory.observations.row(i).to_owned();
let action = trajectory.actions.row(i).to_owned();
let log_prob_grad = self.compute_log_prob_gradients(&obs, &action)?;
diagonal = diagonal + log_prob_grad.mapv(|x| x * x);
}
diagonal = diagonal / T::from(batch_size).unwrap_or_else(|| T::zero());
diagonal = diagonal + T::from(self._config.damping).unwrap_or_else(|| T::zero());
self.fisher_diagonal = Some(diagonal);
Ok(())
}
fn update_block_diagonal_fisher(&mut self, trajectory: &TrajectoryBatch<T>) -> Result<()> {
Ok(())
}
fn update_kronecker_factors(&mut self, trajectory: &TrajectoryBatch<T>) -> Result<()> {
Ok(())
}
fn compute_natural_gradients(&self, gradients: &Array1<T>) -> Result<Array1<T>> {
if !self._config.enable_preconditioning {
return Ok(gradients.clone());
}
let natural_grad = match self._config.fisher_method {
FisherEstimationMethod::Diagonal => {
if let Some(ref diag) = self.fisher_diagonal {
gradients / diag
} else {
gradients.clone()
}
}
_ => {
if let Some(ref fisher) = self.fisher_matrix {
self.solve_fisher_system(fisher, gradients)?
} else {
gradients.clone()
}
}
};
let scaled_natural_grad = natural_grad * self._config.natural_grad_scale;
Ok(scaled_natural_grad)
}
fn solve_fisher_system(&self, fisher: &Array2<T>, rhs: &Array1<T>) -> Result<Array1<T>> {
let n = rhs.len();
let mut x = Array1::zeros(n);
let mut r = rhs.clone();
let mut p = r.clone();
let mut rsold = self.dot(&r, &r);
for _i in 0..self._config.cg_iters {
let ap = fisher.dot(&p);
let alpha = rsold / self.dot(&p, &ap);
x = &x + &(&p * alpha);
r = &r - &(&ap * alpha);
let rsnew = self.dot(&r, &r);
if rsnew.sqrt() < self._config.cg_tolerance {
break;
}
let beta = rsnew / rsold;
p = &r + &(&p * beta);
rsold = rsnew;
}
Ok(x)
}
fn apply_natural_gradient_update(&mut self, naturalgradients: &Array1<T>) -> Result<()> {
let update = if let Some(ref prev_ng) = self.natural_grad_state.prev_natural_grad {
naturalgradients + &(prev_ng * self.natural_grad_state.momentum)
} else {
naturalgradients.clone()
};
let clipped_update = self.apply_trust_region_constraint(&update)?;
self.update_policy_parameters(&clipped_update)?;
Ok(())
}
fn apply_trust_region_constraint(&self, update: &Array1<T>) -> Result<Array1<T>> {
let update_norm = self.vector_norm(update);
let trust_radius = self.natural_grad_state.trust_radius;
if update_norm <= trust_radius {
Ok(update.clone())
} else {
Ok(update * (trust_radius / update_norm))
}
}
fn update_policy_parameters(&mut self, update: &Array1<T>) -> Result<()> {
Ok(())
}
fn compute_log_prob_gradients(
&self,
obs: &Array1<T>,
_action: &Array1<T>,
) -> Result<Array1<T>> {
Ok(Array1::zeros(self.paramdim))
}
fn add_to_empirical_fisher(&mut self, gradient: &Array1<T>) -> Result<()> {
for i in 0..self.paramdim {
for j in 0..self.paramdim {
self.empirical_fisher_accumulator.fisher_sum[[i, j]] += gradient[i] * gradient[j];
}
}
self.empirical_fisher_accumulator.sample_count += 1;
if self.empirical_fisher_accumulator.gradient_history.len()
>= self.empirical_fisher_accumulator.max_history_size
{
self.empirical_fisher_accumulator.gradient_history.remove(0);
}
self.empirical_fisher_accumulator
.gradient_history
.push(gradient.clone());
Ok(())
}
fn finalize_empirical_fisher(&mut self) -> Result<()> {
if self.empirical_fisher_accumulator.sample_count == 0 {
return Ok(());
}
let fisher = &self.empirical_fisher_accumulator.fisher_sum
/ T::from(self.empirical_fisher_accumulator.sample_count).unwrap_or_else(|| T::zero());
let mut damped_fisher = fisher;
for i in 0..self.paramdim {
damped_fisher[[i, i]] += self._config.damping;
}
self.fisher_matrix = Some(damped_fisher);
self.empirical_fisher_accumulator.fisher_sum.fill(T::zero());
self.empirical_fisher_accumulator.sample_count = 0;
Ok(())
}
fn dot(&self, a: &Array1<T>, b: &Array1<T>) -> T {
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
fn vector_norm(&self, v: &Array1<T>) -> T {
self.dot(v, v).sqrt()
}
pub fn get_fisher_matrix(&self) -> Option<&Array2<T>> {
self.fisher_matrix.as_ref()
}
pub fn get_natural_grad_state(&self) -> &NaturalGradientState<T> {
&self.natural_grad_state
}
}