scirs2-neural 0.4.3

Neural network building blocks module for SciRS2 (scirs2-neural) - Minimal Version
Documentation
//! Policy-based reinforcement learning algorithms

use crate::error::{NeuralError, Result};
use crate::layers::{Dense, Layer};
use scirs2_core::ndarray::prelude::*;
use scirs2_core::random::rng;

/// Base trait for policies
pub trait Policy: Send + Sync {
    /// Sample an action from the policy given an observation
    fn sample_action(&self, state: &ArrayView1<f32>) -> Result<Array1<f32>>;

    /// Log probability of `action` under the policy at `state`
    fn log_prob(&self, state: &ArrayView1<f32>, action: &ArrayView1<f32>) -> Result<f32>;

    /// Policy weight matrices (one per linear layer)
    fn parameters(&self) -> Vec<Array2<f32>>;

    /// Replace policy weights
    fn set_parameters(&mut self, params: &[Array2<f32>]) -> Result<()>;
}

/// Simple gradient of a policy with log-prob derivative placeholders
pub struct PolicyGradient {
    pub policy: PolicyNetwork,
    learning_rate: f32,
}

impl PolicyGradient {
    /// Create a new policy gradient wrapper
    pub fn new(policy: PolicyNetwork, learning_rate: f32) -> Self {
        Self {
            policy,
            learning_rate,
        }
    }

    /// Compute the policy-gradient loss (REINFORCE)
    pub fn compute_loss(&self, log_probs: &[f32], returns: &[f32]) -> f32 {
        log_probs
            .iter()
            .zip(returns.iter())
            .map(|(lp, g)| -lp * g)
            .sum::<f32>()
            / log_probs.len().max(1) as f32
    }

    /// Learning rate accessor
    pub fn learning_rate(&self) -> f32 {
        self.learning_rate
    }
}

/// Neural-network policy (actor network)
pub struct PolicyNetwork {
    layers: Vec<Box<dyn Layer<f32>>>,
    action_dim: usize,
    continuous: bool,
    /// Learnable log-standard deviation for continuous actions
    pub log_std: Option<Array1<f32>>,
}

impl PolicyNetwork {
    /// Create a new policy network
    pub fn new(
        state_dim: usize,
        action_dim: usize,
        hidden_sizes: Vec<usize>,
        continuous: bool,
    ) -> Result<Self> {
        let mut layers: Vec<Box<dyn Layer<f32>>> = Vec::new();
        let mut input_size = state_dim;
        for hidden_size in &hidden_sizes {
            layers.push(Box::new(Dense::new(
                input_size,
                *hidden_size,
                Some("relu"),
                &mut rng(),
            )?));
            input_size = *hidden_size;
        }
        let output_activation = if continuous {
            None // Tanh applied externally for bounded actions
        } else {
            Some("softmax")
        };
        layers.push(Box::new(Dense::new(
            input_size,
            action_dim,
            output_activation,
            &mut rng(),
        )?));
        let log_std = if continuous {
            Some(Array1::zeros(action_dim))
        } else {
            None
        };
        Ok(Self {
            layers,
            action_dim,
            continuous,
            log_std,
        })
    }

    /// Forward pass: returns action probabilities (discrete) or mean (continuous)
    pub fn forward(&self, state: &ArrayView1<f32>) -> Result<Array1<f32>> {
        let mut x: ArrayD<f32> = state.to_owned().insert_axis(Axis(0)).into_dyn();
        for layer in &self.layers {
            x = layer.forward(&x)?;
        }
        // Convert back to 1-D
        let out = x.into_dimensionality::<Ix2>().map_err(|e| {
            NeuralError::InvalidArgument(format!("policy forward reshape error: {e}"))
        })?;
        Ok(out.row(0).to_owned())
    }

    /// Sample an action from the policy
    pub fn sample_action(&self, state: &ArrayView1<f32>) -> Result<Array1<f32>> {
        let output = self.forward(state)?;
        if self.continuous {
            // Return the mean (no stochastic sampling in this stub)
            Ok(output)
        } else {
            // Argmax for discrete actions, returned as one-hot
            let best = output
                .iter()
                .enumerate()
                .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("non-NaN"))
                .map(|(i, _)| i)
                .unwrap_or(0);
            let mut action = Array1::zeros(self.action_dim);
            if best < self.action_dim {
                action[best] = 1.0;
            }
            Ok(action)
        }
    }

    /// Log probability of an action
    pub fn log_prob(&self, state: &ArrayView1<f32>, action: &ArrayView1<f32>) -> Result<f32> {
        let output = self.forward(state)?;
        if self.continuous {
            // Gaussian log-prob with learned log_std
            let log_std = self
                .log_std
                .clone()
                .unwrap_or_else(|| Array1::zeros(self.action_dim));
            let mut lp = 0.0f32;
            for i in 0..self.action_dim.min(output.len()).min(action.len()) {
                let std = log_std[i].exp().max(1e-6);
                let diff = action[i] - output[i];
                lp -= 0.5 * (diff / std).powi(2) + log_std[i] + 0.5 * std::f32::consts::TAU.ln();
            }
            Ok(lp)
        } else {
            // Categorical log-prob
            let act_idx = action
                .iter()
                .enumerate()
                .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("non-NaN"))
                .map(|(i, _)| i)
                .unwrap_or(0);
            let prob = output.get(act_idx).copied().unwrap_or(1e-10).max(1e-10);
            Ok(prob.ln())
        }
    }

    /// Whether the policy operates in continuous action space
    pub fn is_continuous(&self) -> bool {
        self.continuous
    }

    /// Action dimensionality
    pub fn action_dim(&self) -> usize {
        self.action_dim
    }
}

impl Policy for PolicyNetwork {
    fn sample_action(&self, state: &ArrayView1<f32>) -> Result<Array1<f32>> {
        PolicyNetwork::sample_action(self, state)
    }

    fn log_prob(&self, state: &ArrayView1<f32>, action: &ArrayView1<f32>) -> Result<f32> {
        PolicyNetwork::log_prob(self, state, action)
    }

    fn parameters(&self) -> Vec<Array2<f32>> {
        // In a full implementation these would be the actual weight matrices.
        Vec::new()
    }

    fn set_parameters(&mut self, _params: &[Array2<f32>]) -> Result<()> {
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_discrete_policy_network() {
        let policy = PolicyNetwork::new(4, 2, vec![8], false).expect("create ok");
        let state = Array1::from_vec(vec![0.1, -0.2, 0.5, 0.3]);
        let action = policy.sample_action(&state.view()).expect("sample ok");
        assert_eq!(action.len(), 2);
        // One-hot: exactly one entry should be 1.0
        assert_eq!(action.iter().filter(|&&x| x > 0.5).count(), 1);
    }

    #[test]
    fn test_continuous_policy_network() {
        let policy = PolicyNetwork::new(4, 3, vec![8], true).expect("create ok");
        let state = Array1::from_vec(vec![0.1, -0.2, 0.5, 0.3]);
        let action = policy.sample_action(&state.view()).expect("sample ok");
        assert_eq!(action.len(), 3);
    }

    #[test]
    fn test_policy_log_prob_discrete() {
        let policy = PolicyNetwork::new(4, 2, vec![8], false).expect("create ok");
        let state = Array1::from_vec(vec![0.1, -0.2, 0.5, 0.3]);
        let action = Array1::from_vec(vec![1.0, 0.0]);
        let lp = policy
            .log_prob(&state.view(), &action.view())
            .expect("log_prob ok");
        assert!(lp <= 0.0, "log-prob of a valid action must be ≤ 0");
    }

    #[test]
    fn test_policy_log_prob_continuous() {
        let policy = PolicyNetwork::new(4, 3, vec![8], true).expect("create ok");
        let state = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0]);
        let action = Array1::from_vec(vec![0.0, 0.0, 0.0]);
        let lp = policy
            .log_prob(&state.view(), &action.view())
            .expect("log_prob ok");
        assert!(lp.is_finite());
    }

    #[test]
    fn test_policy_gradient_loss() {
        let policy = PolicyNetwork::new(2, 2, vec![4], false).expect("create ok");
        let pg = PolicyGradient::new(policy, 1e-3);
        let log_probs = vec![-0.5, -0.3, -0.7];
        let returns = vec![1.0, 2.0, 0.5];
        let loss = pg.compute_loss(&log_probs, &returns);
        assert!(loss.is_finite());
        assert!(loss >= 0.0, "REINFORCE loss should be positive");
    }
}