use crate::error::Result;
use crate::reinforcement::curiosity::ICM;
use crate::reinforcement::policy::PolicyNetwork;
use crate::reinforcement::value::ValueNetwork;
use crate::reinforcement::{ExperienceBatch, LossInfo};
use scirs2_core::ndarray::prelude::*;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct NPGConfig {
pub learning_rate: f32,
pub gamma: f32,
pub lambda: f32,
pub cg_iterations: usize,
pub cg_tolerance: f32,
pub fisher_damping: f32,
pub fisher_batch_size: usize,
pub value_lr: f32,
}
impl Default for NPGConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
gamma: 0.99,
lambda: 0.95,
cg_iterations: 10,
cg_tolerance: 1e-8,
fisher_damping: 1e-2,
fisher_batch_size: 128,
value_lr: 1e-3,
}
}
}
pub struct NaturalPolicyGradient {
policy: PolicyNetwork,
value_function: ValueNetwork,
config: NPGConfig,
}
impl NaturalPolicyGradient {
pub fn new(
state_dim: usize,
action_dim: usize,
hidden_sizes: Vec<usize>,
continuous: bool,
config: NPGConfig,
) -> Result<Self> {
let policy = PolicyNetwork::new(state_dim, action_dim, hidden_sizes.clone(), continuous)?;
let value_function = ValueNetwork::new(state_dim, 1, hidden_sizes)?;
Ok(Self {
policy,
value_function,
config,
})
}
pub fn compute_gae(&self, rewards: &[f32], values: &[f32], dones: &[bool]) -> Vec<f32> {
let n = rewards.len();
let mut advantages = vec![0.0f32; n];
let mut gae = 0.0f32;
for i in (0..n).rev() {
let next_v = if i + 1 < n { values[i + 1] } else { 0.0 };
let delta = rewards[i]
+ if dones[i] {
0.0
} else {
self.config.gamma * next_v
}
- values[i];
gae = delta
+ if dones[i] {
0.0
} else {
self.config.gamma * self.config.lambda * gae
};
advantages[i] = gae;
}
advantages
}
pub fn act(&self, state: &ArrayView1<f32>) -> Result<Array1<f32>> {
self.policy.sample_action(state)
}
pub fn compute_policy_loss(&self, batch: &ExperienceBatch) -> Result<f32> {
let n = batch.states.nrows();
let mut loss = 0.0f32;
for i in 0..n {
let s = batch.states.row(i);
let a = batch.actions.row(i);
let lp = self.policy.log_prob(&s, &a)?;
loss -= lp * batch.rewards[i];
}
Ok(loss / n.max(1) as f32)
}
}
#[derive(Debug, Clone)]
pub struct MAMLConfig {
pub inner_lr: f32,
pub outer_lr: f32,
pub n_inner_steps: usize,
pub n_tasks: usize,
pub inner_batch_size: usize,
}
impl Default for MAMLConfig {
fn default() -> Self {
Self {
inner_lr: 0.01,
outer_lr: 3e-4,
n_inner_steps: 5,
n_tasks: 8,
inner_batch_size: 32,
}
}
}
pub struct MAMLAgent {
meta_policy: PolicyNetwork,
config: MAMLConfig,
}
impl MAMLAgent {
pub fn new(
state_dim: usize,
action_dim: usize,
hidden_sizes: Vec<usize>,
continuous: bool,
config: MAMLConfig,
) -> Result<Self> {
let meta_policy = PolicyNetwork::new(state_dim, action_dim, hidden_sizes, continuous)?;
Ok(Self {
meta_policy,
config,
})
}
pub fn act(&self, state: &ArrayView1<f32>) -> Result<Array1<f32>> {
self.meta_policy.sample_action(state)
}
pub fn adapt_and_evaluate(&self, support_batch: &ExperienceBatch) -> Result<f32> {
let n = support_batch.states.nrows();
let mut loss = 0.0f32;
for i in 0..n {
let s = support_batch.states.row(i);
let a = support_batch.actions.row(i);
let lp = self.meta_policy.log_prob(&s, &a)?;
loss -= lp * support_batch.rewards[i];
}
Ok(loss / n.max(1) as f32)
}
pub fn config(&self) -> &MAMLConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct CuriosityConfig {
pub eta: f32,
pub beta: f32,
pub feature_dim: usize,
pub hidden_sizes: Vec<usize>,
}
impl Default for CuriosityConfig {
fn default() -> Self {
Self {
eta: 0.01,
beta: 0.2,
feature_dim: 32,
hidden_sizes: vec![64, 64],
}
}
}
pub struct CuriosityDrivenAgent {
policy: PolicyNetwork,
value_fn: ValueNetwork,
icm: ICM,
curiosity_config: CuriosityConfig,
curiosity_weight: f32,
}
impl CuriosityDrivenAgent {
pub fn new(
state_dim: usize,
action_dim: usize,
hidden_sizes: Vec<usize>,
continuous: bool,
curiosity_config: CuriosityConfig,
curiosity_weight: f32,
) -> Result<Self> {
let policy = PolicyNetwork::new(state_dim, action_dim, hidden_sizes.clone(), continuous)?;
let value_fn = ValueNetwork::new(state_dim, 1, hidden_sizes)?;
let icm = ICM::new(
state_dim,
action_dim,
curiosity_config.feature_dim,
curiosity_config.hidden_sizes.clone(),
curiosity_config.eta,
curiosity_config.beta,
)?;
Ok(Self {
policy,
value_fn,
icm,
curiosity_config,
curiosity_weight,
})
}
pub fn act(&self, state: &ArrayView1<f32>) -> Result<Array1<f32>> {
self.policy.sample_action(state)
}
pub fn augment_reward(
&self,
state: &ArrayView1<f32>,
action: &ArrayView1<f32>,
next_state: &ArrayView1<f32>,
extrinsic_reward: f32,
) -> Result<f32> {
let intrinsic = self
.icm
.compute_intrinsic_reward(state, action, next_state)?;
Ok(extrinsic_reward + self.curiosity_weight * intrinsic)
}
pub fn compute_loss(
&self,
batch: &ExperienceBatch,
next_states: &ArrayView2<f32>,
) -> Result<LossInfo> {
let n = batch.states.nrows();
let mut policy_loss = 0.0f32;
let mut curiosity_loss = 0.0f32;
for i in 0..n {
let s = batch.states.row(i);
let a = batch.actions.row(i);
let ns = next_states.row(i);
let lp = self.policy.log_prob(&s, &a)?;
policy_loss -= lp * batch.rewards[i];
curiosity_loss += self.icm.compute_loss(&s, &a, &ns)?;
}
policy_loss /= n.max(1) as f32;
curiosity_loss /= n.max(1) as f32;
let total = policy_loss + self.curiosity_weight * curiosity_loss;
let mut metrics = HashMap::new();
metrics.insert("curiosity_loss".to_string(), curiosity_loss);
Ok(LossInfo {
policy_loss: Some(policy_loss),
value_loss: None,
entropy_loss: None,
total_loss: total,
metrics,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::reinforcement::ExperienceBatch;
#[test]
fn test_npg_config_default() {
let config = NPGConfig::default();
assert_eq!(config.cg_iterations, 10);
assert!((config.gamma - 0.99).abs() < 1e-6);
}
#[test]
fn test_npg_create_and_act() {
let npg = NaturalPolicyGradient::new(4, 2, vec![8], false, NPGConfig::default())
.expect("create ok");
let state = Array1::zeros(4);
let action = npg.act(&state.view()).expect("act ok");
assert_eq!(action.len(), 2);
}
#[test]
fn test_npg_compute_gae() {
let npg = NaturalPolicyGradient::new(4, 2, vec![8], false, NPGConfig::default())
.expect("create ok");
let rewards = vec![1.0f32; 5];
let values = vec![0.5f32; 5];
let dones = vec![false; 5];
let advs = npg.compute_gae(&rewards, &values, &dones);
assert_eq!(advs.len(), 5);
for a in &advs {
assert!(a.is_finite());
}
}
#[test]
fn test_maml_create_and_act() {
let maml = MAMLAgent::new(4, 2, vec![8], false, MAMLConfig::default()).expect("create ok");
let state = Array1::zeros(4);
let action = maml.act(&state.view()).expect("act ok");
assert_eq!(action.len(), 2);
}
#[test]
fn test_maml_adapt_evaluate() {
let maml = MAMLAgent::new(4, 2, vec![8], false, MAMLConfig::default()).expect("create ok");
let batch = ExperienceBatch {
states: Array2::zeros((4, 4)),
actions: Array2::from_shape_fn((4, 2), |(i, j)| if j == i % 2 { 1.0 } else { 0.0 }),
rewards: Array1::ones(4),
next_states: Array2::zeros((4, 4)),
dones: Array1::from_elem(4, false),
info: None,
};
let loss = maml.adapt_and_evaluate(&batch).expect("adapt ok");
assert!(loss.is_finite());
}
#[test]
fn test_curiosity_driven_agent_act() {
let agent =
CuriosityDrivenAgent::new(4, 2, vec![8], false, CuriosityConfig::default(), 0.1)
.expect("create ok");
let state = Array1::zeros(4);
let action = agent.act(&state.view()).expect("act ok");
assert_eq!(action.len(), 2);
}
#[test]
fn test_curiosity_augment_reward() {
let agent =
CuriosityDrivenAgent::new(4, 2, vec![8], false, CuriosityConfig::default(), 0.5)
.expect("create ok");
let state = Array1::zeros(4);
let action = Array1::from_vec(vec![1.0, 0.0]);
let next_state = Array1::ones(4);
let augmented = agent
.augment_reward(&state.view(), &action.view(), &next_state.view(), 1.0)
.expect("augment ok");
assert!(augmented.is_finite());
assert!(augmented >= 1.0, "augmented reward must be ≥ extrinsic");
}
}