# RLKit
A deep reinforcement learning library based on Rust and Candle, providing complete implementations of Q-Learning and DQN algorithms, supporting custom environments, various policy choices, and flexible training configurations. Future support will include more reinforcement learning algorithms, such as DDPG, PPO, A2C, etc.
## Key Characteristics
- Implements classic Q-Learning and DQN algorithms
- Supports experience replay buffer
- Supports target network updates (for DQN)
- Provides multiple action selection policies (ε-greedy, Boltzmann, Ornstein-Uhlenbeck, Gaussian noise, etc.)
- Implemented based on the Candle library, supporting CPU and CUDA GPU computation
- Generic design, supporting different types of state and action spaces
- Complete error handling mechanism
- Detailed test cases
## Features
- `cuda`: Enables CUDA GPU computation support
## Project Structure
```plaintext
src/
├── algs/ # Reinforcement learning algorithm implementations
│ ├── dqn.rs # DQN algorithm
│ ├── q_learning.rs # Q-Learning algorithm
│ └── mod.rs # Algorithm interfaces and common components
├── network.rs # Neural network implementation (for DQN)
├── policies.rs # Action selection policies
├── replay_buffer.rs # Experience replay buffer
├── types.rs # Core type definitions (state, action, environment interface, etc.)
└── utils.rs # Utility functions
```
## Installation
Add this library to your Rust project:
```toml
[dependencies]
rlkit = { version = "0.0.3", features = ["cuda"] }
```
## Core Components
### 1. Environment Interface (EnvTrait)
All reinforcement learning environments must implement the `EnvTrait` interface:
```rust
use rlkit::types::{EnvTrait, Status, Reward, Action};
struct MyEnv {
// Environment state
}
impl EnvTrait<u16, u16> for MyEnv {
fn step(&mut self, state: &Status<u16>, action: &Action<u16>) -> (Status<u16>, Reward, bool) {
// Execute action and return next state, reward, and whether it is done
}
fn reset(&mut self) -> Status<u16> {
// Reset the environment and return the initial state
}
fn action_space(&self) -> &[u16] {
// Return the dimension information of the action space
}
fn state_space(&self) -> &[u16] {
// Return the dimension information of the state space
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
```
### 2. Algorithm Usage
#### Q-Learning Example
```rust
use rlkit::algs::{QLearning, TrainArgs};
use rlkit::policies::EpsilonGreedy;
use rlkit::types::{EnvTrait};
// Create the environment
let mut env = MyEnv::new();
// Create a Q-Learning algorithm instance
let mut q_learning = QLearning::new(&env, 10000).unwrap();
// Create an ε-greedy policy
let mut policy = EpsilonGreedy::new(1.0, 0.01, 0.995);
// Configure training parameters
let train_args = TrainArgs {
epochs: 1000,
max_steps: 200,
batch_size: 64,
learning_rate: 0.1,
gamma: 0.99,
..Default::default()
};
// Train the model
q_learning.train(&mut env, &mut policy, train_args).unwrap();
// Get an action using the trained model
let state = env.reset();
let action = q_learning.get_action(&state, &mut policy).unwrap();
```
#### DQN Example
```rust
use candle_core::Device;
use rlkit::algs::{DQN, TrainArgs};
use rlkit::algs::dqn::DNQStateMode;
use rlkit::policies::EpsilonGreedy;
// Create the environment
let mut env = MyEnv::new();
// Select the computation device (CPU or GPU)
let device = Device::Cpu;
// Or use GPU: let device = Device::new_cuda(0).unwrap();
// Create a DQN algorithm instance
let mut dqn = DQN::new(
&env,
10000, // Replay buffer capacity
&[128, 64, 16], // Hidden layer structure
DNQStateMode::OneHot, // State encoding mode
&device
).unwrap();
// Create a policy
let mut policy = EpsilonGreedy::new(1.0, 0.01, 0.995);
// Configure training parameters
let train_args = TrainArgs {
epochs: 1000,
max_steps: 200,
batch_size: 32,
learning_rate: 1e-3,
gamma: 0.99,
update_freq: 5,
update_interval: 100,
};
// Train the model
dqn.train(&mut env, &mut policy, train_args).unwrap();
```
### 3. Policy Selection
The library provides multiple action selection policies:
```rust
use rlkit::policies::{PolicyConfig, EpsilonGreedy, Boltzmann};
// ε-greedy policy
let mut epsilon_greedy = EpsilonGreedy::new(1.0, 0.01, 0.995);
// Boltzmann policy
let mut boltzmann = Boltzmann::new(1.0, 0.1, 0.99);
```
## Usage Examples
The library includes two complete examples:
1. **Grid World** - Located in `examples/grid_world-example/`
- Implements a grid navigation environment with obstacles
- Demonstrates the usage of both Q-Learning and DQN algorithms
- Includes path visualization functionality
2. **Catch Rabbit** - Located in `examples/catch_rabbit.rs`
- Implements a multi-agent capture environment on a graph structure
- Demonstrates the application of the Q-Learning algorithm in a complex environment
### Running the Grid World Example
```bash
cd examples/grid_world-example
cargo run
```
Then follow the prompts to select the algorithm to use (1: Q-Learning, 2: DQN).
## Environment Implementation Guide
To implement a custom environment, you need to:
1. Create an environment struct to store the environment state
2. Implement all methods of the `EnvTrait` interface
3. Define an appropriate reward function and termination conditions
Detailed examples can be found in `examples/grid_world-example/env.rs` and `examples/catch_rabbit.rs`.
## Training Parameter Explanation
The `TrainArgs` struct includes the following parameters:
- `epochs`: Number of training epochs
- `max_steps`: Maximum number of steps per epoch
- `batch_size`: Number of samples for batch updates
- `learning_rate`: Learning rate
- `gamma`: Discount factor
- `update_freq`: Frequency of network updates (every how many steps)
- `update_interval`: Interval for target network updates (only used in DQN)
## License
This project is licensed under the MIT License. For more details, see the LICENSE file.