Skip to main content

Crate rl4burn

Crate rl4burn 

Source
Expand description

Reinforcement learning algorithms for the Burn ML framework.

rl4burn provides generic, backend-agnostic RL building blocks that exploit Burn’s type system: write PPO<B: AutodiffBackend> once and run on WGPU, CUDA, NdArray, or LibTorch.

§Modules

  • [env] — Environment trait, spaces, vectorized environments, wrappers
  • envs — Built-in environments (CartPole)
  • algo — Algorithms (PPO, DQN)
  • nn — Neural network utilities (init, gradient clipping, polyak, losses, policy traits)
  • collect — Data collection (GAE, V-trace, UPGO, replay buffer, advantage normalization)

Re-exports§

pub use env::adapter::DiscreteEnvAdapter;
pub use env::space::Space;
pub use env::vec_env::SyncVecEnv;
pub use env::wrapper;
pub use env::Env;
pub use env::Step;
pub use algo::dqn::dqn_update;
pub use algo::dqn::epsilon_greedy;
pub use algo::dqn::epsilon_schedule;
pub use algo::dqn::DqnConfig;
pub use algo::dqn::DqnStats;
pub use algo::dqn::QNetwork;
pub use algo::dqn::Transition;
pub use algo::ppo::ppo_collect;
pub use algo::ppo::ppo_update;
pub use algo::ppo::PpoConfig;
pub use algo::ppo::PpoRollout;
pub use algo::ppo::PpoStats;
pub use algo::ppo_masked::masked_ppo_collect;
pub use algo::ppo_masked::masked_ppo_update;
pub use algo::ppo_masked::MaskedActorCritic;
pub use algo::ppo_masked::MaskedPpoRollout;
pub use algo::behavioral_cloning::bc_loss_discrete;
pub use algo::behavioral_cloning::bc_loss_multi_head;
pub use algo::behavioral_cloning::bc_step;
pub use algo::distillation::distillation_loss;
pub use algo::distillation::value_distillation_loss;
pub use algo::distillation::DistillationConfig;
pub use algo::cspl::CsplConfig;
pub use algo::cspl::CsplPhase;
pub use algo::cspl::CsplPipeline;
pub use algo::league::AgentRole;
pub use algo::league::League;
pub use algo::league::LeagueAgentConfig;
pub use algo::multi_agent::batch_multi_agent_obs;
pub use algo::multi_agent::broadcast_team_reward;
pub use algo::multi_agent::unbatch_actions;
pub use algo::multi_agent::MultiAgentRolloutData;
pub use algo::pfsp::PfspConfig;
pub use algo::pfsp::PfspMatchmaking;
pub use algo::pfsp::PlayerRecord;
pub use algo::privileged_critic::make_critic_input;
pub use algo::privileged_critic::PrivilegedActorCritic;
pub use algo::self_play::branch_agent;
pub use algo::self_play::SelfPlayPool;
pub use algo::z_conditioning::z_reward;
pub use algo::z_conditioning::ZConditioning;
pub use algo::z_conditioning::ZConditioningConfig;
pub use algo::imagination::imagine_rollout;
pub use algo::imagination::lambda_returns;
pub use algo::imagination::ImaginedTrajectory;
pub use algo::mcts::MctsConfig;
pub use algo::mcts::MctsTree;
pub use algo::distributed::DistributedConfig;
pub use algo::distributed::GradientSync;
pub use algo::distributed::LocalSync;
pub use algo::distributed::ReduceStrategy;
pub use algo::distributed::scale_gradients;
pub use nn::autoregressive::ActionHead;
pub use nn::autoregressive::CompositeDistribution;
pub use nn::dist::ActionDist;
pub use nn::dist::LogStdMode;
pub use nn::attention::AttentionPool;
pub use nn::attention::AttentionPoolConfig;
pub use nn::attention::MultiHeadAttention;
pub use nn::attention::MultiHeadAttentionConfig;
pub use nn::attention::PointerNet;
pub use nn::attention::PointerNetConfig;
pub use nn::attention::TargetAttention;
pub use nn::attention::TargetAttentionConfig;
pub use nn::attention::TransformerBlock;
pub use nn::attention::TransformerBlockConfig;
pub use nn::attention::TransformerEncoder;
pub use nn::attention::TransformerEncoderConfig;
pub use nn::clip::clip_grad_norm;
pub use nn::film::Film;
pub use nn::film::FilmConfig;
pub use nn::init::orthogonal_linear;
pub use nn::kl_balance::categorical_kl;
pub use nn::kl_balance::categorical_kl_groups;
pub use nn::kl_balance::kl_balanced_loss;
pub use nn::kl_balance::kl_balanced_loss_groups;
pub use nn::kl_balance::KlBalanceConfig;
pub use nn::loss::policy_loss_continuous;
pub use nn::loss::policy_loss_discrete;
pub use nn::loss::value_loss;
pub use nn::rnn::BlockGruCell;
pub use nn::rnn::BlockGruCellConfig;
pub use nn::rnn::GruCell;
pub use nn::rnn::GruCellConfig;
pub use nn::rnn::LstmCell;
pub use nn::rnn::LstmCellConfig;
pub use nn::rnn::LstmState;
pub use nn::rssm::Rssm;
pub use nn::rssm::RssmConfig;
pub use nn::rssm::RssmState;
pub use nn::policy::greedy_action;
pub use nn::policy::DiscreteAcOutput;
pub use nn::policy::DiscreteActorCritic;
pub use nn::symlog::symexp;
pub use nn::symlog::symlog;
pub use nn::symlog::TwohotEncoder;
pub use nn::vae::BetaVae;
pub use nn::vae::BetaVaeConfig;
pub use nn::vae::VaeOutput;
pub use env::render::Renderable;
pub use env::render::RgbFrame;
pub use nn::multi_head_value::multi_head_gae;
pub use nn::multi_head_value::multi_head_value_loss;
pub use nn::multi_head_value::MultiHeadGaeResult;
pub use nn::multi_head_value::MultiHeadValueConfig;
pub use nn::polyak::polyak_update;
pub use collect::advantage::normalize;
pub use collect::gae::gae;
pub use collect::intrinsic::combine_rewards;
pub use collect::intrinsic::CountBasedReward;
pub use collect::intrinsic::EntropyReductionReward;
pub use collect::intrinsic::IntrinsicReward;
pub use collect::percentile_normalize::PercentileNormalizer;
pub use collect::replay::ReplayBuffer;
pub use collect::sequence_replay::SequenceReplayBuffer;
pub use collect::sequence_replay::SequenceStep;
pub use collect::upgo::upgo as upgo_advantages;
pub use collect::vtrace::vtrace_targets;
pub use log::CompositeLogger;
pub use log::Loggable;
pub use log::Logger;
pub use log::NoopLogger;
pub use log::PrintLogger;

Modules§

algo
RL algorithms (PPO, DQN).
collect
Data collection and advantage estimation.
env
Environment abstractions: trait, spaces, vectorized envs, wrappers. Environment trait and step result type.
envs
Built-in environments (CartPole). Built-in environments for testing and benchmarking.
log
Logging infrastructure for training metrics. Logging infrastructure for training metrics.
nn
Neural network utilities for RL.