stateset-rl-core
High-performance Rust implementations of reinforcement learning operations, with optional Python bindings via PyO3.
Features
- GAE (Generalized Advantage Estimation) - Fast, parallel GAE computation
- Advantage Computation - Group-relative advantages for GRPO/PPO training
- Reward Normalization - Welford's algorithm for numerically stable online normalization
- GSPO Support - Sequence-level importance ratios and clipping for GSPO
- PPO Surrogate - Clipped surrogate objective computation
- Parallel Processing - Automatic parallelization via Rayon
Installation
As a Rust crate
[]
= "0.1"
As a Python extension
Usage
Rust
use ;
// Compute GAE
let rewards = vec!;
let values = vec!; // n+1 values for bootstrap
let advantages = compute_gae_internal;
// Compute group-relative advantages
let group_rewards = vec!;
let advantages = compute_advantages_for_group;
Python
# Compute GAE
=
=
=
# Batch GAE (parallel)
=
=
=
# Group-relative advantages for GRPO
= # 16 groups, 4 samples each
=
# Reward normalization with running stats
=
, , , =
# GSPO importance ratios
=
=
=
=
# PPO surrogate objective
=
=
=
API Reference
GAE Functions
compute_gae(rewards, values, gamma=0.99, gae_lambda=0.95)- Single trajectory GAEbatch_compute_gae(all_rewards, all_values, gamma=0.99, gae_lambda=0.95)- Parallel batch GAE
Advantage Functions
compute_group_advantages(rewards_2d, baseline_type, normalize)- GRPO-style group advantagesbaseline_type:"mean","median", or"min"
Reward Functions
normalize_rewards(rewards, running_mean=0, running_var=1, count=0, epsilon=1e-8)- Online normalizationclip_rewards(rewards, min_val, max_val)- Reward clippingcompute_reward_statistics(rewards)- Compute mean, std, min, max, median
Policy Gradient Functions
compute_gspo_importance_ratios(log_probs_new, log_probs_old, sequence_lengths)- GSPO ratiosapply_gspo_clipping(ratios, advantages, clip_left=3e-4, clip_right=4e-4)- GSPO clippingcompute_ppo_surrogate(ratios, advantages, clip_epsilon=0.2)- PPO clipped objective
Performance
This crate is optimized for performance:
- LTO enabled - Link-time optimization for maximum speed
- Single codegen unit - Better optimization opportunities
- Rayon parallelization - Automatic multi-threading for batch operations
- Zero-copy Python interop - Minimal overhead when called from Python
Typical speedups over pure Python/NumPy: 10-100x for batch operations.
License
MIT