Expand description
Reinforcement learning algorithms for the Burn ML framework.
rl4burn provides generic, backend-agnostic RL building blocks that
exploit Burn’s type system: write PPO<B: AutodiffBackend> once and
run on WGPU, CUDA, NdArray, or LibTorch.
§Modules
- [
env] — Environment trait, spaces, vectorized environments, wrappers envs— Built-in environments (CartPole)algo— Algorithms (PPO, DQN)nn— Neural network utilities (init, gradient clipping, polyak, losses, policy traits)collect— Data collection (GAE, V-trace, UPGO, replay buffer, advantage normalization)
Re-exports§
pub use env::adapter::DiscreteEnvAdapter;pub use env::space::Space;pub use env::vec_env::SyncVecEnv;pub use env::wrapper;pub use env::Env;pub use env::Step;pub use algo::dqn::dqn_update;pub use algo::dqn::epsilon_greedy;pub use algo::dqn::epsilon_schedule;pub use algo::dqn::DqnConfig;pub use algo::dqn::DqnStats;pub use algo::dqn::QNetwork;pub use algo::dqn::Transition;pub use algo::ppo::ppo_collect;pub use algo::ppo::ppo_update;pub use algo::ppo::PpoConfig;pub use algo::ppo::PpoRollout;pub use algo::ppo::PpoStats;pub use algo::ppo_masked::masked_ppo_collect;pub use algo::ppo_masked::masked_ppo_update;pub use algo::ppo_masked::MaskedActorCritic;pub use algo::ppo_masked::MaskedPpoRollout;pub use algo::behavioral_cloning::bc_loss_discrete;pub use algo::behavioral_cloning::bc_loss_multi_head;pub use algo::behavioral_cloning::bc_step;pub use algo::distillation::distillation_loss;pub use algo::distillation::value_distillation_loss;pub use algo::distillation::DistillationConfig;pub use algo::cspl::CsplConfig;pub use algo::cspl::CsplPhase;pub use algo::cspl::CsplPipeline;pub use algo::league::AgentRole;pub use algo::league::League;pub use algo::league::LeagueAgentConfig;pub use algo::multi_agent::batch_multi_agent_obs;pub use algo::multi_agent::broadcast_team_reward;pub use algo::multi_agent::unbatch_actions;pub use algo::multi_agent::MultiAgentRolloutData;pub use algo::pfsp::PfspConfig;pub use algo::pfsp::PfspMatchmaking;pub use algo::pfsp::PlayerRecord;pub use algo::privileged_critic::make_critic_input;pub use algo::privileged_critic::PrivilegedActorCritic;pub use algo::self_play::branch_agent;pub use algo::self_play::SelfPlayPool;pub use algo::z_conditioning::z_reward;pub use algo::z_conditioning::ZConditioning;pub use algo::z_conditioning::ZConditioningConfig;pub use algo::imagination::imagine_rollout;pub use algo::imagination::lambda_returns;pub use algo::imagination::ImaginedTrajectory;pub use algo::mcts::MctsConfig;pub use algo::mcts::MctsTree;pub use algo::distributed::DistributedConfig;pub use algo::distributed::GradientSync;pub use algo::distributed::LocalSync;pub use algo::distributed::ReduceStrategy;pub use algo::distributed::scale_gradients;pub use nn::autoregressive::ActionHead;pub use nn::autoregressive::CompositeDistribution;pub use nn::dist::ActionDist;pub use nn::dist::LogStdMode;pub use nn::attention::AttentionPool;pub use nn::attention::AttentionPoolConfig;pub use nn::attention::MultiHeadAttention;pub use nn::attention::MultiHeadAttentionConfig;pub use nn::attention::PointerNet;pub use nn::attention::PointerNetConfig;pub use nn::attention::TargetAttention;pub use nn::attention::TargetAttentionConfig;pub use nn::attention::TransformerBlock;pub use nn::attention::TransformerBlockConfig;pub use nn::attention::TransformerEncoder;pub use nn::attention::TransformerEncoderConfig;pub use nn::clip::clip_grad_norm;pub use nn::film::Film;pub use nn::film::FilmConfig;pub use nn::init::orthogonal_linear;pub use nn::kl_balance::categorical_kl;pub use nn::kl_balance::categorical_kl_groups;pub use nn::kl_balance::kl_balanced_loss;pub use nn::kl_balance::kl_balanced_loss_groups;pub use nn::kl_balance::KlBalanceConfig;pub use nn::loss::policy_loss_continuous;pub use nn::loss::policy_loss_discrete;pub use nn::loss::value_loss;pub use nn::rnn::BlockGruCell;pub use nn::rnn::BlockGruCellConfig;pub use nn::rnn::GruCell;pub use nn::rnn::GruCellConfig;pub use nn::rnn::LstmCell;pub use nn::rnn::LstmCellConfig;pub use nn::rnn::LstmState;pub use nn::rssm::Rssm;pub use nn::rssm::RssmConfig;pub use nn::rssm::RssmState;pub use nn::policy::greedy_action;pub use nn::policy::DiscreteAcOutput;pub use nn::policy::DiscreteActorCritic;pub use nn::symlog::symexp;pub use nn::symlog::symlog;pub use nn::symlog::TwohotEncoder;pub use nn::vae::BetaVae;pub use nn::vae::BetaVaeConfig;pub use nn::vae::VaeOutput;pub use env::render::Renderable;pub use env::render::RgbFrame;pub use nn::multi_head_value::multi_head_gae;pub use nn::multi_head_value::multi_head_value_loss;pub use nn::multi_head_value::MultiHeadGaeResult;pub use nn::multi_head_value::MultiHeadValueConfig;pub use nn::polyak::polyak_update;pub use collect::advantage::normalize;pub use collect::gae::gae;pub use collect::intrinsic::combine_rewards;pub use collect::intrinsic::CountBasedReward;pub use collect::intrinsic::EntropyReductionReward;pub use collect::intrinsic::IntrinsicReward;pub use collect::percentile_normalize::PercentileNormalizer;pub use collect::replay::ReplayBuffer;pub use collect::sequence_replay::SequenceReplayBuffer;pub use collect::sequence_replay::SequenceStep;pub use collect::upgo::upgo as upgo_advantages;pub use collect::vtrace::vtrace_targets;pub use log::CompositeLogger;pub use log::Loggable;pub use log::Logger;pub use log::NoopLogger;pub use log::PrintLogger;
Modules§
- algo
- RL algorithms (PPO, DQN).
- collect
- Data collection and advantage estimation.
- env
- Environment abstractions: trait, spaces, vectorized envs, wrappers. Environment trait and step result type.
- envs
- Built-in environments (CartPole). Built-in environments for testing and benchmarking.
- log
- Logging infrastructure for training metrics. Logging infrastructure for training metrics.
- nn
- Neural network utilities for RL.