use crate::error::{NeuralError, Result};
use crate::rl::policy::PolicyRng;
use crate::rl::value::QNetwork;
#[derive(Debug, Clone)]
pub struct Experience {
pub state: Vec<f32>,
pub action: usize,
pub reward: f32,
pub next_state: Vec<f32>,
pub done: bool,
}
pub struct DQNReplayBuffer {
capacity: usize,
buffer: Vec<Experience>,
ptr: usize,
size: usize,
rng: PolicyRng,
}
impl DQNReplayBuffer {
pub fn new(capacity: usize) -> Self {
assert!(capacity > 0, "capacity must be > 0");
Self {
capacity,
buffer: Vec::with_capacity(capacity),
ptr: 0,
size: 0,
rng: PolicyRng::new(0xabcd_ef01_2345_6789),
}
}
pub fn push(&mut self, exp: Experience) {
if self.size < self.capacity {
self.buffer.push(exp);
} else {
self.buffer[self.ptr] = exp;
}
self.ptr = (self.ptr + 1) % self.capacity;
self.size = (self.size + 1).min(self.capacity);
}
pub fn sample(&mut self, batch_size: usize) -> Result<Vec<Experience>> {
if self.size == 0 {
return Err(NeuralError::InvalidState("cannot sample from empty replay buffer".into()));
}
let samples: Vec<Experience> = (0..batch_size)
.map(|_| self.buffer[self.rng.usize_below(self.size)].clone())
.collect();
Ok(samples)
}
pub fn len(&self) -> usize {
self.size
}
pub fn is_empty(&self) -> bool {
self.size == 0
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn is_ready(&self, min_size: usize) -> bool {
self.size >= min_size
}
}
#[derive(Debug, Clone)]
pub struct DQNConfig {
pub lr: f32,
pub gamma: f32,
pub eps_start: f32,
pub eps_end: f32,
pub eps_decay_steps: usize,
pub batch_size: usize,
pub target_update_freq: usize,
pub buffer_capacity: usize,
pub learning_starts: usize,
pub double_dqn: bool,
pub hidden_dims: Vec<usize>,
}
impl Default for DQNConfig {
fn default() -> Self {
Self {
lr: 1e-3,
gamma: 0.99,
eps_start: 1.0,
eps_end: 0.05,
eps_decay_steps: 10_000,
batch_size: 32,
target_update_freq: 500,
buffer_capacity: 50_000,
learning_starts: 1_000,
double_dqn: true,
hidden_dims: vec![64, 64],
}
}
}
pub struct DQNAgent {
online_net: QNetwork,
target_net: QNetwork,
replay: DQNReplayBuffer,
config: DQNConfig,
steps: usize,
steps_since_target_update: usize,
rng: PolicyRng,
}
impl DQNAgent {
pub fn new(obs_dim: usize, num_actions: usize, config: DQNConfig) -> Self {
let online_net = QNetwork::new(obs_dim, &config.hidden_dims, num_actions);
let target_net = QNetwork::new(obs_dim, &config.hidden_dims, num_actions);
let replay = DQNReplayBuffer::new(config.buffer_capacity);
let rng = PolicyRng::from_time();
Self {
online_net,
target_net,
replay,
config,
steps: 0,
steps_since_target_update: 0,
rng,
}
}
pub fn epsilon(&self) -> f32 {
let frac = (self.steps as f32 / self.config.eps_decay_steps.max(1) as f32).min(1.0);
self.config.eps_start + frac * (self.config.eps_end - self.config.eps_start)
}
pub fn select_action(&mut self, obs: &[f32]) -> Result<usize> {
let eps = self.epsilon();
self.steps += 1;
if self.rng.uniform_f32() < eps {
Ok(self.rng.usize_below(self.online_net.num_actions()))
} else {
self.online_net.greedy_action(obs)
}
}
pub fn store_transition(&mut self, exp: Experience) {
self.replay.push(exp);
}
pub fn update(&mut self) -> Result<Option<f32>> {
if !self.replay.is_ready(self.config.learning_starts) {
return Ok(None);
}
let batch = self.replay.sample(self.config.batch_size)?;
let total_loss = self.td_update(&batch)?;
self.steps_since_target_update += 1;
if self.steps_since_target_update >= self.config.target_update_freq {
self.update_target()?;
self.steps_since_target_update = 0;
}
Ok(Some(total_loss / self.config.batch_size as f32))
}
pub fn update_target(&mut self) -> Result<()> {
self.target_net.copy_from(&self.online_net)
}
pub fn steps(&self) -> usize {
self.steps
}
pub fn replay_buffer(&self) -> &DQNReplayBuffer {
&self.replay
}
pub fn online_net(&self) -> &QNetwork {
&self.online_net
}
fn td_update(&mut self, batch: &[Experience]) -> Result<f32> {
let gamma = self.config.gamma;
let double = self.config.double_dqn;
let lr = self.config.lr;
let num_actions = self.online_net.num_actions();
let mut total_loss = 0.0_f32;
for exp in batch {
let qs = self.online_net.q_values(&exp.state)?;
if qs.len() != num_actions {
return Err(NeuralError::ShapeMismatch(format!(
"Q-value length {} != num_actions {}",
qs.len(), num_actions
)));
}
let td_target = if exp.done {
exp.reward
} else {
let next_q = if double {
let best_next_action = self.online_net.greedy_action(&exp.next_state)?;
let target_qs = self.target_net.q_values(&exp.next_state)?;
target_qs.get(best_next_action)
.copied()
.ok_or_else(|| NeuralError::ComputationError(
"best_next_action out of range".into()
))?
} else {
let target_qs = self.target_net.q_values(&exp.next_state)?;
target_qs.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
};
exp.reward + gamma * next_q
};
let mut targets = qs.clone();
targets[exp.action] = td_target;
let loss = self.online_net.update_action(&exp.state, exp.action, td_target, lr)?;
total_loss += loss;
}
Ok(total_loss)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_exp(obs_dim: usize, action: usize) -> Experience {
Experience {
state: vec![0.1_f32; obs_dim],
action,
reward: 1.0,
next_state: vec![0.2_f32; obs_dim],
done: false,
}
}
#[test]
fn dqn_replay_buffer_push_and_len() {
let mut buf = DQNReplayBuffer::new(10);
assert!(buf.is_empty());
buf.push(make_exp(4, 0));
assert_eq!(buf.len(), 1);
}
#[test]
fn dqn_replay_buffer_circular_overwrite() {
let mut buf = DQNReplayBuffer::new(3);
for _ in 0..5 {
buf.push(make_exp(4, 0));
}
assert_eq!(buf.len(), 3);
}
#[test]
fn dqn_replay_buffer_sample_not_empty() {
let mut buf = DQNReplayBuffer::new(20);
for i in 0..20 {
buf.push(make_exp(4, i % 2));
}
let samples = buf.sample(8).expect("sample failed");
assert_eq!(samples.len(), 8);
}
#[test]
fn dqn_replay_buffer_is_ready() {
let mut buf = DQNReplayBuffer::new(100);
for _ in 0..9 {
buf.push(make_exp(4, 0));
}
assert!(!buf.is_ready(10));
buf.push(make_exp(4, 1));
assert!(buf.is_ready(10));
}
#[test]
fn dqn_agent_select_action_in_range() {
let cfg = DQNConfig { hidden_dims: vec![16], ..Default::default() };
let mut agent = DQNAgent::new(4, 3, cfg);
let obs = vec![0.1_f32; 4];
let a = agent.select_action(&obs).expect("select_action failed");
assert!(a < 3, "action should be in [0, num_actions)");
}
#[test]
fn dqn_agent_epsilon_decays() {
let cfg = DQNConfig {
eps_start: 1.0,
eps_end: 0.1,
eps_decay_steps: 10,
hidden_dims: vec![8],
..Default::default()
};
let mut agent = DQNAgent::new(4, 2, cfg);
let obs = vec![0.0_f32; 4];
assert!((agent.epsilon() - 1.0).abs() < 1e-4);
for _ in 0..10 {
let _ = agent.select_action(&obs).expect("act");
}
assert!((agent.epsilon() - 0.1).abs() < 0.05, "epsilon should reach eps_end");
}
#[test]
fn dqn_agent_returns_none_before_ready() {
let cfg = DQNConfig {
learning_starts: 100,
hidden_dims: vec![8],
..Default::default()
};
let mut agent = DQNAgent::new(4, 2, cfg);
agent.store_transition(make_exp(4, 0));
let result = agent.update().expect("update failed");
assert!(result.is_none(), "should not update before learning_starts");
}
#[test]
fn dqn_agent_update_after_learning_starts() {
let cfg = DQNConfig {
learning_starts: 10,
batch_size: 4,
hidden_dims: vec![8],
..Default::default()
};
let mut agent = DQNAgent::new(4, 2, cfg);
for i in 0..20 {
let done = i % 5 == 4;
agent.store_transition(Experience {
state: vec![i as f32 * 0.01; 4],
action: i % 2,
reward: 1.0,
next_state: vec![(i + 1) as f32 * 0.01; 4],
done,
});
}
let result = agent.update().expect("update failed");
assert!(result.is_some(), "should produce a loss after learning_starts");
let loss = result.expect("operation should succeed");
assert!(loss.is_finite(), "loss must be finite; got {}", loss);
}
#[test]
fn dqn_agent_target_update_copy() {
let cfg = DQNConfig {
target_update_freq: 1,
learning_starts: 5,
batch_size: 4,
hidden_dims: vec![8],
..Default::default()
};
let mut agent = DQNAgent::new(4, 2, cfg);
for i in 0..20 {
agent.store_transition(make_exp(4, i % 2));
}
agent.update().expect("update failed");
let obs = vec![0.1_f32; 4];
let online_q = agent.online_net().q_values(&obs).expect("online qs");
assert_eq!(online_q.len(), 2);
}
#[test]
fn dqn_cartpole_10_steps_no_panic() {
use crate::rl::environments::{CartPole, Environment};
let cfg = DQNConfig {
hidden_dims: vec![16],
learning_starts: 5,
batch_size: 4,
..Default::default()
};
let mut agent = DQNAgent::new(4, 2, cfg);
let mut env = CartPole::new();
let mut obs: Vec<f32> = env.reset().iter().map(|&x| x as f32).collect();
for _ in 0..10 {
let action = agent.select_action(&obs).expect("select_action");
let action_arr = scirs2_core::ndarray::array![action as f64];
let (next_f64, reward, done) = env.step(&action_arr);
let next_obs: Vec<f32> = next_f64.iter().map(|&x| x as f32).collect();
agent.store_transition(Experience {
state: obs.clone(),
action,
reward: reward as f32,
next_state: next_obs.clone(),
done,
});
let _ = agent.update().expect("update");
if done {
obs = env.reset().iter().map(|&x| x as f32).collect();
} else {
obs = next_obs;
}
}
}
#[test]
fn double_dqn_agent_no_panic() {
let cfg = DQNConfig {
double_dqn: true,
learning_starts: 8,
batch_size: 4,
hidden_dims: vec![8],
..Default::default()
};
let mut agent = DQNAgent::new(4, 2, cfg);
for i in 0..20 {
agent.store_transition(make_exp(4, i % 2));
}
let loss = agent.update().expect("update").expect("should produce loss");
assert!(loss.is_finite());
}
}