#[allow(dead_code)]
use super::{
ActionDistribution, DistributionType, PolicyNetwork, RLOptimizationMetrics, RLOptimizerConfig,
RLScheduler, TrajectoryBatch, ValueNetwork,
};
use crate::error::{OptimError, Result};
use scirs2_core::ndarray::{Array1, Array2, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use scirs2_core::random::Rng;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy)]
pub enum ActorCriticMethod {
A2C,
A3C,
SAC,
TD3,
DDPG,
D4PG,
MPO,
}
#[derive(Debug, Clone)]
pub struct ActorCriticConfig<T: Float + Debug + Send + Sync + 'static> {
pub base_config: RLOptimizerConfig<T>,
pub method: ActorCriticMethod,
pub sac_config: SACConfig<T>,
pub td3_config: TD3Config<T>,
pub ddpg_config: DDPGConfig<T>,
pub use_target_networks: bool,
pub target_update_rate: T,
pub target_hard_update_freq: Option<usize>,
pub replay_buffer_size: usize,
pub prioritized_replay: bool,
pub per_alpha: T,
pub per_beta: T,
pub n_critics: usize,
}
#[derive(Debug, Clone)]
pub struct SACConfig<T: Float + Debug + Send + Sync + 'static> {
pub temperature: T,
pub auto_entropy_tuning: bool,
pub target_entropy: Option<T>,
pub temperature_lr: T,
pub use_reparameterization: bool,
pub policy_update_freq: usize,
pub target_update_freq: usize,
}
#[derive(Debug, Clone)]
pub struct TD3Config<T: Float + Debug + Send + Sync + 'static> {
pub policy_noise: T,
pub noise_clip: T,
pub policy_delay: usize,
pub exploration_noise: T,
pub action_bounds: Option<(T, T)>,
}
#[derive(Debug, Clone)]
pub struct DDPGConfig<T: Float + Debug + Send + Sync + 'static> {
pub exploration_noise: T,
pub ou_noise_theta: T,
pub ou_noise_sigma: T,
pub action_bounds: Option<(T, T)>,
}
impl<T: Float + Debug + Send + Sync + 'static> Default for ActorCriticConfig<T> {
fn default() -> Self {
Self {
base_config: RLOptimizerConfig::default(),
method: ActorCriticMethod::A2C,
sac_config: SACConfig::default(),
td3_config: TD3Config::default(),
ddpg_config: DDPGConfig::default(),
use_target_networks: false,
target_update_rate: T::from(0.005).unwrap_or_else(|| T::zero()),
target_hard_update_freq: None,
replay_buffer_size: 100000,
prioritized_replay: false,
per_alpha: T::from(0.6).unwrap_or_else(|| T::zero()),
per_beta: T::from(0.4).unwrap_or_else(|| T::zero()),
n_critics: 1,
}
}
}
impl<T: Float + Debug + Send + Sync + 'static> Default for SACConfig<T> {
fn default() -> Self {
Self {
temperature: T::from(0.2).unwrap_or_else(|| T::zero()),
auto_entropy_tuning: true,
target_entropy: None,
temperature_lr: T::from(3e-4).unwrap_or_else(|| T::zero()),
use_reparameterization: true,
policy_update_freq: 1,
target_update_freq: 1,
}
}
}
impl<T: Float + Debug + Send + Sync + 'static> Default for TD3Config<T> {
fn default() -> Self {
Self {
policy_noise: T::from(0.2).unwrap_or_else(|| T::zero()),
noise_clip: T::from(0.5).unwrap_or_else(|| T::zero()),
policy_delay: 2,
exploration_noise: T::from(0.1).unwrap_or_else(|| T::zero()),
action_bounds: Some((T::from(-1.0).unwrap_or_else(|| T::zero()), T::from(1.0).unwrap_or_else(|| T::zero()))),
}
}
}
impl<T: Float + Debug + Send + Sync + 'static> Default for DDPGConfig<T> {
fn default() -> Self {
Self {
exploration_noise: T::from(0.1).unwrap_or_else(|| T::zero()),
ou_noise_theta: T::from(0.15).unwrap_or_else(|| T::zero()),
ou_noise_sigma: T::from(0.2).unwrap_or_else(|| T::zero()),
action_bounds: Some((T::from(-1.0).unwrap_or_else(|| T::zero()), T::from(1.0).unwrap_or_else(|| T::zero()))),
}
}
}
pub struct ActorCriticOptimizer<T: Float + Debug, P: PolicyNetwork<T>, V: ValueNetwork<T>> {
config: ActorCriticConfig<T>,
actor: P,
critics: Vec<V>,
target_actor: Option<P>,
target_critics: Option<Vec<V>>,
temperature: T,
actor_scheduler: Option<RLScheduler<T>>,
critic_scheduler: Option<RLScheduler<T>>,
temperature_scheduler: Option<RLScheduler<T>>,
metrics: ActorCriticMetrics<T>,
update_count: usize,
policy_update_count: usize,
replay_buffer: ExperienceReplayBuffer<T>,
ou_noise_state: Option<Array1<T>>,
}
#[derive(Debug, Clone)]
pub struct ActorCriticMetrics<T: Float + Debug + Send + Sync + 'static> {
pub base_metrics: RLOptimizationMetrics<T>,
pub actor_loss: T,
pub critic_losses: Vec<T>,
pub temperature: Option<T>,
pub temperature_loss: Option<T>,
pub q_values_mean: T,
pub q_values_std: T,
pub target_q_mean: T,
pub target_q_std: T,
pub policy_entropy: T,
pub critic_grad_norms: Vec<T>,
pub replay_buffer_size: usize,
pub replay_sampling_time: Option<std::time::Duration>,
}
impl<T: Float + Debug + Send + Sync + 'static> Default for ActorCriticMetrics<T> {
fn default() -> Self {
Self {
base_metrics: RLOptimizationMetrics::default(),
actor_loss: T::zero(),
critic_losses: vec![T::zero()],
temperature: None,
temperature_loss: None,
q_values_mean: T::zero(),
q_values_std: T::zero(),
target_q_mean: T::zero(),
target_q_std: T::zero(),
policy_entropy: T::zero(),
critic_grad_norms: vec![T::zero()],
replay_buffer_size: 0,
replay_sampling_time: None,
}
}
}
#[derive(Debug, Clone)]
pub struct Experience<T: Float + Debug + Send + Sync + 'static> {
pub state: Array1<T>,
pub action: Array1<T>,
pub reward: T,
pub next_state: Array1<T>,
pub done: bool,
pub priority: T,
pub info: HashMap<String, T>,
}
pub struct ExperienceReplayBuffer<T: Float + Debug + Send + Sync + 'static> {
buffer: Vec<Experience<T>>,
maxsize: usize,
position: usize,
is_full: bool,
alpha: T,
beta: T,
priority_tree: Option<Vec<T>>,
}
impl<T: Float + Debug + Send + Sync + 'static> ExperienceReplayBuffer<T> {
pub fn new(maxsize: usize, alpha: T, beta: T) -> Self {
Self {
buffer: Vec::with_capacity(maxsize),
maxsize,
position: 0,
is_full: false,
alpha,
beta,
priority_tree: None,
}
}
pub fn add(&mut self, experience: Experience<T>) {
if self.buffer.len() < self.maxsize {
self.buffer.push(experience);
} else {
self.buffer[self.position] = experience;
self.is_full = true;
}
self.position = (self.position + 1) % self.maxsize;
}
pub fn sample(&self, batchsize: usize) -> Vec<Experience<T>> {
let available_size = if self.is_full {
self.maxsize
} else {
self.buffer.len()
};
let sample_size = batchsize.min(available_size);
let mut samples = Vec::new();
for _ in 0..sample_size {
let idx = scirs2_core::random::thread_rng().gen_range(0..available_size);
samples.push(self.buffer[idx].clone());
}
samples
}
pub fn len(&self) -> usize {
if self.is_full {
self.maxsize
} else {
self.buffer.len()
}
}
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
}
impl<
T: Float + scirs2_core::numeric::FromPrimitive + std::iter::Sum + Send + Sync + ScalarOperand,
P: PolicyNetwork<T>,
V: ValueNetwork<T>,
> ActorCriticOptimizer<T, P, V>
{
pub fn new(config: ActorCriticConfig<T>, actor: P, critics: Vec<V>) -> Result<Self> {
if critics.is_empty() {
return Err(OptimError::InvalidConfig(
"At least one critic required".to_string(),
));
}
let replay_buffer = ExperienceReplayBuffer::new(
config.replay_buffer_size,
config.per_alpha,
config.per_beta,
);
let temperature = config.sac_config.temperature;
Ok(Self {
config,
actor,
critics,
target_actor: None,
target_critics: None,
temperature,
actor_scheduler: None,
critic_scheduler: None,
temperature_scheduler: None,
metrics: ActorCriticMetrics::default(),
update_count: 0,
policy_update_count: 0,
replay_buffer,
ou_noise_state: None,
})
}
pub fn update_from_replay(&mut self, batchsize: usize) -> Result<ActorCriticMetrics<T>> {
if self.replay_buffer.len() < batchsize {
return Err(OptimError::InvalidConfig(
"Not enough experiences in buffer".to_string(),
));
}
let experiences = self.replay_buffer.sample(batchsize);
match self.config.method {
ActorCriticMethod::SAC => self.update_sac(&experiences),
ActorCriticMethod::TD3 => self.update_td3(&experiences),
ActorCriticMethod::DDPG => self.update_ddpg(&experiences),
ActorCriticMethod::A2C => Err(OptimError::InvalidConfig(
"Method not implemented".to_string(),
)),
_ => Err(OptimError::InvalidConfig(
"Method not implemented".to_string(),
)),
}
}
pub fn update_from_trajectory(
&mut self,
trajectory: TrajectoryBatch<T>,
) -> Result<ActorCriticMetrics<T>> {
match self.config.method {
ActorCriticMethod::A2C => self.update_a2c(trajectory),
ActorCriticMethod::A3C => Err(OptimError::InvalidConfig(
"Method requires experience replay".to_string(),
)),
_ => Err(OptimError::InvalidConfig(
"Method requires experience replay".to_string(),
)),
}
}
fn update_sac(&mut self, experiences: &[Experience<T>]) -> Result<ActorCriticMetrics<T>> {
let _batch_size = experiences.len();
let states = self.extract_states(experiences)?;
let actions = self.extract_actions(experiences)?;
let rewards = self.extract_rewards(experiences)?;
let next_states = self.extract_next_states(experiences)?;
let dones = self.extract_dones(experiences)?;
let mut critic_losses = Vec::new();
for (_i, critic) in self.critics.iter().enumerate() {
let q_values = self.compute_q_values(critic, &states, &actions)?;
let targetq = self.compute_target_q_sac(&next_states, &rewards, &dones)?;
let critic_loss = self.compute_critic_loss(&q_values, &targetq)?;
critic_losses.push(critic_loss);
}
let actor_loss = self.compute_actor_loss_sac(&states)?;
let temperature_loss = if self.config.sac_config.auto_entropy_tuning {
Some(self.update_temperature_sac(&states)?)
} else {
None
};
if self.config.use_target_networks {
self.soft_update_targets()?;
}
self.metrics.actor_loss = actor_loss;
self.metrics.critic_losses = critic_losses;
self.metrics.temperature = Some(self.temperature);
self.metrics.temperature_loss = temperature_loss;
self.metrics.replay_buffer_size = self.replay_buffer.len();
self.update_count += 1;
Ok(self.metrics.clone())
}
fn update_td3(&mut self, experiences: &[Experience<T>]) -> Result<ActorCriticMetrics<T>> {
let _batch_size = experiences.len();
let states = self.extract_states(experiences)?;
let actions = self.extract_actions(experiences)?;
let rewards = self.extract_rewards(experiences)?;
let next_states = self.extract_next_states(experiences)?;
let dones = self.extract_dones(experiences)?;
let mut critic_losses = Vec::new();
if self.critics.len() >= 2 {
let target_actions = if let Some(ref target_actor) = self.target_actor {
let target_action_dist = target_actor.get_action_distribution(&next_states)?;
let mut target_actions =
self.sample_actions_from_distribution(&target_action_dist)?;
for action in target_actions.iter_mut() {
let noise = T::from(scirs2_core::random::thread_rng().random_f64() - 0.5).expect("unwrap failed")
* T::from(2.0).unwrap_or_else(|| T::zero())
* self.config.td3_config.policy_noise;
let clipped_noise = noise
.max(-self.config.td3_config.noise_clip)
.min(self.config.td3_config.noise_clip);
*action = *action + clipped_noise;
if let Some((min_action, max_action)) = self.config.td3_config.action_bounds {
*action = action.max(min_action).min(max_action);
}
}
target_actions
} else {
actions.clone()
};
let target_q1 = if let Some(ref target_critics) = self.target_critics {
if target_critics.len() >= 2 {
self.compute_q_values(&target_critics[0], &next_states, &target_actions)?
} else {
self.compute_q_values(&self.critics[0], &next_states, &target_actions)?
}
} else {
self.compute_q_values(&self.critics[0], &next_states, &target_actions)?
};
let target_q2 = if let Some(ref target_critics) = self.target_critics {
if target_critics.len() >= 2 {
self.compute_q_values(&target_critics[1], &next_states, &target_actions)?
} else {
self.compute_q_values(&self.critics[1], &next_states, &target_actions)?
}
} else {
self.compute_q_values(&self.critics[1], &next_states, &target_actions)?
};
let mut min_target_q = Array1::zeros(target_q1.len());
for i in 0..target_q1.len() {
min_target_q[i] = target_q1[i].min(target_q2[i]);
}
let gamma = self.config.base_config.discount_factor;
let mut td_targets = Array1::zeros(rewards.len());
for i in 0..rewards.len() {
td_targets[i] = rewards[i]
+ gamma * min_target_q[i] * T::from(if dones[i] { 0.0 } else { 1.0 }).unwrap_or_else(|| T::zero());
}
for (_i, critic) in self.critics.iter().enumerate().take(2) {
let q_values = self.compute_q_values(critic, &states, &actions)?;
let critic_loss = self.compute_critic_loss(&q_values, &td_targets)?;
critic_losses.push(critic_loss);
}
}
let actor_loss = if self.update_count % self.config.td3_config.policy_delay == 0 {
self.compute_actor_loss_td3(&states)?
} else {
T::zero()
};
if self.config.use_target_networks {
self.soft_update_targets()?;
}
self.metrics.actor_loss = actor_loss;
self.metrics.critic_losses = critic_losses;
self.metrics.replay_buffer_size = self.replay_buffer.len();
self.update_count += 1;
Ok(self.metrics.clone())
}
fn update_ddpg(&mut self, experiences: &[Experience<T>]) -> Result<ActorCriticMetrics<T>> {
let _batch_size = experiences.len();
let states = self.extract_states(experiences)?;
let actions = self.extract_actions(experiences)?;
let rewards = self.extract_rewards(experiences)?;
let next_states = self.extract_next_states(experiences)?;
let dones = self.extract_dones(experiences)?;
let target_actions = if let Some(ref target_actor) = self.target_actor {
let target_action_dist = target_actor.get_action_distribution(&next_states)?;
self.sample_actions_from_distribution(&target_action_dist)?
} else {
let action_dist = self.actor.get_action_distribution(&next_states)?;
self.sample_actions_from_distribution(&action_dist)?
};
let targetq = if let Some(ref target_critics) = self.target_critics {
self.compute_q_values(&target_critics[0], &next_states, &target_actions)?
} else {
self.compute_q_values(&self.critics[0], &next_states, &target_actions)?
};
let gamma = self.config.base_config.discount_factor;
let mut td_targets = Array1::zeros(rewards.len());
for i in 0..rewards.len() {
td_targets[i] = rewards[i]
+ gamma * targetq[i] * T::from(if dones[i] { 0.0 } else { 1.0 }).unwrap_or_else(|| T::zero());
}
let q_values = self.compute_q_values(&self.critics[0], &states, &actions)?;
let critic_loss = self.compute_critic_loss(&q_values, &td_targets)?;
let actor_loss = self.compute_actor_loss_ddpg(&states)?;
if self.config.use_target_networks {
self.soft_update_targets()?;
}
self.update_ou_noise()?;
self.metrics.actor_loss = actor_loss;
self.metrics.critic_losses = vec![critic_loss];
self.metrics.replay_buffer_size = self.replay_buffer.len();
self.update_count += 1;
Ok(self.metrics.clone())
}
fn compute_actor_loss_td3(&self, states: &Array2<T>) -> Result<T> {
let action_dist = self.actor.get_action_distribution(states)?;
let actions = self.sample_actions_from_distribution(&action_dist)?;
let q_values = self.compute_q_values(&self.critics[0], states, &actions)?;
let actor_loss =
-q_values.iter().copied().sum::<T>() / T::from(q_values.len()).unwrap_or(T::zero());
Ok(actor_loss)
}
fn compute_actor_loss_ddpg(&self, states: &Array2<T>) -> Result<T> {
let action_dist = self.actor.get_action_distribution(states)?;
let actions = self.sample_actions_from_distribution(&action_dist)?;
let q_values = self.compute_q_values(&self.critics[0], states, &actions)?;
let actor_loss =
-q_values.iter().copied().sum::<T>() / T::from(q_values.len()).unwrap_or(T::zero());
Ok(actor_loss)
}
fn update_ou_noise(&mut self) -> Result<()> {
if let Some(ref mut ou_state) = self.ou_noise_state {
let theta = self.config.ddpg_config.ou_noise_theta;
let sigma = self.config.ddpg_config.ou_noise_sigma;
for noise in ou_state.iter_mut() {
let dx = -theta * *noise
+ sigma * T::from(scirs2_core::random::thread_rng().random_f64() - 0.5).expect("unwrap failed");
*noise = *noise + dx;
}
}
Ok(())
}
fn update_a2c(&mut self, trajectory: TrajectoryBatch<T>) -> Result<ActorCriticMetrics<T>> {
let mut traj_copy = trajectory;
let next_value = if let Some(critic) = self.critics.first() {
let last_obs = traj_copy.observations.slice(s![-1.., ..]).to_owned();
let mut last_obs_batch = Array2::zeros((1, last_obs.ncols()));
last_obs_batch.row_mut(0).assign(&last_obs.row(0));
critic.evaluate_value(&last_obs_batch)?[0]
} else {
T::zero()
};
traj_copy.compute_advantages(
self.config.base_config.discount_factor,
self.config.base_config.gae_lambda,
next_value,
)?;
let values = self.critics[0].evaluate_value(&traj_copy.observations)?;
let critic_loss = (&values - &traj_copy.returns)
.mapv(|x| x * x)
.mean()
.unwrap_or(T::zero());
let policy_eval = self
.actor
.evaluate_actions(&traj_copy.observations, &traj_copy.actions)?;
let actor_loss = -(policy_eval.log_probs * traj_copy.advantages)
.mean()
.unwrap_or(T::zero());
self.metrics.actor_loss = actor_loss;
self.metrics.critic_losses = vec![critic_loss];
self.metrics.policy_entropy = policy_eval.entropy.iter().copied().sum::<T>()
/ T::from(policy_eval.entropy.len()).unwrap_or(T::zero());
self.update_count += 1;
Ok(self.metrics.clone())
}
fn update_a3c(&mut self, trajectory: TrajectoryBatch<T>) -> Result<ActorCriticMetrics<T>> {
self.update_a2c(trajectory)
}
fn update_a2c_from_experiences(
&mut self,
experiences: &[Experience<T>],
) -> Result<ActorCriticMetrics<T>> {
let trajectory = self.experiences_to_trajectory(experiences)?;
self.update_a2c(trajectory)
}
pub fn add_experience(&mut self, experience: Experience<T>) {
self.replay_buffer.add(experience);
}
pub fn get_metrics(&self) -> &ActorCriticMetrics<T> {
&self.metrics
}
fn extract_states(&self, experiences: &[Experience<T>]) -> Result<Array2<T>> {
if experiences.is_empty() {
return Err(OptimError::InvalidConfig(
"Empty experience batch".to_string(),
));
}
let batchsize = experiences.len();
let state_dim = experiences[0].state.len();
let mut states = Array2::zeros((batchsize, state_dim));
for (i, exp) in experiences.iter().enumerate() {
states.row_mut(i).assign(&exp.state);
}
Ok(states)
}
fn extract_actions(&self, experiences: &[Experience<T>]) -> Result<Array2<T>> {
let batchsize = experiences.len();
let action_dim = experiences[0].action.len();
let mut actions = Array2::zeros((batchsize, action_dim));
for (i, exp) in experiences.iter().enumerate() {
actions.row_mut(i).assign(&exp.action);
}
Ok(actions)
}
fn extract_rewards(&self, experiences: &[Experience<T>]) -> Result<Array1<T>> {
let rewards: Vec<T> = experiences.iter().map(|exp| exp.reward).collect();
Ok(Array1::from_vec(rewards))
}
fn extract_next_states(&self, experiences: &[Experience<T>]) -> Result<Array2<T>> {
let batchsize = experiences.len();
let state_dim = experiences[0].next_state.len();
let mut next_states = Array2::zeros((batchsize, state_dim));
for (i, exp) in experiences.iter().enumerate() {
next_states.row_mut(i).assign(&exp.next_state);
}
Ok(next_states)
}
fn extract_dones(&self, experiences: &[Experience<T>]) -> Result<Array1<bool>> {
let dones: Vec<bool> = experiences.iter().map(|exp| exp.done).collect();
Ok(Array1::from_vec(dones))
}
fn compute_q_values(
&self,
critic: &V,
states: &Array2<T>,
_actions: &Array2<T>,
) -> Result<Array1<T>> {
critic.evaluate_value(states)
}
fn compute_target_q_sac(
&self,
_next_states: &Array2<T>,
rewards: &Array1<T>,
_dones: &Array1<bool>,
) -> Result<Array1<T>> {
Ok(rewards.clone())
}
fn compute_critic_loss(&self, q_values: &Array1<T>, targetq: &Array1<T>) -> Result<T> {
Ok((q_values - targetq)
.mapv(|x| x * x)
.mean()
.unwrap_or(T::zero()))
}
fn compute_actor_loss_sac(&self, states: &Array2<T>) -> Result<T> {
let action_dist = self.actor.get_action_distribution(states)?;
let sampled_actions = self.sample_actions_from_distribution(&action_dist)?;
let q_values = if self.critics.len() >= 2 {
let q1 = self.compute_q_values(&self.critics[0], states, &sampled_actions)?;
let q2 = self.compute_q_values(&self.critics[1], states, &sampled_actions)?;
let mut min_q = Array1::zeros(q1.len());
for i in 0..q1.len() {
min_q[i] = q1[i].min(q2[i]);
}
min_q
} else {
self.compute_q_values(&self.critics[0], states, &sampled_actions)?
};
let log_probs = self.compute_log_probabilities(&action_dist, &sampled_actions)?;
let entropy_term = log_probs * self.temperature;
let actor_objective = q_values - entropy_term;
let actor_loss = -actor_objective.iter().copied().sum::<T>()
/ T::from(actor_objective.len()).unwrap_or(T::zero());
Ok(actor_loss)
}
fn update_temperature_sac(&mut self, states: &Array2<T>) -> Result<T> {
if !self.config.sac_config.auto_entropy_tuning {
return Ok(T::zero());
}
let action_dist = self.actor.get_action_distribution(states)?;
let sampled_actions = self.sample_actions_from_distribution(&action_dist)?;
let log_probs = self.compute_log_probabilities(&action_dist, &sampled_actions)?;
let current_entropy =
-log_probs.iter().copied().sum::<T>() / T::from(log_probs.len()).unwrap_or(T::zero());
let target_entropy = self
.config
.sac_config
.target_entropy
.unwrap_or(-T::from(sampled_actions.ncols()).expect("unwrap failed"));
let temperature_loss = self.temperature * (target_entropy - current_entropy);
let temp_lr = self.config.sac_config.temperature_lr;
let temp_gradient = target_entropy - current_entropy;
self.temperature =
(self.temperature - temp_lr * temp_gradient).max(T::from(0.001).unwrap_or_else(|| T::zero()));
Ok(temperature_loss)
}
fn sample_actions_from_distribution(
&self,
action_dist: &ActionDistribution<T>,
) -> Result<Array2<T>> {
match action_dist.distribution_type {
DistributionType::Gaussian => {
if let (Some(ref mean), Some(ref std)) = (&action_dist.mean, &action_dist.std) {
let mut actions = mean.clone();
for ((action, &m), &s) in actions.iter_mut().zip(mean.iter()).zip(std.iter()) {
let noise = T::from(scirs2_core::random::thread_rng().random_f64() - 0.5).expect("unwrap failed")
* T::from(2.0).unwrap_or_else(|| T::zero()); *action = m + s * noise;
}
Ok(actions)
} else {
Err(OptimError::InvalidConfig(
"Invalid Gaussian distribution".to_string(),
))
}
}
DistributionType::Categorical => {
if let Some(ref logits) = action_dist.logits {
let mut actions = Array2::zeros(logits.dim());
for i in 0..logits.nrows() {
let row = logits.row(i);
let max_logit = row.iter().fold(T::neg_infinity(), |acc, &x| acc.max(x));
let exp_logits: Vec<T> =
row.iter().map(|&x| (x - max_logit).exp()).collect();
let sum_exp: T = exp_logits.iter().cloned().sum();
let mut max_idx = 0;
let mut max_prob = T::zero();
for (j, &prob) in exp_logits.iter().enumerate() {
let normalized_prob = prob / sum_exp;
if normalized_prob > max_prob {
max_prob = normalized_prob;
max_idx = j;
}
}
actions[[i, max_idx]] = T::one();
}
Ok(actions)
} else {
Err(OptimError::InvalidConfig(
"Invalid categorical distribution".to_string(),
))
}
}
_ => Err(OptimError::InvalidConfig(
"Unsupported distribution type".to_string(),
)),
}
}
fn compute_log_probabilities(
&self,
action_dist: &ActionDistribution<T>,
actions: &Array2<T>,
) -> Result<Array1<T>> {
match action_dist.distribution_type {
DistributionType::Gaussian => {
if let (Some(ref mean), Some(ref std)) = (&action_dist.mean, &action_dist.std) {
let mut log_probs = Array1::zeros(actions.nrows());
for i in 0..actions.nrows() {
let mut log_prob = T::zero();
for j in 0..actions.ncols() {
let action = actions[[i, j]];
let mu = mean[[i, j]];
let sigma = std[[i, j]];
let normalized_diff = (action - mu) / sigma;
let log_prob_term =
-T::from(0.5).unwrap_or_else(|| T::zero()) * normalized_diff * normalized_diff
- sigma.ln()
- T::from(0.5 * 2.0 * std::f64::consts::PI).unwrap_or_else(|| T::zero()).ln();
log_prob = log_prob + log_prob_term;
}
log_probs[i] = log_prob;
}
Ok(log_probs)
} else {
Err(OptimError::InvalidConfig(
"Invalid Gaussian distribution".to_string(),
))
}
}
DistributionType::Categorical => {
if let Some(ref logits) = action_dist.logits {
let mut log_probs = Array1::zeros(actions.nrows());
for i in 0..actions.nrows() {
let mut action_idx = 0;
for j in 0..actions.ncols() {
if actions[[i, j]] > T::from(0.5).unwrap_or_else(|| T::zero()) {
action_idx = j;
break;
}
}
let row = logits.row(i);
let max_logit = row.iter().fold(T::neg_infinity(), |acc, &x| acc.max(x));
let log_sum_exp = (row.iter().map(|&x| (x - max_logit).exp()).sum::<T>())
.ln()
+ max_logit;
log_probs[i] = logits[[i, action_idx]] - log_sum_exp;
}
Ok(log_probs)
} else {
Err(OptimError::InvalidConfig(
"Invalid categorical distribution".to_string(),
))
}
}
_ => Err(OptimError::InvalidConfig(
"Unsupported distribution type".to_string(),
)),
}
}
fn soft_update_targets(&mut self) -> Result<()> {
let tau = self.config.target_update_rate;
let one_minus_tau = T::one() - tau;
if let Some(ref mut target_critics) = self.target_critics {
for (target_critic, online_critic) in target_critics.iter_mut().zip(self.critics.iter())
{
let online_params = online_critic.get_parameters();
let mut target_params = target_critic.get_parameters();
for (param_name, online_param) in online_params {
if let Some(target_param) = target_params.get_mut(¶m_name) {
*target_param =
&(target_param.clone() * one_minus_tau) + &(online_param * tau);
}
}
target_critic.update_parameters(&target_params)?;
}
}
if let Some(ref mut target_actor) = self.target_actor {
let online_params = self.actor.get_parameters();
let mut target_params = target_actor.get_parameters();
for (param_name, online_param) in online_params {
if let Some(target_param) = target_params.get_mut(¶m_name) {
*target_param = &(target_param.clone() * one_minus_tau) + &(online_param * tau);
}
}
target_actor.update_parameters(&target_params)?;
}
Ok(())
}
fn experiences_to_trajectory(
&self,
experiences: &[Experience<T>],
) -> Result<TrajectoryBatch<T>> {
let states = self.extract_states(experiences)?;
let actions = self.extract_actions(experiences)?;
let rewards = self.extract_rewards(experiences)?;
let dones = self.extract_dones(experiences)?;
let log_probs = Array1::zeros(experiences.len());
let values = Array1::zeros(experiences.len());
TrajectoryBatch::new(states, actions, log_probs, rewards, values, dones)
}
}
use scirs2_core::ndarray::s;