Expand description
§RLKit
A multi-algorithm deep reinforcement learning library based on Rust and Candle.
§Features
§Core Algorithms
- Q-Learning: A classic tabular reinforcement learning algorithm
- DQN (Deep Q-Network): Q-value function estimation based on deep neural networks
§Key Features
- Supports multiple action selection policies (ε-greedy, Boltzmann, Gaussian noise, etc.)
- Implements experience replay buffer to improve sample efficiency
- Supports target network updates (DQN) for improved training stability
- Generic design supporting different types of state and action spaces
- Comprehensive error handling mechanisms
§Hardware Acceleration
- CUDA Support: Enable GPU acceleration using the
cudafeature flag - CPU-only mode available for systems without GPU
§Core Components
- Environment Interface: Defines standard methods for interacting with environments
- Algorithm Implementations: Contains implementations of various reinforcement learning algorithms
- Policies: Provides multiple action selection strategies
- Networks: Neural network structures (used by DQN and other algorithms)
- Replay Buffer: Stores and samples experiences
§Getting Started
To use this library, simply implement the EnvTrait interface to create a custom environment, then select appropriate algorithms and policies for training.
§Usage Examples
Here are simple examples of using Q-Learning and DQN algorithms:
// 1. Import necessary components
use rlkit::{
Algorithm,
QLearning,
DQN,
TrainArgs,
EpsilonGreedy,
DeterministicPolicy,
EnvTrait,
Status,
Reward,
Action,
DNQStateMode,
Policy,
};
use candle_core::Device;
// 2. Create a simple environment example
struct SimpleEnv {
state: u16,
max_steps: usize,
current_step: usize,
}
impl SimpleEnv {
fn new(max_steps: usize) -> Self {
Self {
state: 0,
max_steps,
current_step: 0,
}
}
}
impl EnvTrait<u16, u16> for SimpleEnv {
fn step(&mut self, _state: &Status<u16>, action: &Action<u16>) -> (Status<u16>, Reward, bool) {
// Simple environment logic
self.current_step += 1;
let next_state = (self.state + action.as_slice()[0]) % 10;
self.state = next_state;
// Simple reward function
let reward = if next_state == 5 {
Reward(10.0)
} else {
Reward(-1.0)
};
// Check for termination
let done = self.current_step >= self.max_steps || next_state == 5;
// Use the correct Status::new constructor
(Status::new(vec![next_state], vec![10]), reward, done)
}
fn reset(&mut self) -> Status<u16> {
self.state = 0;
self.current_step = 0;
Status::new(vec![self.state], vec![10])
}
fn action_space(&self) -> &[u16] {
&[1]
}
fn state_space(&self) -> &[u16] {
&[10]
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
fn main() -> anyhow::Result<()> {
// Create environment
let mut env = SimpleEnv::new(100);
// Configure training parameters
let train_args = TrainArgs {
epochs: 100,
max_steps: 50,
batch_size: 64,
learning_rate: 0.1,
gamma: 0.99,
update_freq: 1,
update_interval: 10,
};
// 3. Create Q-Learning agent
let mut q_agent = QLearning::new(&env, 10000)?;
let mut q_policy = EpsilonGreedy::new(1.0, 0.01, 0.995); // Initial exploration rate, minimum exploration rate, decay rate
// 4. Train Q-Learning agent
println!("Training Q-Learning agent...");
q_agent.train(&mut env, &mut q_policy, train_args)?;
println!("Training complete!");
// 5. Use the trained agent for inference (demonstrate get_action usage)
println!("\nTesting the trained agent...");
let mut current_state = env.reset();
let mut total_reward = 0.0;
let mut steps = 0;
// Reset policy's exploration rate for testing
let mut q_policy = DeterministicPolicy;
// Run a test episode
while steps < 50 {
// Use get_action method to select action
let action = <rlkit::QLearning as Algorithm<SimpleEnv, u16, u16>>::get_action(&q_agent, ¤t_state, &mut q_policy)?;
// Execute action
let (next_state, reward, done) = env.step(¤t_state, &action);
// Update state and cumulative reward
let (next_state, reward, done) = env.step(¤t_state, &action);
// Update state and accumulated reward
current_state = next_state;
total_reward += reward.0 as f64;
steps += 1;
println!("Step {}: Action = {:?}, Reward = {}, Cumulative Reward = {:.2}",
steps, action.as_slice(), reward.0, total_reward);
if done {
println!("Termination condition reached!");
break;
}
}
println!("\nTest completed! Total Reward: {:.2}, Total Steps: {}", total_reward, steps);
Ok(())
}For more detailed examples, please check the examples directory in the project.
§Cargo Features
cuda: Enable CUDA GPU acceleration for faster training- Default features: CPU-only mode
Re-exports§
pub use types::Status;pub use types::Reward;pub use types::Sample;pub use types::EnvTrait;pub use types::Action;pub use types::QValue;pub use algs::Algorithm;pub use algs::QLearning;pub use algs::DQN;pub use algs::TrainArgs;pub use policies::Policy;pub use policies::EpsilonGreedy;pub use policies::Boltzmann;pub use policies::GaussianNoise;pub use policies::DeterministicPolicy;pub use algs::dqn::DNQStateMode;
Modules§
- algs
- 强化学习算法模块,封装了强化学习算法的训练和使用方法
- network
- Network implemented using the Candle library.
- policies
- Policy module, containing implementations of various action selection policies.
- replay_
buffer - Experience replay buffer, used for storing and sampling training samples.
- types
- Define the core types and environment interfaces in reinforcement learning.
- utils
- Hybrid base encoding utility functions.
Constants§
- VERSION
- Version number