RLKit
A deep reinforcement learning library based on Rust and Candle, providing complete implementations of Q-Learning and DQN algorithms, supporting custom environments, various policy choices, and flexible training configurations. Future support will include more reinforcement learning algorithms, such as DDPG, PPO, A2C, etc.
Key Characteristics
- Implements classic Q-Learning and DQN algorithms
- Supports experience replay buffer
- Supports target network updates (for DQN)
- Provides multiple action selection policies (ε-greedy, Boltzmann, Ornstein-Uhlenbeck, Gaussian noise, etc.)
- Implemented based on the Candle library, supporting CPU and CUDA GPU computation
- Generic design, supporting different types of state and action spaces
- Complete error handling mechanism
- Detailed test cases
Features
cuda: Enables CUDA GPU computation support
Project Structure
src/
├── algs/ # Reinforcement learning algorithm implementations
│ ├── dqn.rs # DQN algorithm
│ ├── q_learning.rs # Q-Learning algorithm
│ └── mod.rs # Algorithm interfaces and common components
├── network.rs # Neural network implementation (for DQN)
├── policies.rs # Action selection policies
├── replay_buffer.rs # Experience replay buffer
├── types.rs # Core type definitions (state, action, environment interface, etc.)
└── utils.rs # Utility functions
Installation
Add this library to your Rust project:
[]
= { = "0.0.3", = ["cuda"] }
Core Components
1. Environment Interface (EnvTrait)
All reinforcement learning environments must implement the EnvTrait interface:
use ;
2. Algorithm Usage
Q-Learning Example
use ;
use EpsilonGreedy;
use ;
// Create the environment
let mut env = new;
// Create a Q-Learning algorithm instance
let mut q_learning = new.unwrap;
// Create an ε-greedy policy
let mut policy = new;
// Configure training parameters
let train_args = TrainArgs ;
// Train the model
q_learning.train.unwrap;
// Get an action using the trained model
let state = env.reset;
let action = q_learning.get_action.unwrap;
DQN Example
use Device;
use ;
use DNQStateMode;
use EpsilonGreedy;
// Create the environment
let mut env = new;
// Select the computation device (CPU or GPU)
let device = Cpu;
// Or use GPU: let device = Device::new_cuda(0).unwrap();
// Create a DQN algorithm instance
let mut dqn = DQNnew.unwrap;
// Create a policy
let mut policy = new;
// Configure training parameters
let train_args = TrainArgs ;
// Train the model
dqn.train.unwrap;
3. Policy Selection
The library provides multiple action selection policies:
use ;
// ε-greedy policy
let mut epsilon_greedy = new;
// Boltzmann policy
let mut boltzmann = new;
Usage Examples
The library includes two complete examples:
-
Grid World - Located in
examples/grid_world-example/- Implements a grid navigation environment with obstacles
- Demonstrates the usage of both Q-Learning and DQN algorithms
- Includes path visualization functionality
-
Catch Rabbit - Located in
examples/catch_rabbit.rs- Implements a multi-agent capture environment on a graph structure
- Demonstrates the application of the Q-Learning algorithm in a complex environment
Running the Grid World Example
Then follow the prompts to select the algorithm to use (1: Q-Learning, 2: DQN).
Environment Implementation Guide
To implement a custom environment, you need to:
- Create an environment struct to store the environment state
- Implement all methods of the
EnvTraitinterface - Define an appropriate reward function and termination conditions
Detailed examples can be found in examples/grid_world-example/env.rs and examples/catch_rabbit.rs.
Training Parameter Explanation
The TrainArgs struct includes the following parameters:
epochs: Number of training epochsmax_steps: Maximum number of steps per epochbatch_size: Number of samples for batch updateslearning_rate: Learning rategamma: Discount factorupdate_freq: Frequency of network updates (every how many steps)update_interval: Interval for target network updates (only used in DQN)
License
This project is licensed under the MIT License. For more details, see the LICENSE file.