rlkit 0.0.3

A deep reinforcement learning library based on Rust and Candle, providing complete implementations of Q-Learning and DQN algorithms, supporting custom environments, various policy choices, and flexible training configurations. Future support will include more reinforcement learning algorithms, such as DDPG, PPO, A2C, etc.
Documentation

RLKit

A deep reinforcement learning library based on Rust and Candle, providing complete implementations of Q-Learning and DQN algorithms, supporting custom environments, various policy choices, and flexible training configurations. Future support will include more reinforcement learning algorithms, such as DDPG, PPO, A2C, etc.

Key Characteristics

  • Implements classic Q-Learning and DQN algorithms
  • Supports experience replay buffer
  • Supports target network updates (for DQN)
  • Provides multiple action selection policies (ε-greedy, Boltzmann, Ornstein-Uhlenbeck, Gaussian noise, etc.)
  • Implemented based on the Candle library, supporting CPU and CUDA GPU computation
  • Generic design, supporting different types of state and action spaces
  • Complete error handling mechanism
  • Detailed test cases

Features

  • cuda: Enables CUDA GPU computation support

Project Structure

src/
├── algs/           # Reinforcement learning algorithm implementations
│   ├── dqn.rs      # DQN algorithm
│   ├── q_learning.rs # Q-Learning algorithm
│   └── mod.rs      # Algorithm interfaces and common components
├── network.rs      # Neural network implementation (for DQN)
├── policies.rs     # Action selection policies
├── replay_buffer.rs # Experience replay buffer
├── types.rs        # Core type definitions (state, action, environment interface, etc.)
└── utils.rs        # Utility functions

Installation

Add this library to your Rust project:

[dependencies]
rlkit = { version = "0.0.3", features = ["cuda"] }

Core Components

1. Environment Interface (EnvTrait)

All reinforcement learning environments must implement the EnvTrait interface:

use rlkit::types::{EnvTrait, Status, Reward, Action};

struct MyEnv {
    // Environment state
}

impl EnvTrait<u16, u16> for MyEnv {
    fn step(&mut self, state: &Status<u16>, action: &Action<u16>) -> (Status<u16>, Reward, bool) {
        // Execute action and return next state, reward, and whether it is done
    }
    
    fn reset(&mut self) -> Status<u16> {
        // Reset the environment and return the initial state
    }
    
    fn action_space(&self) -> &[u16] {
        // Return the dimension information of the action space
    }
    
    fn state_space(&self) -> &[u16] {
        // Return the dimension information of the state space
    }
    
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }
    
    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
        self
    }
}

2. Algorithm Usage

Q-Learning Example

use rlkit::algs::{QLearning, TrainArgs};
use rlkit::policies::EpsilonGreedy;
use rlkit::types::{EnvTrait};

// Create the environment
let mut env = MyEnv::new();

// Create a Q-Learning algorithm instance
let mut q_learning = QLearning::new(&env, 10000).unwrap();

// Create an ε-greedy policy
let mut policy = EpsilonGreedy::new(1.0, 0.01, 0.995);

// Configure training parameters
let train_args = TrainArgs {
    epochs: 1000,
    max_steps: 200,
    batch_size: 64,
    learning_rate: 0.1,
    gamma: 0.99,
    ..Default::default()
};

// Train the model
q_learning.train(&mut env, &mut policy, train_args).unwrap();

// Get an action using the trained model
let state = env.reset();
let action = q_learning.get_action(&state, &mut policy).unwrap();

DQN Example

use candle_core::Device;
use rlkit::algs::{DQN, TrainArgs};
use rlkit::algs::dqn::DNQStateMode;
use rlkit::policies::EpsilonGreedy;

// Create the environment
let mut env = MyEnv::new();

// Select the computation device (CPU or GPU)
let device = Device::Cpu;
// Or use GPU: let device = Device::new_cuda(0).unwrap();

// Create a DQN algorithm instance
let mut dqn = DQN::new(
    &env,
    10000,                  // Replay buffer capacity
    &[128, 64, 16],         // Hidden layer structure
    DNQStateMode::OneHot,   // State encoding mode
    &device
).unwrap();

// Create a policy
let mut policy = EpsilonGreedy::new(1.0, 0.01, 0.995);

// Configure training parameters
let train_args = TrainArgs {
    epochs: 1000,
    max_steps: 200,
    batch_size: 32,
    learning_rate: 1e-3,
    gamma: 0.99,
    update_freq: 5,
    update_interval: 100,
};

// Train the model
dqn.train(&mut env, &mut policy, train_args).unwrap();

3. Policy Selection

The library provides multiple action selection policies:

use rlkit::policies::{PolicyConfig, EpsilonGreedy, Boltzmann};

// ε-greedy policy
let mut epsilon_greedy = EpsilonGreedy::new(1.0, 0.01, 0.995);

// Boltzmann policy
let mut boltzmann = Boltzmann::new(1.0, 0.1, 0.99);

Usage Examples

The library includes two complete examples:

  1. Grid World - Located in examples/grid_world-example/

    • Implements a grid navigation environment with obstacles
    • Demonstrates the usage of both Q-Learning and DQN algorithms
    • Includes path visualization functionality
  2. Catch Rabbit - Located in examples/catch_rabbit.rs

    • Implements a multi-agent capture environment on a graph structure
    • Demonstrates the application of the Q-Learning algorithm in a complex environment

Running the Grid World Example

cd examples/grid_world-example
cargo run

Then follow the prompts to select the algorithm to use (1: Q-Learning, 2: DQN).

Environment Implementation Guide

To implement a custom environment, you need to:

  1. Create an environment struct to store the environment state
  2. Implement all methods of the EnvTrait interface
  3. Define an appropriate reward function and termination conditions

Detailed examples can be found in examples/grid_world-example/env.rs and examples/catch_rabbit.rs.

Training Parameter Explanation

The TrainArgs struct includes the following parameters:

  • epochs: Number of training epochs
  • max_steps: Maximum number of steps per epoch
  • batch_size: Number of samples for batch updates
  • learning_rate: Learning rate
  • gamma: Discount factor
  • update_freq: Frequency of network updates (every how many steps)
  • update_interval: Interval for target network updates (only used in DQN)

License

This project is licensed under the MIT License. For more details, see the LICENSE file.