oxicuda_rl/lib.rs
1//! # OxiCUDA-RL — GPU-Accelerated Reinforcement Learning Primitives (Vol.9)
2//!
3//! `oxicuda-rl` provides a comprehensive set of GPU-ready RL building blocks:
4//!
5//! ## Replay Buffers
6//!
7//! * [`buffer::UniformReplayBuffer`] — fixed-capacity circular buffer with
8//! uniform random sampling (DQN, SAC, TD3).
9//! * [`buffer::PrioritizedReplayBuffer`] — segment-tree PER with IS weight
10//! computation (PER-DQN, PER-SAC).
11//! * [`buffer::NStepBuffer`] — n-step return accumulation with configurable
12//! discount and episode-boundary handling.
13//!
14//! ## Policy Distributions
15//!
16//! * [`policy::CategoricalPolicy`] — discrete actions with Gumbel-max
17//! sampling, log-probability, entropy, KL-divergence.
18//! * [`policy::GaussianPolicy`] — diagonal Gaussian for continuous actions
19//! with reparameterisation trick and optional Tanh squashing (SAC).
20//! * [`policy::DeterministicPolicy`] — DDPG/TD3 with exploration noise and
21//! TD3 target policy smoothing.
22//!
23//! ## Return / Advantage Estimators
24//!
25//! * [`estimator::compute_gae`] — GAE advantages and value targets (PPO, A3C).
26//! * [`estimator::compute_td_lambda`] — TD(λ) multi-step returns.
27//! * [`estimator::compute_vtrace`] — V-trace off-policy correction (IMPALA).
28//! * [`estimator::compute_retrace`] — Retrace(λ) safe off-policy Q-targets.
29//!
30//! ## Loss Functions
31//!
32//! * [`loss::ppo_loss`] — PPO clip + value + entropy combined loss.
33//! * [`loss::dqn_loss`] / [`loss::double_dqn_loss`] — Bellman MSE / Huber.
34//! * [`loss::sac_critic_loss`] / [`loss::sac_actor_loss`] — SAC soft Q and
35//! policy losses with automatic temperature tuning.
36//! * [`loss::td3_critic_loss`] / [`loss::td3_actor_loss`] — TD3 twin-Q critic
37//! and deterministic actor losses.
38//!
39//! ## Normalization
40//!
41//! * [`normalize::ObservationNormalizer`] — running mean/variance with clip.
42//! * [`normalize::RewardNormalizer`] — return-based or clip normalization.
43//! * [`normalize::RunningStats`] — Welford online statistics tracker.
44//!
45//! ## Environment Abstractions
46//!
47//! * [`env::Env`] — standard RL environment trait (`reset`, `step`).
48//! * [`env::VecEnv`] — vectorized multi-environment wrapper with auto-reset.
49//! * [`env::env::LinearQuadraticEnv`] — reference LQ environment for testing.
50//!
51//! ## PTX Kernels
52//!
53//! * [`ptx_kernels`] — GPU PTX source strings for TD-error, PPO ratio, SAC
54//! target, PER IS weight computation, and advantage normalisation.
55//!
56//! ## Quick Start
57//!
58//! ```rust
59//! use oxicuda_rl::buffer::UniformReplayBuffer;
60//! use oxicuda_rl::policy::CategoricalPolicy;
61//! use oxicuda_rl::estimator::{GaeConfig, compute_gae};
62//! use oxicuda_rl::loss::{PpoConfig, ppo_loss};
63//! use oxicuda_rl::handle::RlHandle;
64//!
65//! // Set up replay buffer
66//! let mut buf = UniformReplayBuffer::new(10_000, 8, 4);
67//! let mut handle = RlHandle::default_handle();
68//!
69//! // Push some experience
70//! for i in 0..100_usize {
71//! buf.push(
72//! vec![i as f32; 8],
73//! vec![0.0_f32; 4],
74//! 1.0,
75//! vec![i as f32 + 1.0; 8],
76//! false,
77//! );
78//! }
79//!
80//! // Sample a mini-batch
81//! let batch = buf.sample(32, &mut handle).unwrap();
82//! assert_eq!(batch.len(), 32);
83//!
84//! // Compute GAE for a 5-step rollout
85//! let rewards = vec![1.0_f32; 5];
86//! let values = vec![0.5_f32; 5];
87//! let next_vals = vec![0.5_f32; 5];
88//! let dones = vec![0.0_f32; 5];
89//! let gae = compute_gae(&rewards, &values, &next_vals, &dones, GaeConfig::default()).unwrap();
90//! assert_eq!(gae.advantages.len(), 5);
91//! ```
92//!
93//! (C) 2026 COOLJAPAN OU (Team KitaSan)
94
95#![warn(missing_docs)]
96#![warn(clippy::all)]
97#![allow(clippy::module_name_repetitions)]
98#![allow(clippy::module_inception)]
99#![allow(clippy::wildcard_imports)]
100
101// ─── Public modules ──────────────────────────────────────────────────────────
102
103/// Error types and result alias.
104pub mod error;
105
106/// RL session handle: SM version, device info, seeded RNG.
107pub mod handle;
108
109/// PTX kernel sources for GPU-accelerated RL operations.
110pub mod ptx_kernels;
111
112/// Experience replay buffers.
113pub mod buffer;
114
115/// Policy distributions for discrete and continuous action spaces.
116pub mod policy;
117
118/// Return and advantage estimators.
119pub mod estimator;
120
121/// RL algorithm loss functions.
122pub mod loss;
123
124/// Observation and reward normalization.
125pub mod normalize;
126
127/// Environment abstractions.
128pub mod env;
129
130// ─── Re-exports ───────────────────────────────────────────────────────────────
131
132pub use error::{RlError, RlResult};
133
134/// Convenience prelude: imports the most commonly used types.
135pub mod prelude {
136 pub use crate::buffer::{
137 NStepBuffer, NStepTransition, PrioritizedReplayBuffer, PrioritySample, Transition,
138 UniformReplayBuffer,
139 };
140 pub use crate::env::env::{Env, EnvInfo, LinearQuadraticEnv, StepResult};
141 pub use crate::env::vectorized::{VecEnv, VecStepResult};
142 pub use crate::error::{RlError, RlResult};
143 pub use crate::estimator::{
144 GaeConfig, RetraceConfig, TdConfig, VtraceConfig, VtraceOutput, compute_gae,
145 compute_retrace, compute_td_lambda, compute_vtrace,
146 };
147 pub use crate::handle::{LcgRng, RlHandle, SmVersion};
148 pub use crate::loss::{
149 DqnConfig, DqnLoss, PpoConfig, PpoLoss, SacConfig, SacLoss, Td3Config, Td3Loss,
150 double_dqn_loss, dqn_loss, ppo_loss, sac_actor_loss, sac_critic_loss, sac_temperature_loss,
151 td3_actor_loss, td3_critic_loss,
152 };
153 pub use crate::normalize::{ObservationNormalizer, RewardNormalizer, RunningStats};
154 pub use crate::policy::{
155 CategoricalPolicy, DeterministicPolicy, GaussianPolicy, deterministic::OrnsteinUhlenbeck,
156 };
157}
158
159// ─── Integration tests ────────────────────────────────────────────────────────
160
161#[cfg(test)]
162mod tests {
163 use super::prelude::*;
164
165 /// End-to-end DQN-style training loop simulation.
166 #[test]
167 fn e2e_dqn_style_loop() {
168 let obs_dim = 4;
169 let n_actions = 2;
170 let mut buf = UniformReplayBuffer::new(1000, obs_dim, 1);
171 let mut handle = RlHandle::default_handle();
172 let mut env = LinearQuadraticEnv::new(obs_dim, 200);
173 let policy = CategoricalPolicy::new(n_actions);
174
175 let mut obs = env.reset().unwrap();
176 // Collect 200 transitions
177 for _ in 0..200 {
178 // Dummy logits
179 let logits = obs.iter().take(n_actions).copied().collect::<Vec<_>>();
180 let probs = policy.softmax(&logits).unwrap();
181 let _action = policy.sample_action(&probs, &mut handle).unwrap();
182 let result = env.step(&[0.0; 4]).unwrap();
183 buf.push(
184 obs.clone(),
185 vec![_action as f32],
186 result.reward,
187 result.obs.clone(),
188 result.done,
189 );
190 obs = if result.done {
191 env.reset().unwrap()
192 } else {
193 result.obs
194 };
195 }
196 assert!(buf.len() >= 32, "should have enough transitions");
197
198 // Sample and compute loss
199 let batch = buf.sample(32, &mut handle).unwrap();
200 let q_sa: Vec<f32> = batch.iter().map(|t| t.reward).collect();
201 let rewards: Vec<f32> = batch.iter().map(|t| t.reward).collect();
202 let max_q_next: Vec<f32> = batch.iter().map(|_| 0.0).collect();
203 let dones: Vec<f32> = batch
204 .iter()
205 .map(|t| if t.done { 1.0 } else { 0.0 })
206 .collect();
207 let is_w = vec![1.0_f32; 32];
208 let l = dqn_loss(
209 &q_sa,
210 &rewards,
211 &max_q_next,
212 &dones,
213 &is_w,
214 DqnConfig::default(),
215 )
216 .unwrap();
217 assert!(l.loss.is_finite(), "DQN loss should be finite");
218 }
219
220 /// End-to-end PPO-style advantage computation + loss.
221 #[test]
222 fn e2e_ppo_gae_loss() {
223 let t = 128;
224 let rewards: Vec<f32> = (0..t)
225 .map(|i| if i % 10 == 9 { -1.0 } else { 0.1 })
226 .collect();
227 let values: Vec<f32> = vec![0.5; t];
228 let next_vals: Vec<f32> = vec![0.5; t];
229 let dones: Vec<f32> = (0..t)
230 .map(|i| if i % 10 == 9 { 1.0 } else { 0.0 })
231 .collect();
232
233 let gae = compute_gae(&rewards, &values, &next_vals, &dones, GaeConfig::default()).unwrap();
234 assert_eq!(gae.advantages.len(), t);
235
236 // Simulate PPO mini-batch update
237 let lp_new = vec![-0.693_f32; t]; // ln(0.5)
238 let lp_old = vec![-0.693_f32; t];
239 let vp = vec![0.5_f32; t];
240 let ent = vec![0.693_f32; t];
241 let ovp = vec![0.5_f32; t];
242 let l = ppo_loss(
243 &lp_new,
244 &lp_old,
245 &gae.advantages,
246 &vp,
247 &gae.returns,
248 &ent,
249 &ovp,
250 PpoConfig::default(),
251 )
252 .unwrap();
253 assert!(
254 l.total.is_finite(),
255 "PPO loss should be finite: {}",
256 l.total
257 );
258 assert!(l.clip_fraction >= 0.0 && l.clip_fraction <= 1.0);
259 }
260
261 /// End-to-end SAC-style off-policy update.
262 #[test]
263 fn e2e_sac_style_update() {
264 let mut buf = PrioritizedReplayBuffer::new(256, 8, 2, 0.6, 0.4);
265 let mut handle = RlHandle::default_handle();
266 for i in 0..256_usize {
267 buf.push(
268 vec![i as f32 * 0.01; 8],
269 vec![0.1_f32; 2],
270 (i % 5) as f32 * 0.2,
271 vec![(i + 1) as f32 * 0.01; 8],
272 i % 20 == 19,
273 );
274 }
275 let batch = buf.sample(32, &mut handle).unwrap();
276 let q: Vec<f32> = batch.iter().map(|s| s.transition.reward).collect();
277 let r: Vec<f32> = batch.iter().map(|s| s.transition.reward).collect();
278 let d: Vec<f32> = batch
279 .iter()
280 .map(|s| if s.transition.done { 1.0 } else { 0.0 })
281 .collect();
282 let min_qn = vec![0.5_f32; 32];
283 let lp_next = vec![-0.5_f32; 32];
284 let is_w: Vec<f32> = batch.iter().map(|s| s.weight).collect();
285 let (cl, _) =
286 sac_critic_loss(&q, &r, &d, &min_qn, &lp_next, &is_w, SacConfig::default()).unwrap();
287 assert!(cl.is_finite(), "SAC critic loss should be finite");
288 }
289
290 /// VecEnv with observation normalization.
291 #[test]
292 fn e2e_vecenv_with_obs_norm() {
293 let envs: Vec<_> = (0..4).map(|_| LinearQuadraticEnv::new(3, 50)).collect();
294 let mut ve = VecEnv::new(envs);
295 let mut norm = ObservationNormalizer::new(3);
296 let init_obs = ve.reset_all().unwrap();
297 for chunk in init_obs.chunks_exact(3) {
298 norm.process_one(chunk).unwrap();
299 }
300 let actions = vec![0.01_f32; 4 * 3];
301 for _ in 0..20 {
302 let result = ve.step(&actions).unwrap();
303 for chunk in result.obs.chunks_exact(3) {
304 let _norm_obs = norm.process_one(chunk).unwrap();
305 }
306 }
307 assert!(norm.count() > 0);
308 }
309
310 /// N-step buffer integration.
311 #[test]
312 fn e2e_n_step_buffer() {
313 let mut nsbuf = NStepBuffer::new(3, 0.99);
314 let mut transitions = Vec::new();
315 for i in 0..20_usize {
316 if let Some(t) = nsbuf.push([i as f32], [0.0], 1.0, [(i + 1) as f32], false) {
317 transitions.push(t);
318 }
319 }
320 // Should have transitions after first 3 steps
321 assert!(!transitions.is_empty(), "n-step should produce transitions");
322 for t in &transitions {
323 assert_eq!(t.actual_n, 3);
324 // R = 1 + 0.99 + 0.99^2 ≈ 2.9701
325 assert!(
326 (t.n_step_return - (1.0 + 0.99 + 0.99_f32 * 0.99)).abs() < 0.01,
327 "n_step_return={}",
328 t.n_step_return
329 );
330 }
331 }
332}