rlkit 0.0.2

一个基于Rust和candle的强化学习库
Documentation

RLKit

一个基于Rust和Candle的深度强化学习库,提供了Q-Learning和DQN算法的完整实现,支持自定义环境、多种策略选择和灵活的训练配置。未来将支持更多强化学习算法,如DDPG、PPO、A2C等。

Read this in other languages: 中文, English.

功能特点

  • 实现了经典的Q-Learning和DQN算法
  • 支持经验回放缓冲区
  • 支持目标网络更新(DQN)
  • 提供多种动作选择策略(ε-贪婪、Boltzmann、Ornstein-Uhlenbeck、高斯噪声等)
  • 基于Candle库实现,支持CPU和CUDA GPU计算
  • 泛型设计,支持不同类型的状态和动作空间
  • 完整的错误处理机制
  • 详细的测试用例

Features

  • cuda: 启用CUDA GPU计算支持

项目结构

src/
├── algs/           # 强化学习算法实现
│   ├── dqn.rs      # DQN算法
│   ├── q_learning.rs # Q-Learning算法
│   └── mod.rs      # 算法接口和公共组件
├── network.rs      # 神经网络实现(用于DQN)
├── policies.rs     # 动作选择策略
├── replay_buffer.rs # 经验回放缓冲区
├── types.rs        # 核心类型定义(状态、动作、环境接口等)
└── utils.rs        # 工具函数

安装

将此库添加到你的Rust项目中:

[dependencies]
rlkit = { version = "0.0.1", features = ["cuda"] }

核心组件

1. 环境接口 (EnvTrait)

所有强化学习环境必须实现EnvTrait接口:

use rlkit::types::{EnvTrait, Status, Reward, Action};

struct MyEnv {
    // 环境状态
}

impl EnvTrait<u16, u16> for MyEnv {
    fn step(&mut self, state: &Status<u16>, action: &Action<u16>) -> (Status<u16>, Reward, bool) {
        // 执行动作并返回下一状态、奖励和是否完成
    }
    
    fn reset(&mut self) -> Status<u16> {
        // 重置环境并返回初始状态
    }
    
    fn action_space(&self) -> &[u16] {
        // 返回动作空间的维度信息
    }
    
    fn state_space(&self) -> &[u16] {
        // 返回状态空间的维度信息
    }
    
    fn as_any(&self) -> &dyn std::any::Any {
        self
    }
    
    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
        self
    }
}

2. 算法使用

Q-Learning 示例

use rlkit::algs::{QLearning, TrainArgs};
use rlkit::policies::EpsilonGreedy;
use rlkit::types::{EnvTrait};

// 创建环境
let mut env = MyEnv::new();

// 创建Q-Learning算法实例
let mut q_learning = QLearning::new(&env, 10000).unwrap();

// 创建ε-贪婪策略
let mut policy = EpsilonGreedy::new(1.0, 0.01, 0.995);

// 配置训练参数
let train_args = TrainArgs {
    epochs: 1000,
    max_steps: 200,
    batch_size: 64,
    learning_rate: 0.1,
    gamma: 0.99,
    ..Default::default()
};

// 训练模型
q_learning.train(&mut env, &mut policy, train_args).unwrap();

// 使用训练好的模型获取动作
let state = env.reset();
let action = q_learning.get_action(&state, &mut policy).unwrap();

DQN 示例

use candle_core::Device;
use rlkit::algs::{DQN, TrainArgs};
use rlkit::algs::dqn::DNQStateMode;
use rlkit::policies::EpsilonGreedy;

// 创建环境
let mut env = MyEnv::new();

// 选择计算设备(CPU或GPU)
let device = Device::Cpu;
// 或使用GPU: let device = Device::new_cuda(0).unwrap();

// 创建DQN算法实例
let mut dqn = DQN::new(
    &env,
    10000,                  // 回放缓冲区容量
    &[128, 64, 16],         // 隐藏层结构
    DNQStateMode::OneHot,   // 状态编码模式
    &device
).unwrap();

// 创建策略
let mut policy = EpsilonGreedy::new(1.0, 0.01, 0.995);

// 配置训练参数
let train_args = TrainArgs {
    epochs: 1000,
    max_steps: 200,
    batch_size: 32,
    learning_rate: 1e-3,
    gamma: 0.99,
    update_freq: 5,
    update_interval: 100,
};

// 训练模型
dqn.train(&mut env, &mut policy, train_args).unwrap();

3. 策略选择

库提供了多种动作选择策略:

use rlkit::policies::{PolicyConfig, EpsilonGreedy, Boltzmann};

// ε-贪婪策略
let mut epsilon_greedy = EpsilonGreedy::new(1.0, 0.01, 0.995);

// Boltzmann策略
let mut boltzmann = Boltzmann::new(1.0, 0.1, 0.99);

使用示例

库中包含两个完整的示例:

  1. 网格世界 (Grid World) - 位于 examples/grid_world-example/

    • 实现了一个带障碍物的网格导航环境
    • 展示了Q-Learning和DQN两种算法的使用
    • 包含路径可视化功能
  2. 兔子抓捕 (Catch Rabbit) - 位于 examples/catch_rabbit.rs

    • 实现了一个图结构上的多智能体抓捕环境
    • 展示了Q-Learning算法在复杂环境中的应用

运行网格世界示例

cd examples/grid_world-example
cargo run

然后按照提示选择要使用的算法(1: Q-Learning, 2: DQN)。

环境实现指南

要实现自定义环境,你需要:

  1. 创建环境结构体,保存环境状态
  2. 实现EnvTrait接口的所有方法
  3. 定义合适的奖励函数和终止条件

详细示例可参考 examples/grid_world-example/env.rsexamples/catch_rabbit.rs

训练参数说明

TrainArgs结构体包含以下参数:

  • epochs: 训练轮数
  • max_steps: 每轮最大步数
  • batch_size: 批量更新的样本数量
  • learning_rate: 学习率
  • gamma: 折扣因子
  • update_freq: 每多少步进行一次网络更新
  • update_interval: 目标网络更新间隔(仅DQN使用)

许可证

本项目使用MIT许可证。详情请查看LICENSE文件。