#![warn(missing_docs)]
#![warn(clippy::all)]
#![allow(clippy::module_name_repetitions)]
#![allow(clippy::module_inception)]
#![allow(clippy::wildcard_imports)]
pub mod error;
pub mod handle;
pub mod ptx_kernels;
pub mod buffer;
pub mod policy;
pub mod estimator;
pub mod loss;
pub mod normalize;
pub mod env;
pub use error::{RlError, RlResult};
pub mod prelude {
pub use crate::buffer::{
NStepBuffer, NStepTransition, PrioritizedReplayBuffer, PrioritySample, Transition,
UniformReplayBuffer,
};
pub use crate::env::env::{Env, EnvInfo, LinearQuadraticEnv, StepResult};
pub use crate::env::vectorized::{VecEnv, VecStepResult};
pub use crate::error::{RlError, RlResult};
pub use crate::estimator::{
GaeConfig, RetraceConfig, TdConfig, VtraceConfig, VtraceOutput, compute_gae,
compute_retrace, compute_td_lambda, compute_vtrace,
};
pub use crate::handle::{LcgRng, RlHandle, SmVersion};
pub use crate::loss::{
DqnConfig, DqnLoss, PpoConfig, PpoLoss, SacConfig, SacLoss, Td3Config, Td3Loss,
double_dqn_loss, dqn_loss, ppo_loss, sac_actor_loss, sac_critic_loss, sac_temperature_loss,
td3_actor_loss, td3_critic_loss,
};
pub use crate::normalize::{ObservationNormalizer, RewardNormalizer, RunningStats};
pub use crate::policy::{
CategoricalPolicy, DeterministicPolicy, GaussianPolicy, deterministic::OrnsteinUhlenbeck,
};
}
#[cfg(test)]
mod tests {
use super::prelude::*;
#[test]
fn e2e_dqn_style_loop() {
let obs_dim = 4;
let n_actions = 2;
let mut buf = UniformReplayBuffer::new(1000, obs_dim, 1);
let mut handle = RlHandle::default_handle();
let mut env = LinearQuadraticEnv::new(obs_dim, 200);
let policy = CategoricalPolicy::new(n_actions);
let mut obs = env.reset().unwrap();
for _ in 0..200 {
let logits = obs.iter().take(n_actions).copied().collect::<Vec<_>>();
let probs = policy.softmax(&logits).unwrap();
let _action = policy.sample_action(&probs, &mut handle).unwrap();
let result = env.step(&[0.0; 4]).unwrap();
buf.push(
obs.clone(),
vec![_action as f32],
result.reward,
result.obs.clone(),
result.done,
);
obs = if result.done {
env.reset().unwrap()
} else {
result.obs
};
}
assert!(buf.len() >= 32, "should have enough transitions");
let batch = buf.sample(32, &mut handle).unwrap();
let q_sa: Vec<f32> = batch.iter().map(|t| t.reward).collect();
let rewards: Vec<f32> = batch.iter().map(|t| t.reward).collect();
let max_q_next: Vec<f32> = batch.iter().map(|_| 0.0).collect();
let dones: Vec<f32> = batch
.iter()
.map(|t| if t.done { 1.0 } else { 0.0 })
.collect();
let is_w = vec![1.0_f32; 32];
let l = dqn_loss(
&q_sa,
&rewards,
&max_q_next,
&dones,
&is_w,
DqnConfig::default(),
)
.unwrap();
assert!(l.loss.is_finite(), "DQN loss should be finite");
}
#[test]
fn e2e_ppo_gae_loss() {
let t = 128;
let rewards: Vec<f32> = (0..t)
.map(|i| if i % 10 == 9 { -1.0 } else { 0.1 })
.collect();
let values: Vec<f32> = vec![0.5; t];
let next_vals: Vec<f32> = vec![0.5; t];
let dones: Vec<f32> = (0..t)
.map(|i| if i % 10 == 9 { 1.0 } else { 0.0 })
.collect();
let gae = compute_gae(&rewards, &values, &next_vals, &dones, GaeConfig::default()).unwrap();
assert_eq!(gae.advantages.len(), t);
let lp_new = vec![-0.693_f32; t]; let lp_old = vec![-0.693_f32; t];
let vp = vec![0.5_f32; t];
let ent = vec![0.693_f32; t];
let ovp = vec![0.5_f32; t];
let l = ppo_loss(
&lp_new,
&lp_old,
&gae.advantages,
&vp,
&gae.returns,
&ent,
&ovp,
PpoConfig::default(),
)
.unwrap();
assert!(
l.total.is_finite(),
"PPO loss should be finite: {}",
l.total
);
assert!(l.clip_fraction >= 0.0 && l.clip_fraction <= 1.0);
}
#[test]
fn e2e_sac_style_update() {
let mut buf = PrioritizedReplayBuffer::new(256, 8, 2, 0.6, 0.4);
let mut handle = RlHandle::default_handle();
for i in 0..256_usize {
buf.push(
vec![i as f32 * 0.01; 8],
vec![0.1_f32; 2],
(i % 5) as f32 * 0.2,
vec![(i + 1) as f32 * 0.01; 8],
i % 20 == 19,
);
}
let batch = buf.sample(32, &mut handle).unwrap();
let q: Vec<f32> = batch.iter().map(|s| s.transition.reward).collect();
let r: Vec<f32> = batch.iter().map(|s| s.transition.reward).collect();
let d: Vec<f32> = batch
.iter()
.map(|s| if s.transition.done { 1.0 } else { 0.0 })
.collect();
let min_qn = vec![0.5_f32; 32];
let lp_next = vec![-0.5_f32; 32];
let is_w: Vec<f32> = batch.iter().map(|s| s.weight).collect();
let (cl, _) =
sac_critic_loss(&q, &r, &d, &min_qn, &lp_next, &is_w, SacConfig::default()).unwrap();
assert!(cl.is_finite(), "SAC critic loss should be finite");
}
#[test]
fn e2e_vecenv_with_obs_norm() {
let envs: Vec<_> = (0..4).map(|_| LinearQuadraticEnv::new(3, 50)).collect();
let mut ve = VecEnv::new(envs);
let mut norm = ObservationNormalizer::new(3);
let init_obs = ve.reset_all().unwrap();
for chunk in init_obs.chunks_exact(3) {
norm.process_one(chunk).unwrap();
}
let actions = vec![0.01_f32; 4 * 3];
for _ in 0..20 {
let result = ve.step(&actions).unwrap();
for chunk in result.obs.chunks_exact(3) {
let _norm_obs = norm.process_one(chunk).unwrap();
}
}
assert!(norm.count() > 0);
}
#[test]
fn e2e_n_step_buffer() {
let mut nsbuf = NStepBuffer::new(3, 0.99);
let mut transitions = Vec::new();
for i in 0..20_usize {
if let Some(t) = nsbuf.push([i as f32], [0.0], 1.0, [(i + 1) as f32], false) {
transitions.push(t);
}
}
assert!(!transitions.is_empty(), "n-step should produce transitions");
for t in &transitions {
assert_eq!(t.actual_n, 3);
assert!(
(t.n_step_return - (1.0 + 0.99 + 0.99_f32 * 0.99)).abs() < 0.01,
"n_step_return={}",
t.n_step_return
);
}
}
}