Skip to main content

oxicuda_rl/
lib.rs

1//! # OxiCUDA-RL — GPU-Accelerated Reinforcement Learning Primitives (Vol.9)
2//!
3//! `oxicuda-rl` provides a comprehensive set of GPU-ready RL building blocks:
4//!
5//! ## Replay Buffers
6//!
7//! * [`buffer::UniformReplayBuffer`] — fixed-capacity circular buffer with
8//!   uniform random sampling (DQN, SAC, TD3).
9//! * [`buffer::PrioritizedReplayBuffer`] — segment-tree PER with IS weight
10//!   computation (PER-DQN, PER-SAC).
11//! * [`buffer::NStepBuffer`] — n-step return accumulation with configurable
12//!   discount and episode-boundary handling.
13//!
14//! ## Policy Distributions
15//!
16//! * [`policy::CategoricalPolicy`] — discrete actions with Gumbel-max
17//!   sampling, log-probability, entropy, KL-divergence.
18//! * [`policy::GaussianPolicy`] — diagonal Gaussian for continuous actions
19//!   with reparameterisation trick and optional Tanh squashing (SAC).
20//! * [`policy::DeterministicPolicy`] — DDPG/TD3 with exploration noise and
21//!   TD3 target policy smoothing.
22//!
23//! ## Return / Advantage Estimators
24//!
25//! * [`estimator::compute_gae`] — GAE advantages and value targets (PPO, A3C).
26//! * [`estimator::compute_td_lambda`] — TD(λ) multi-step returns.
27//! * [`estimator::compute_vtrace`] — V-trace off-policy correction (IMPALA).
28//! * [`estimator::compute_retrace`] — Retrace(λ) safe off-policy Q-targets.
29//!
30//! ## Loss Functions
31//!
32//! * [`loss::ppo_loss`] — PPO clip + value + entropy combined loss.
33//! * [`loss::dqn_loss`] / [`loss::double_dqn_loss`] — Bellman MSE / Huber.
34//! * [`loss::sac_critic_loss`] / [`loss::sac_actor_loss`] — SAC soft Q and
35//!   policy losses with automatic temperature tuning.
36//! * [`loss::td3_critic_loss`] / [`loss::td3_actor_loss`] — TD3 twin-Q critic
37//!   and deterministic actor losses.
38//!
39//! ## Normalization
40//!
41//! * [`normalize::ObservationNormalizer`] — running mean/variance with clip.
42//! * [`normalize::RewardNormalizer`] — return-based or clip normalization.
43//! * [`normalize::RunningStats`] — Welford online statistics tracker.
44//!
45//! ## Environment Abstractions
46//!
47//! * [`env::Env`] — standard RL environment trait (`reset`, `step`).
48//! * [`env::VecEnv`] — vectorized multi-environment wrapper with auto-reset.
49//! * [`env::env::LinearQuadraticEnv`] — reference LQ environment for testing.
50//!
51//! ## PTX Kernels
52//!
53//! * [`ptx_kernels`] — GPU PTX source strings for TD-error, PPO ratio, SAC
54//!   target, PER IS weight computation, and advantage normalisation.
55//!
56//! ## Quick Start
57//!
58//! ```rust
59//! use oxicuda_rl::buffer::UniformReplayBuffer;
60//! use oxicuda_rl::policy::CategoricalPolicy;
61//! use oxicuda_rl::estimator::{GaeConfig, compute_gae};
62//! use oxicuda_rl::loss::{PpoConfig, ppo_loss};
63//! use oxicuda_rl::handle::RlHandle;
64//!
65//! // Set up replay buffer
66//! let mut buf = UniformReplayBuffer::new(10_000, 8, 4);
67//! let mut handle = RlHandle::default_handle();
68//!
69//! // Push some experience
70//! for i in 0..100_usize {
71//!     buf.push(
72//!         vec![i as f32; 8],
73//!         vec![0.0_f32; 4],
74//!         1.0,
75//!         vec![i as f32 + 1.0; 8],
76//!         false,
77//!     );
78//! }
79//!
80//! // Sample a mini-batch
81//! let batch = buf.sample(32, &mut handle).unwrap();
82//! assert_eq!(batch.len(), 32);
83//!
84//! // Compute GAE for a 5-step rollout
85//! let rewards    = vec![1.0_f32; 5];
86//! let values     = vec![0.5_f32; 5];
87//! let next_vals  = vec![0.5_f32; 5];
88//! let dones      = vec![0.0_f32; 5];
89//! let gae = compute_gae(&rewards, &values, &next_vals, &dones, GaeConfig::default()).unwrap();
90//! assert_eq!(gae.advantages.len(), 5);
91//! ```
92//!
93//! (C) 2026 COOLJAPAN OU (Team KitaSan)
94
95#![warn(missing_docs)]
96#![warn(clippy::all)]
97#![allow(clippy::module_name_repetitions)]
98#![allow(clippy::module_inception)]
99#![allow(clippy::wildcard_imports)]
100
101// ─── Public modules ──────────────────────────────────────────────────────────
102
103/// Error types and result alias.
104pub mod error;
105
106/// RL session handle: SM version, device info, seeded RNG.
107pub mod handle;
108
109/// PTX kernel sources for GPU-accelerated RL operations.
110pub mod ptx_kernels;
111
112/// Experience replay buffers.
113pub mod buffer;
114
115/// Policy distributions for discrete and continuous action spaces.
116pub mod policy;
117
118/// Return and advantage estimators.
119pub mod estimator;
120
121/// RL algorithm loss functions.
122pub mod loss;
123
124/// Observation and reward normalization.
125pub mod normalize;
126
127/// Environment abstractions.
128pub mod env;
129
130// ─── Re-exports ───────────────────────────────────────────────────────────────
131
132pub use error::{RlError, RlResult};
133
134/// Convenience prelude: imports the most commonly used types.
135pub mod prelude {
136    pub use crate::buffer::{
137        NStepBuffer, NStepTransition, PrioritizedReplayBuffer, PrioritySample, Transition,
138        UniformReplayBuffer,
139    };
140    pub use crate::env::env::{Env, EnvInfo, LinearQuadraticEnv, StepResult};
141    pub use crate::env::vectorized::{VecEnv, VecStepResult};
142    pub use crate::error::{RlError, RlResult};
143    pub use crate::estimator::{
144        GaeConfig, RetraceConfig, TdConfig, VtraceConfig, VtraceOutput, compute_gae,
145        compute_retrace, compute_td_lambda, compute_vtrace,
146    };
147    pub use crate::handle::{LcgRng, RlHandle, SmVersion};
148    pub use crate::loss::{
149        DqnConfig, DqnLoss, PpoConfig, PpoLoss, SacConfig, SacLoss, Td3Config, Td3Loss,
150        double_dqn_loss, dqn_loss, ppo_loss, sac_actor_loss, sac_critic_loss, sac_temperature_loss,
151        td3_actor_loss, td3_critic_loss,
152    };
153    pub use crate::normalize::{ObservationNormalizer, RewardNormalizer, RunningStats};
154    pub use crate::policy::{
155        CategoricalPolicy, DeterministicPolicy, GaussianPolicy, deterministic::OrnsteinUhlenbeck,
156    };
157}
158
159// ─── Integration tests ────────────────────────────────────────────────────────
160
161#[cfg(test)]
162mod tests {
163    use super::prelude::*;
164
165    /// End-to-end DQN-style training loop simulation.
166    #[test]
167    fn e2e_dqn_style_loop() {
168        let obs_dim = 4;
169        let n_actions = 2;
170        let mut buf = UniformReplayBuffer::new(1000, obs_dim, 1);
171        let mut handle = RlHandle::default_handle();
172        let mut env = LinearQuadraticEnv::new(obs_dim, 200);
173        let policy = CategoricalPolicy::new(n_actions);
174
175        let mut obs = env.reset().unwrap();
176        // Collect 200 transitions
177        for _ in 0..200 {
178            // Dummy logits
179            let logits = obs.iter().take(n_actions).copied().collect::<Vec<_>>();
180            let probs = policy.softmax(&logits).unwrap();
181            let _action = policy.sample_action(&probs, &mut handle).unwrap();
182            let result = env.step(&[0.0; 4]).unwrap();
183            buf.push(
184                obs.clone(),
185                vec![_action as f32],
186                result.reward,
187                result.obs.clone(),
188                result.done,
189            );
190            obs = if result.done {
191                env.reset().unwrap()
192            } else {
193                result.obs
194            };
195        }
196        assert!(buf.len() >= 32, "should have enough transitions");
197
198        // Sample and compute loss
199        let batch = buf.sample(32, &mut handle).unwrap();
200        let q_sa: Vec<f32> = batch.iter().map(|t| t.reward).collect();
201        let rewards: Vec<f32> = batch.iter().map(|t| t.reward).collect();
202        let max_q_next: Vec<f32> = batch.iter().map(|_| 0.0).collect();
203        let dones: Vec<f32> = batch
204            .iter()
205            .map(|t| if t.done { 1.0 } else { 0.0 })
206            .collect();
207        let is_w = vec![1.0_f32; 32];
208        let l = dqn_loss(
209            &q_sa,
210            &rewards,
211            &max_q_next,
212            &dones,
213            &is_w,
214            DqnConfig::default(),
215        )
216        .unwrap();
217        assert!(l.loss.is_finite(), "DQN loss should be finite");
218    }
219
220    /// End-to-end PPO-style advantage computation + loss.
221    #[test]
222    fn e2e_ppo_gae_loss() {
223        let t = 128;
224        let rewards: Vec<f32> = (0..t)
225            .map(|i| if i % 10 == 9 { -1.0 } else { 0.1 })
226            .collect();
227        let values: Vec<f32> = vec![0.5; t];
228        let next_vals: Vec<f32> = vec![0.5; t];
229        let dones: Vec<f32> = (0..t)
230            .map(|i| if i % 10 == 9 { 1.0 } else { 0.0 })
231            .collect();
232
233        let gae = compute_gae(&rewards, &values, &next_vals, &dones, GaeConfig::default()).unwrap();
234        assert_eq!(gae.advantages.len(), t);
235
236        // Simulate PPO mini-batch update
237        let lp_new = vec![-0.693_f32; t]; // ln(0.5)
238        let lp_old = vec![-0.693_f32; t];
239        let vp = vec![0.5_f32; t];
240        let ent = vec![0.693_f32; t];
241        let ovp = vec![0.5_f32; t];
242        let l = ppo_loss(
243            &lp_new,
244            &lp_old,
245            &gae.advantages,
246            &vp,
247            &gae.returns,
248            &ent,
249            &ovp,
250            PpoConfig::default(),
251        )
252        .unwrap();
253        assert!(
254            l.total.is_finite(),
255            "PPO loss should be finite: {}",
256            l.total
257        );
258        assert!(l.clip_fraction >= 0.0 && l.clip_fraction <= 1.0);
259    }
260
261    /// End-to-end SAC-style off-policy update.
262    #[test]
263    fn e2e_sac_style_update() {
264        let mut buf = PrioritizedReplayBuffer::new(256, 8, 2, 0.6, 0.4);
265        let mut handle = RlHandle::default_handle();
266        for i in 0..256_usize {
267            buf.push(
268                vec![i as f32 * 0.01; 8],
269                vec![0.1_f32; 2],
270                (i % 5) as f32 * 0.2,
271                vec![(i + 1) as f32 * 0.01; 8],
272                i % 20 == 19,
273            );
274        }
275        let batch = buf.sample(32, &mut handle).unwrap();
276        let q: Vec<f32> = batch.iter().map(|s| s.transition.reward).collect();
277        let r: Vec<f32> = batch.iter().map(|s| s.transition.reward).collect();
278        let d: Vec<f32> = batch
279            .iter()
280            .map(|s| if s.transition.done { 1.0 } else { 0.0 })
281            .collect();
282        let min_qn = vec![0.5_f32; 32];
283        let lp_next = vec![-0.5_f32; 32];
284        let is_w: Vec<f32> = batch.iter().map(|s| s.weight).collect();
285        let (cl, _) =
286            sac_critic_loss(&q, &r, &d, &min_qn, &lp_next, &is_w, SacConfig::default()).unwrap();
287        assert!(cl.is_finite(), "SAC critic loss should be finite");
288    }
289
290    /// VecEnv with observation normalization.
291    #[test]
292    fn e2e_vecenv_with_obs_norm() {
293        let envs: Vec<_> = (0..4).map(|_| LinearQuadraticEnv::new(3, 50)).collect();
294        let mut ve = VecEnv::new(envs);
295        let mut norm = ObservationNormalizer::new(3);
296        let init_obs = ve.reset_all().unwrap();
297        for chunk in init_obs.chunks_exact(3) {
298            norm.process_one(chunk).unwrap();
299        }
300        let actions = vec![0.01_f32; 4 * 3];
301        for _ in 0..20 {
302            let result = ve.step(&actions).unwrap();
303            for chunk in result.obs.chunks_exact(3) {
304                let _norm_obs = norm.process_one(chunk).unwrap();
305            }
306        }
307        assert!(norm.count() > 0);
308    }
309
310    /// N-step buffer integration.
311    #[test]
312    fn e2e_n_step_buffer() {
313        let mut nsbuf = NStepBuffer::new(3, 0.99);
314        let mut transitions = Vec::new();
315        for i in 0..20_usize {
316            if let Some(t) = nsbuf.push([i as f32], [0.0], 1.0, [(i + 1) as f32], false) {
317                transitions.push(t);
318            }
319        }
320        // Should have transitions after first 3 steps
321        assert!(!transitions.is_empty(), "n-step should produce transitions");
322        for t in &transitions {
323            assert_eq!(t.actual_n, 3);
324            // R = 1 + 0.99 + 0.99^2 ≈ 2.9701
325            assert!(
326                (t.n_step_return - (1.0 + 0.99 + 0.99_f32 * 0.99)).abs() < 0.01,
327                "n_step_return={}",
328                t.n_step_return
329            );
330        }
331    }
332}