use crate::error::{NeuralError, Result};
use crate::reinforcement::policy::PolicyNetwork;
use crate::reinforcement::value::ValueNetwork;
use oxicode::{config as oxicode_config, serde as oxicode_serde};
use scirs2_core::ndarray::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TRPOConfig {
pub max_kl: f32,
pub damping: f32,
pub accept_ratio: f32,
pub max_line_search_iter: usize,
pub cg_iters: usize,
pub cg_tol: f32,
pub vf_iters: usize,
pub vf_lr: f32,
pub gae_lambda: f32,
pub gamma: f32,
pub entropy_coef: f32,
}
impl Default for TRPOConfig {
fn default() -> Self {
Self {
max_kl: 0.01,
damping: 0.1,
accept_ratio: 0.1,
max_line_search_iter: 10,
cg_iters: 10,
cg_tol: 1e-8,
vf_iters: 5,
vf_lr: 1e-3,
gae_lambda: 0.97,
gamma: 0.99,
entropy_coef: 0.0,
}
}
}
pub struct TRPO {
policy: PolicyNetwork,
value_fn: ValueNetwork,
config: TRPOConfig,
}
impl TRPO {
pub fn new(
state_dim: usize,
action_dim: usize,
hidden_sizes: Vec<usize>,
continuous: bool,
config: TRPOConfig,
) -> Result<Self> {
let policy = PolicyNetwork::new(state_dim, action_dim, hidden_sizes.clone(), continuous)?;
let value_fn = ValueNetwork::new(state_dim, 1, hidden_sizes)?;
Ok(Self {
policy,
value_fn,
config,
})
}
pub fn act(&self, state: &ArrayView1<f32>) -> Result<Array1<f32>> {
self.policy.sample_action(state)
}
pub fn value(&self, state: &ArrayView1<f32>) -> Result<f32> {
self.value_fn.predict(state)
}
pub fn compute_gae(
&self,
rewards: &[f32],
values: &[f32],
dones: &[bool],
next_value: f32,
) -> Vec<f32> {
let n = rewards.len();
let mut advantages = vec![0.0f32; n];
let mut gae = 0.0f32;
for i in (0..n).rev() {
let next_val = if i + 1 < n { values[i + 1] } else { next_value };
let delta = rewards[i]
+ if dones[i] {
0.0
} else {
self.config.gamma * next_val
}
- values[i];
gae = delta
+ if dones[i] {
0.0
} else {
self.config.gamma * self.config.gae_lambda * gae
};
advantages[i] = gae;
}
advantages
}
pub fn update(
&mut self,
states: &ArrayView2<f32>,
actions: &ArrayView2<f32>,
advantages: &ArrayView1<f32>,
) -> Result<f32> {
let n = states.nrows();
if n == 0 {
return Ok(0.0);
}
let mut loss = 0.0f32;
for i in 0..n {
let s = states.row(i);
let a = actions.row(i);
let lp = self.policy.log_prob(&s, &a)?;
loss -= lp * advantages[i];
}
loss /= n as f32;
Ok(loss)
}
pub fn config(&self) -> &TRPOConfig {
&self.config
}
pub fn save(&self, path: &str) -> Result<()> {
let snapshot = TrpoSnapshot {
policy_params: self.policy.collect_params(),
value_params: self.value_fn.collect_params(),
config: self.config.clone(),
};
let cfg = oxicode_config::standard();
let bytes = oxicode_serde::encode_to_vec(&snapshot, cfg)
.map_err(|e| NeuralError::SerializationError(format!("TRPO save: {e}")))?;
std::fs::write(path, &bytes)
.map_err(|e| NeuralError::IOError(format!("TRPO save write: {e}")))
}
pub fn load(&mut self, path: &str) -> Result<()> {
let bytes = std::fs::read(path)
.map_err(|e| NeuralError::IOError(format!("TRPO load read: {e}")))?;
let cfg = oxicode_config::standard();
let (snapshot, _): (TrpoSnapshot, _) = oxicode_serde::decode_owned_from_slice(&bytes, cfg)
.map_err(|e| NeuralError::DeserializationError(format!("TRPO load: {e}")))?;
self.policy.restore_params(&snapshot.policy_params)?;
self.value_fn.restore_params(&snapshot.value_params)?;
self.config = snapshot.config;
Ok(())
}
}
#[derive(Serialize, Deserialize)]
struct TrpoSnapshot {
policy_params: Vec<(Vec<f32>, Vec<usize>)>,
value_params: Vec<(Vec<f32>, Vec<usize>)>,
config: TRPOConfig,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trpo_config_default() {
let config = TRPOConfig::default();
assert_eq!(config.cg_iters, 10);
assert!((config.max_kl - 0.01).abs() < 1e-6);
}
#[test]
fn test_trpo_create_and_act() {
let trpo = TRPO::new(4, 2, vec![8], false, TRPOConfig::default()).expect("create ok");
let state = Array1::zeros(4);
let action = trpo.act(&state.view()).expect("act ok");
assert_eq!(action.len(), 2);
}
#[test]
fn test_trpo_compute_gae() {
let trpo = TRPO::new(4, 2, vec![8], false, TRPOConfig::default()).expect("create ok");
let rewards = vec![1.0f32; 5];
let values = vec![0.5f32; 5];
let dones = vec![false; 5];
let advs = trpo.compute_gae(&rewards, &values, &dones, 0.0);
assert_eq!(advs.len(), 5);
for a in &advs {
assert!(a.is_finite());
}
}
#[test]
fn test_trpo_update() {
let mut trpo = TRPO::new(4, 2, vec![8], false, TRPOConfig::default()).expect("create ok");
let states = Array2::zeros((4, 4));
let actions = Array2::from_shape_fn((4, 2), |(i, j)| if j == i % 2 { 1.0 } else { 0.0 });
let advantages = Array1::ones(4);
let loss = trpo
.update(&states.view(), &actions.view(), &advantages.view())
.expect("update ok");
assert!(loss.is_finite());
}
#[test]
fn test_trpo_save_load_round_trip() {
let tmp = std::env::temp_dir().join("trpo_test_save_load.oxicode");
let path = tmp.to_str().expect("valid temp path");
let trpo_orig =
TRPO::new(4, 2, vec![8, 8], false, TRPOConfig::default()).expect("create trpo");
let policy_before = trpo_orig.policy.collect_params();
let value_before = trpo_orig.value_fn.collect_params();
trpo_orig.save(path).expect("trpo save");
let mut trpo_loaded =
TRPO::new(4, 2, vec![8, 8], false, TRPOConfig::default()).expect("create trpo load");
trpo_loaded.load(path).expect("trpo load");
let policy_after = trpo_loaded.policy.collect_params();
let value_after = trpo_loaded.value_fn.collect_params();
assert_eq!(policy_before.len(), policy_after.len());
for (orig, loaded) in policy_before.iter().zip(policy_after.iter()) {
assert_eq!(orig.1, loaded.1, "policy param shape mismatch");
for (&a, &b) in orig.0.iter().zip(loaded.0.iter()) {
assert!(
(a - b).abs() < 1e-10,
"policy param diff {} vs {} exceeds tolerance",
a,
b
);
}
}
assert_eq!(value_before.len(), value_after.len());
for (orig, loaded) in value_before.iter().zip(value_after.iter()) {
assert_eq!(orig.1, loaded.1, "value param shape mismatch");
for (&a, &b) in orig.0.iter().zip(loaded.0.iter()) {
assert!(
(a - b).abs() < 1e-10,
"value param diff {} vs {} exceeds tolerance",
a,
b
);
}
}
assert!((trpo_orig.config.max_kl - trpo_loaded.config.max_kl).abs() < 1e-10);
let _ = std::fs::remove_file(path);
}
}