# DQN-Library
一个基于Rust和Candle的深度强化学习库,提供了Q-Learning和DQN算法的完整实现,支持自定义环境、多种策略选择和灵活的训练配置。
## 功能特点
- 实现了经典的Q-Learning和DQN算法
- 支持经验回放缓冲区
- 支持目标网络更新(DQN)
- 提供多种动作选择策略(ε-贪婪、Boltzmann、Ornstein-Uhlenbeck、高斯噪声等)
- 基于Candle库实现,支持CPU和CUDA GPU计算
- 泛型设计,支持不同类型的状态和动作空间
- 完整的错误处理机制
- 详细的测试用例
## 项目结构
```
src/
├── algs/ # 强化学习算法实现
│ ├── dqn.rs # DQN算法
│ ├── q_learning.rs # Q-Learning算法
│ └── mod.rs # 算法接口和公共组件
├── network.rs # 神经网络实现(用于DQN)
├── policies.rs # 动作选择策略
├── replay_buffer.rs # 经验回放缓冲区
├── types.rs # 核心类型定义(状态、动作、环境接口等)
└── utils.rs # 工具函数
```
## 安装
将此库添加到你的Rust项目中:
```toml
[dependencies]
dqn-library = { path = "path/to/dqn-library" }
```
## 核心组件
### 1. 环境接口 (EnvTrait)
所有强化学习环境必须实现`EnvTrait`接口:
```rust
use rlkit::types::{EnvTrait, Status, Reward, Action};
struct MyEnv {
// 环境状态
}
impl EnvTrait<u16, u16> for MyEnv {
fn step(&mut self, state: &Status<u16>, action: &Action<u16>) -> (Status<u16>, Reward, bool) {
// 执行动作并返回下一状态、奖励和是否完成
}
fn reset(&mut self) -> Status<u16> {
// 重置环境并返回初始状态
}
fn action_space(&self) -> &[u16] {
// 返回动作空间的维度信息
}
fn state_space(&self) -> &[u16] {
// 返回状态空间的维度信息
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
```
### 2. 算法使用
#### Q-Learning 示例
```rust
use rlkit::algs::{QLearning, TrainArgs};
use rlkit::policies::EpsilonGreedy;
use rlkit::types::{EnvTrait};
// 创建环境
let mut env = MyEnv::new();
// 创建Q-Learning算法实例
let mut q_learning = QLearning::new(&env, 10000).unwrap();
// 创建ε-贪婪策略
let mut policy = EpsilonGreedy::new(1.0, 0.01, 0.995);
// 配置训练参数
let train_args = TrainArgs {
epochs: 1000,
max_steps: 200,
batch_size: 64,
learning_rate: 0.1,
gamma: 0.99,
..Default::default()
};
// 训练模型
q_learning.train(&mut env, &mut policy, train_args).unwrap();
// 使用训练好的模型获取动作
let state = env.reset();
let action = q_learning.get_action(&state, &mut policy).unwrap();
```
#### DQN 示例
```rust
use candle_core::Device;
use rlkit::algs::{DQN, TrainArgs};
use rlkit::algs::dqn::DNQStateMode;
use rlkit::policies::EpsilonGreedy;
// 创建环境
let mut env = MyEnv::new();
// 选择计算设备(CPU或GPU)
let device = Device::new_cpu().unwrap();
// 或使用GPU: let device = Device::new_cuda(0).unwrap();
// 创建DQN算法实例
let mut dqn = DQN::new(
&env,
10000, // 回放缓冲区容量
&[128, 64, 16], // 隐藏层结构
DNQStateMode::OneHot, // 状态编码模式
&device
).unwrap();
// 创建策略
let mut policy = EpsilonGreedy::new(1.0, 0.01, 0.995);
// 配置训练参数
let train_args = TrainArgs {
epochs: 1000,
max_steps: 200,
batch_size: 32,
learning_rate: 1e-3,
gamma: 0.99,
update_freq: 5,
update_interval: 100,
};
// 训练模型
dqn.train(&mut env, &mut policy, train_args).unwrap();
```
### 3. 策略选择
库提供了多种动作选择策略:
```rust
use rlkit::policies::{PolicyConfig, EpsilonGreedy, Boltzmann};
// ε-贪婪策略
let mut epsilon_greedy = EpsilonGreedy::new(1.0, 0.01, 0.995);
// Boltzmann策略
let mut boltzmann = Boltzmann::new(1.0, 0.1, 0.99);
// 通过配置创建策略
let policy = PolicyConfig::dqn_epsilon_greedy().create_policy(action_dim).unwrap();
```
## 使用示例
库中包含两个完整的示例:
1. **网格世界 (Grid World)** - 位于 `examples/grid_world-example/`
- 实现了一个带障碍物的网格导航环境
- 展示了Q-Learning和DQN两种算法的使用
- 包含路径可视化功能
2. **兔子抓捕 (Catch Rabbit)** - 位于 `examples/catch_rabbit.rs`
- 实现了一个图结构上的多智能体抓捕环境
- 展示了Q-Learning算法在复杂环境中的应用
### 运行网格世界示例
```bash
cd examples/grid_world-example
cargo run
```
然后按照提示选择要使用的算法(1: Q-Learning, 2: DQN)。
## 环境实现指南
要实现自定义环境,你需要:
1. 创建环境结构体,保存环境状态
2. 实现`EnvTrait`接口的所有方法
3. 定义合适的奖励函数和终止条件
详细示例可参考 `examples/grid_world-example/env.rs` 和 `examples/catch_rabbit.rs`。
## 训练参数说明
`TrainArgs`结构体包含以下参数:
- `epochs`: 训练轮数
- `max_steps`: 每轮最大步数
- `batch_size`: 批量更新的样本数量
- `learning_rate`: 学习率
- `gamma`: 折扣因子
- `update_freq`: 每多少步进行一次网络更新
- `update_interval`: 目标网络更新间隔(仅DQN使用)
## 依赖项
- [candle-core](https://crates.io/crates/candle-core): 张量计算库
- [candle-nn](https://crates.io/crates/candle-nn): 神经网络库
- [rand](https://crates.io/crates/rand): 随机数生成库
- [indicatif](https://crates.io/crates/indicatif): 进度条显示
## 许可证
本项目使用MIT许可证。详情请查看LICENSE文件。