use crate::error::{NeuralError, Result};
use crate::layers::{Dense, Layer};
use scirs2_core::ndarray::prelude::*;
use scirs2_core::random::rng;
pub trait Policy: Send + Sync {
fn sample_action(&self, state: &ArrayView1<f32>) -> Result<Array1<f32>>;
fn log_prob(&self, state: &ArrayView1<f32>, action: &ArrayView1<f32>) -> Result<f32>;
fn parameters(&self) -> Vec<Array2<f32>>;
fn set_parameters(&mut self, params: &[Array2<f32>]) -> Result<()>;
}
pub struct PolicyGradient {
pub policy: PolicyNetwork,
learning_rate: f32,
}
impl PolicyGradient {
pub fn new(policy: PolicyNetwork, learning_rate: f32) -> Self {
Self {
policy,
learning_rate,
}
}
pub fn compute_loss(&self, log_probs: &[f32], returns: &[f32]) -> f32 {
log_probs
.iter()
.zip(returns.iter())
.map(|(lp, g)| -lp * g)
.sum::<f32>()
/ log_probs.len().max(1) as f32
}
pub fn learning_rate(&self) -> f32 {
self.learning_rate
}
}
pub struct PolicyNetwork {
layers: Vec<Box<dyn Layer<f32>>>,
action_dim: usize,
continuous: bool,
pub log_std: Option<Array1<f32>>,
}
impl PolicyNetwork {
pub fn new(
state_dim: usize,
action_dim: usize,
hidden_sizes: Vec<usize>,
continuous: bool,
) -> Result<Self> {
let mut layers: Vec<Box<dyn Layer<f32>>> = Vec::new();
let mut input_size = state_dim;
for hidden_size in &hidden_sizes {
layers.push(Box::new(Dense::new(
input_size,
*hidden_size,
Some("relu"),
&mut rng(),
)?));
input_size = *hidden_size;
}
let output_activation = if continuous {
None } else {
Some("softmax")
};
layers.push(Box::new(Dense::new(
input_size,
action_dim,
output_activation,
&mut rng(),
)?));
let log_std = if continuous {
Some(Array1::zeros(action_dim))
} else {
None
};
Ok(Self {
layers,
action_dim,
continuous,
log_std,
})
}
pub fn forward(&self, state: &ArrayView1<f32>) -> Result<Array1<f32>> {
let mut x: ArrayD<f32> = state.to_owned().insert_axis(Axis(0)).into_dyn();
for layer in &self.layers {
x = layer.forward(&x)?;
}
let out = x.into_dimensionality::<Ix2>().map_err(|e| {
NeuralError::InvalidArgument(format!("policy forward reshape error: {e}"))
})?;
Ok(out.row(0).to_owned())
}
pub fn sample_action(&self, state: &ArrayView1<f32>) -> Result<Array1<f32>> {
let output = self.forward(state)?;
if self.continuous {
Ok(output)
} else {
let best = output
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("non-NaN"))
.map(|(i, _)| i)
.unwrap_or(0);
let mut action = Array1::zeros(self.action_dim);
if best < self.action_dim {
action[best] = 1.0;
}
Ok(action)
}
}
pub fn log_prob(&self, state: &ArrayView1<f32>, action: &ArrayView1<f32>) -> Result<f32> {
let output = self.forward(state)?;
if self.continuous {
let log_std = self
.log_std
.clone()
.unwrap_or_else(|| Array1::zeros(self.action_dim));
let mut lp = 0.0f32;
for i in 0..self.action_dim.min(output.len()).min(action.len()) {
let std = log_std[i].exp().max(1e-6);
let diff = action[i] - output[i];
lp -= 0.5 * (diff / std).powi(2) + log_std[i] + 0.5 * std::f32::consts::TAU.ln();
}
Ok(lp)
} else {
let act_idx = action
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("non-NaN"))
.map(|(i, _)| i)
.unwrap_or(0);
let prob = output.get(act_idx).copied().unwrap_or(1e-10).max(1e-10);
Ok(prob.ln())
}
}
pub fn is_continuous(&self) -> bool {
self.continuous
}
pub fn action_dim(&self) -> usize {
self.action_dim
}
}
impl Policy for PolicyNetwork {
fn sample_action(&self, state: &ArrayView1<f32>) -> Result<Array1<f32>> {
PolicyNetwork::sample_action(self, state)
}
fn log_prob(&self, state: &ArrayView1<f32>, action: &ArrayView1<f32>) -> Result<f32> {
PolicyNetwork::log_prob(self, state, action)
}
fn parameters(&self) -> Vec<Array2<f32>> {
Vec::new()
}
fn set_parameters(&mut self, _params: &[Array2<f32>]) -> Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_discrete_policy_network() {
let policy = PolicyNetwork::new(4, 2, vec![8], false).expect("create ok");
let state = Array1::from_vec(vec![0.1, -0.2, 0.5, 0.3]);
let action = policy.sample_action(&state.view()).expect("sample ok");
assert_eq!(action.len(), 2);
assert_eq!(action.iter().filter(|&&x| x > 0.5).count(), 1);
}
#[test]
fn test_continuous_policy_network() {
let policy = PolicyNetwork::new(4, 3, vec![8], true).expect("create ok");
let state = Array1::from_vec(vec![0.1, -0.2, 0.5, 0.3]);
let action = policy.sample_action(&state.view()).expect("sample ok");
assert_eq!(action.len(), 3);
}
#[test]
fn test_policy_log_prob_discrete() {
let policy = PolicyNetwork::new(4, 2, vec![8], false).expect("create ok");
let state = Array1::from_vec(vec![0.1, -0.2, 0.5, 0.3]);
let action = Array1::from_vec(vec![1.0, 0.0]);
let lp = policy
.log_prob(&state.view(), &action.view())
.expect("log_prob ok");
assert!(lp <= 0.0, "log-prob of a valid action must be ≤ 0");
}
#[test]
fn test_policy_log_prob_continuous() {
let policy = PolicyNetwork::new(4, 3, vec![8], true).expect("create ok");
let state = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0]);
let action = Array1::from_vec(vec![0.0, 0.0, 0.0]);
let lp = policy
.log_prob(&state.view(), &action.view())
.expect("log_prob ok");
assert!(lp.is_finite());
}
#[test]
fn test_policy_gradient_loss() {
let policy = PolicyNetwork::new(2, 2, vec![4], false).expect("create ok");
let pg = PolicyGradient::new(policy, 1e-3);
let log_probs = vec![-0.5, -0.3, -0.7];
let returns = vec![1.0, 2.0, 0.5];
let loss = pg.compute_loss(&log_probs, &returns);
assert!(loss.is_finite());
assert!(loss >= 0.0, "REINFORCE loss should be positive");
}
}