rlkit 0.0.3

A deep reinforcement learning library based on Rust and Candle, providing complete implementations of Q-Learning and DQN algorithms, supporting custom environments, various policy choices, and flexible training configurations. Future support will include more reinforcement learning algorithms, such as DDPG, PPO, A2C, etc.
Documentation
# RLKit


A deep reinforcement learning library based on Rust and Candle, providing complete implementations of Q-Learning and DQN algorithms, supporting custom environments, various policy choices, and flexible training configurations. Future support will include more reinforcement learning algorithms, such as DDPG, PPO, A2C, etc.

## Key Characteristics


- Implements classic Q-Learning and DQN algorithms
- Supports experience replay buffer
- Supports target network updates (for DQN)
- Provides multiple action selection policies (ε-greedy, Boltzmann, Ornstein-Uhlenbeck, Gaussian noise, etc.)
- Implemented based on the Candle library, supporting CPU and CUDA GPU computation
- Generic design, supporting different types of state and action spaces
- Complete error handling mechanism
- Detailed test cases

## Features


- `cuda`: Enables CUDA GPU computation support

## Project Structure


```plaintext
src/
├── algs/           # Reinforcement learning algorithm implementations
│   ├── dqn.rs      # DQN algorithm
│   ├── q_learning.rs # Q-Learning algorithm
│   └── mod.rs      # Algorithm interfaces and common components
├── network.rs      # Neural network implementation (for DQN)
├── policies.rs     # Action selection policies
├── replay_buffer.rs # Experience replay buffer
├── types.rs        # Core type definitions (state, action, environment interface, etc.)
└── utils.rs        # Utility functions
```

## Installation


Add this library to your Rust project:

```toml
[dependencies]
rlkit = { version = "0.0.3", features = ["cuda"] }
```

## Core Components


### 1. Environment Interface (EnvTrait)


All reinforcement learning environments must implement the `EnvTrait` interface:

```rust
use rlkit::types::{EnvTrait, Status, Reward, Action};

struct MyEnv {
    // Environment state
}

impl EnvTrait<u16, u16> for MyEnv {
    fn step(&mut self, state: &Status<u16>, action: &Action<u16>) -> (Status<u16>, Reward, bool) {
        // Execute action and return next state, reward, and whether it is done
    }
    
    fn reset(&mut self) -> Status<u16> {
        // Reset the environment and return the initial state
    }
    
    fn action_space(&self) -> &[u16] {
        // Return the dimension information of the action space
    }
    
    fn state_space(&self) -> &[u16] {
        // Return the dimension information of the state space
    }
    
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }
    
    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
        self
    }
}
```

### 2. Algorithm Usage


#### Q-Learning Example


```rust
use rlkit::algs::{QLearning, TrainArgs};
use rlkit::policies::EpsilonGreedy;
use rlkit::types::{EnvTrait};

// Create the environment
let mut env = MyEnv::new();

// Create a Q-Learning algorithm instance
let mut q_learning = QLearning::new(&env, 10000).unwrap();

// Create an ε-greedy policy
let mut policy = EpsilonGreedy::new(1.0, 0.01, 0.995);

// Configure training parameters
let train_args = TrainArgs {
    epochs: 1000,
    max_steps: 200,
    batch_size: 64,
    learning_rate: 0.1,
    gamma: 0.99,
    ..Default::default()
};

// Train the model
q_learning.train(&mut env, &mut policy, train_args).unwrap();

// Get an action using the trained model
let state = env.reset();
let action = q_learning.get_action(&state, &mut policy).unwrap();
```

#### DQN Example


```rust
use candle_core::Device;
use rlkit::algs::{DQN, TrainArgs};
use rlkit::algs::dqn::DNQStateMode;
use rlkit::policies::EpsilonGreedy;

// Create the environment
let mut env = MyEnv::new();

// Select the computation device (CPU or GPU)
let device = Device::Cpu;
// Or use GPU: let device = Device::new_cuda(0).unwrap();

// Create a DQN algorithm instance
let mut dqn = DQN::new(
    &env,
    10000,                  // Replay buffer capacity
    &[128, 64, 16],         // Hidden layer structure
    DNQStateMode::OneHot,   // State encoding mode
    &device
).unwrap();

// Create a policy
let mut policy = EpsilonGreedy::new(1.0, 0.01, 0.995);

// Configure training parameters
let train_args = TrainArgs {
    epochs: 1000,
    max_steps: 200,
    batch_size: 32,
    learning_rate: 1e-3,
    gamma: 0.99,
    update_freq: 5,
    update_interval: 100,
};

// Train the model
dqn.train(&mut env, &mut policy, train_args).unwrap();
```

### 3. Policy Selection


The library provides multiple action selection policies:

```rust
use rlkit::policies::{PolicyConfig, EpsilonGreedy, Boltzmann};

// ε-greedy policy
let mut epsilon_greedy = EpsilonGreedy::new(1.0, 0.01, 0.995);

// Boltzmann policy
let mut boltzmann = Boltzmann::new(1.0, 0.1, 0.99);
```

## Usage Examples


The library includes two complete examples:

1. **Grid World** - Located in `examples/grid_world-example/`
   - Implements a grid navigation environment with obstacles
   - Demonstrates the usage of both Q-Learning and DQN algorithms
   - Includes path visualization functionality

2. **Catch Rabbit** - Located in `examples/catch_rabbit.rs`
   - Implements a multi-agent capture environment on a graph structure
   - Demonstrates the application of the Q-Learning algorithm in a complex environment

### Running the Grid World Example


```bash
cd examples/grid_world-example
cargo run
```

Then follow the prompts to select the algorithm to use (1: Q-Learning, 2: DQN).

## Environment Implementation Guide


To implement a custom environment, you need to:

1. Create an environment struct to store the environment state
2. Implement all methods of the `EnvTrait` interface
3. Define an appropriate reward function and termination conditions

Detailed examples can be found in `examples/grid_world-example/env.rs` and `examples/catch_rabbit.rs`.

## Training Parameter Explanation


The `TrainArgs` struct includes the following parameters:

- `epochs`: Number of training epochs
- `max_steps`: Maximum number of steps per epoch
- `batch_size`: Number of samples for batch updates
- `learning_rate`: Learning rate
- `gamma`: Discount factor
- `update_freq`: Frequency of network updates (every how many steps)
- `update_interval`: Interval for target network updates (only used in DQN)

## License


This project is licensed under the MIT License. For more details, see the LICENSE file.