RLKit
一个基于Rust和Candle的深度强化学习库,提供了Q-Learning和DQN算法的完整实现,支持自定义环境、多种策略选择和灵活的训练配置。未来将支持更多强化学习算法,如DDPG、PPO、A2C等。
Read this in other languages: 中文, English.
功能特点
- 实现了经典的Q-Learning和DQN算法
- 支持经验回放缓冲区
- 支持目标网络更新(DQN)
- 提供多种动作选择策略(ε-贪婪、Boltzmann、Ornstein-Uhlenbeck、高斯噪声等)
- 基于Candle库实现,支持CPU和CUDA GPU计算
- 泛型设计,支持不同类型的状态和动作空间
- 完整的错误处理机制
- 详细的测试用例
Features
cuda: 启用CUDA GPU计算支持
项目结构
src/
├── algs/ # 强化学习算法实现
│ ├── dqn.rs # DQN算法
│ ├── q_learning.rs # Q-Learning算法
│ └── mod.rs # 算法接口和公共组件
├── network.rs # 神经网络实现(用于DQN)
├── policies.rs # 动作选择策略
├── replay_buffer.rs # 经验回放缓冲区
├── types.rs # 核心类型定义(状态、动作、环境接口等)
└── utils.rs # 工具函数
安装
将此库添加到你的Rust项目中:
[]
= { = "0.0.1", = ["cuda"] }
核心组件
1. 环境接口 (EnvTrait)
所有强化学习环境必须实现EnvTrait接口:
use ;
2. 算法使用
Q-Learning 示例
use ;
use EpsilonGreedy;
use ;
// 创建环境
let mut env = new;
// 创建Q-Learning算法实例
let mut q_learning = new.unwrap;
// 创建ε-贪婪策略
let mut policy = new;
// 配置训练参数
let train_args = TrainArgs ;
// 训练模型
q_learning.train.unwrap;
// 使用训练好的模型获取动作
let state = env.reset;
let action = q_learning.get_action.unwrap;
DQN 示例
use Device;
use ;
use DNQStateMode;
use EpsilonGreedy;
// 创建环境
let mut env = new;
// 选择计算设备(CPU或GPU)
let device = Cpu;
// 或使用GPU: let device = Device::new_cuda(0).unwrap();
// 创建DQN算法实例
let mut dqn = DQNnew.unwrap;
// 创建策略
let mut policy = new;
// 配置训练参数
let train_args = TrainArgs ;
// 训练模型
dqn.train.unwrap;
3. 策略选择
库提供了多种动作选择策略:
use ;
// ε-贪婪策略
let mut epsilon_greedy = new;
// Boltzmann策略
let mut boltzmann = new;
使用示例
库中包含两个完整的示例:
-
网格世界 (Grid World) - 位于
examples/grid_world-example/- 实现了一个带障碍物的网格导航环境
- 展示了Q-Learning和DQN两种算法的使用
- 包含路径可视化功能
-
兔子抓捕 (Catch Rabbit) - 位于
examples/catch_rabbit.rs- 实现了一个图结构上的多智能体抓捕环境
- 展示了Q-Learning算法在复杂环境中的应用
运行网格世界示例
然后按照提示选择要使用的算法(1: Q-Learning, 2: DQN)。
环境实现指南
要实现自定义环境,你需要:
- 创建环境结构体,保存环境状态
- 实现
EnvTrait接口的所有方法 - 定义合适的奖励函数和终止条件
详细示例可参考 examples/grid_world-example/env.rs 和 examples/catch_rabbit.rs。
训练参数说明
TrainArgs结构体包含以下参数:
epochs: 训练轮数max_steps: 每轮最大步数batch_size: 批量更新的样本数量learning_rate: 学习率gamma: 折扣因子update_freq: 每多少步进行一次网络更新update_interval: 目标网络更新间隔(仅DQN使用)
许可证
本项目使用MIT许可证。详情请查看LICENSE文件。