1use train_station::{
12 gradtrack::{clear_all_graphs_known, NoGradTrack},
13 optimizers::{Adam, Optimizer},
14 Tensor,
15};
16
17#[allow(clippy::duplicate_mod)]
19#[path = "../neural_networks/basic_linear_layer.rs"]
20mod basic_linear_layer;
21use basic_linear_layer::LinearLayer;
22
23struct SmallRng {
29 state: u64,
30}
31
32impl SmallRng {
33 fn new(seed: u64) -> Self {
34 Self { state: seed.max(1) }
35 }
36 fn next_u32(&mut self) -> u32 {
37 self.state = self.state.wrapping_mul(1664525).wrapping_add(1013904223);
38 (self.state >> 16) as u32
39 }
40 fn next_f32(&mut self) -> f32 {
41 (self.next_u32() as f32) / (u32::MAX as f32)
42 }
43 fn uniform(&mut self, low: f32, high: f32) -> f32 {
44 low + (high - low) * self.next_f32()
45 }
46 fn sample_index(&mut self, upper_exclusive: usize) -> usize {
47 (self.next_u32() as usize) % upper_exclusive.max(1)
48 }
49}
50
51struct Mlp {
56 layers: Vec<LinearLayer>,
57}
58
59impl Mlp {
60 fn new(sizes: &[usize], seed: Option<u64>) -> Self {
61 assert!(sizes.len() >= 2);
62 let mut layers = Vec::new();
63 let mut s = seed;
64 for w in sizes.windows(2) {
65 layers.push(LinearLayer::new(w[0], w[1], s));
66 s = s.map(|v| v + 1);
67 }
68 Self { layers }
69 }
70
71 fn forward(&self, input: &Tensor, final_activation: Option<fn(&Tensor) -> Tensor>) -> Tensor {
72 let mut current: Option<Tensor> = None;
73 for (i, layer) in self.layers.iter().enumerate() {
74 let out = if i == 0 {
75 layer.forward(input)
76 } else {
77 layer.forward(current.as_ref().unwrap())
78 };
79 let is_last = i + 1 == self.layers.len();
80 let out = if !is_last {
81 out.relu()
82 } else if let Some(act) = final_activation {
83 act(&out)
84 } else {
85 out
86 };
87 current = Some(out);
88 }
89 current.expect("MLP has at least one layer")
90 }
91
92 fn parameters(&mut self) -> Vec<&mut Tensor> {
93 let mut params = Vec::new();
94 for l in &mut self.layers {
95 params.extend(l.parameters());
96 }
97 params
98 }
99
100 fn set_requires_grad_all(&mut self, enable: bool) {
101 for l in &mut self.layers {
102 l.weight.set_requires_grad(enable);
103 l.bias.set_requires_grad(enable);
104 }
105 }
106
107 fn copy_from(&mut self, other: &Self) {
109 for (t, s) in self.layers.iter_mut().zip(other.layers.iter()) {
110 {
111 let src = s.weight.data();
112 let dst = t.weight.data_mut();
113 dst.copy_from_slice(src);
114 }
115 {
116 let src = s.bias.data();
117 let dst = t.bias.data_mut();
118 dst.copy_from_slice(src);
119 }
120 t.weight.set_requires_grad(false);
121 t.bias.set_requires_grad(false);
122 }
123 }
124}
125
126struct QNet {
131 net: Mlp,
132}
133
134impl QNet {
135 fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
136 let net = Mlp::new(&[state_dim, 64, 64, action_dim], seed);
137 Self { net }
138 }
139 fn forward(&self, state: &Tensor) -> Tensor {
140 self.net.forward(state, None)
141 }
142 fn parameters(&mut self) -> Vec<&mut Tensor> {
143 self.net.parameters()
144 }
145 fn set_requires_grad_all(&mut self, enable: bool) {
146 self.net.set_requires_grad_all(enable);
147 }
148}
149
150struct YardEnv {
155 pos: f32,
156 vel: f32,
157 steps: usize,
158 max_steps: usize,
159 rng: SmallRng,
160}
161
162impl YardEnv {
163 const ACTIONS: [f32; 3] = [-1.0, 0.0, 1.0];
164
165 fn new(seed: u64) -> Self {
166 let mut env = Self {
167 pos: 0.0,
168 vel: 0.0,
169 steps: 0,
170 max_steps: 200,
171 rng: SmallRng::new(seed),
172 };
173 env.reset();
174 env
175 }
176
177 fn reset(&mut self) -> Tensor {
178 self.pos = self.rng.uniform(-0.5, 0.5);
179 self.vel = self.rng.uniform(-0.1, 0.1);
180 self.steps = 0;
181 self.state_tensor()
182 }
183
184 fn state_tensor(&self) -> Tensor {
185 Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
186 }
187
188 fn step(&mut self, action_index: usize) -> (Tensor, f32, bool) {
189 let a = Self::ACTIONS[action_index.min(2)];
190 self.vel += 0.1 * a - 0.01 * self.pos;
191 self.pos += self.vel;
192 self.steps += 1;
193 let reward = -(self.pos * self.pos) - 0.05 * (a * a);
194 let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
195 (self.state_tensor(), reward, done)
196 }
197}
198
199struct ReplayBuffer {
204 capacity: usize,
205 size: usize,
206 pos: usize,
207 state_dim: usize,
208 states: Vec<f32>,
209 actions: Vec<usize>,
210 rewards: Vec<f32>,
211 dones: Vec<f32>,
212 next_states: Vec<f32>,
213}
214
215impl ReplayBuffer {
216 fn new(capacity: usize, state_dim: usize) -> Self {
217 Self {
218 capacity,
219 size: 0,
220 pos: 0,
221 state_dim,
222 states: vec![0.0; capacity * state_dim],
223 actions: vec![0usize; capacity],
224 rewards: vec![0.0; capacity],
225 dones: vec![0.0; capacity],
226 next_states: vec![0.0; capacity * state_dim],
227 }
228 }
229
230 fn push(&mut self, s: &[f32], a_idx: usize, r: f32, d: f32, s2: &[f32]) {
231 let i = self.pos;
232 let so = i * self.state_dim;
233 self.states[so..so + self.state_dim].copy_from_slice(s);
234 self.actions[i] = a_idx;
235 self.rewards[i] = r;
236 self.dones[i] = d;
237 self.next_states[so..so + self.state_dim].copy_from_slice(s2);
238 self.pos = (self.pos + 1) % self.capacity;
239 self.size = self.size.saturating_add(1).min(self.capacity);
240 }
241
242 fn can_sample(&self, batch_size: usize) -> bool {
243 self.size >= batch_size
244 }
245
246 fn sample(
247 &self,
248 batch_size: usize,
249 rng: &mut SmallRng,
250 ) -> (Tensor, Vec<usize>, Tensor, Tensor, Tensor) {
251 let mut s_vec = Vec::with_capacity(batch_size * self.state_dim);
252 let mut a_idx = Vec::with_capacity(batch_size);
253 let mut r_vec = Vec::with_capacity(batch_size);
254 let mut d_vec = Vec::with_capacity(batch_size);
255 let mut s2_vec = Vec::with_capacity(batch_size * self.state_dim);
256 for _ in 0..batch_size {
257 let idx = rng.sample_index(self.size);
258 let so = idx * self.state_dim;
259 s_vec.extend_from_slice(&self.states[so..so + self.state_dim]);
260 a_idx.push(self.actions[idx]);
261 r_vec.push(self.rewards[idx]);
262 d_vec.push(self.dones[idx]);
263 s2_vec.extend_from_slice(&self.next_states[so..so + self.state_dim]);
264 }
265 let s = Tensor::from_slice(&s_vec, vec![batch_size, self.state_dim]).unwrap();
266 let r = Tensor::from_slice(&r_vec, vec![batch_size, 1]).unwrap();
267 let d = Tensor::from_slice(&d_vec, vec![batch_size, 1]).unwrap();
268 let s2 = Tensor::from_slice(&s2_vec, vec![batch_size, self.state_dim]).unwrap();
269 (s, a_idx, r, d, s2)
270 }
271}
272
273fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
278 let mut total_sq = 0.0f32;
279 for p in parameters.iter() {
280 if let Some(g) = p.grad_owned() {
281 for &v in g.data() {
282 total_sq += v * v;
283 }
284 }
285 }
286 let norm = total_sq.sqrt();
287 if norm > max_norm {
288 let scale = max_norm / (norm + eps);
289 for p in parameters.iter_mut() {
290 if let Some(g) = p.grad_owned() {
291 p.set_grad(g.mul_scalar(scale));
292 }
293 }
294 }
295}
296
297fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
298 let mut total_sq = 0.0f32;
299 for p in parameters.iter_mut() {
300 if let Some(g) = p.grad_owned() {
301 for &v in g.data() {
302 total_sq += v * v;
303 }
304 }
305 }
306 total_sq.sqrt()
307}
308
309fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
310 let _ng = NoGradTrack::new();
311 let mut total_sq = 0.0f32;
312 for p in parameters.iter_mut() {
313 for &v in p.data() {
314 total_sq += v * v;
315 }
316 }
317 total_sq.sqrt()
318}
319
320fn pseudo_huber_mean(diff: &Tensor) -> Tensor {
322 diff.pow_scalar(2.0)
323 .add_scalar(1.0)
324 .sqrt()
325 .sub_scalar(1.0)
326 .mean()
327}
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334 println!("=== DQN Example (YardEnv discrete) ===");
335
336 let state_dim = 3usize;
338 let action_dim = 3usize;
339
340 let gamma = 0.99f32;
342 let batch_size = 64usize;
343 let start_steps = 200usize;
344 let target_update_interval = 200usize; let max_grad_norm = 1.0f32;
346 let mut epsilon = 1.0f32;
347 let eps_min = 0.05f32;
348 let eps_decay_steps = 2_000usize; let total_steps = std::env::var("DQN_STEPS")
350 .ok()
351 .and_then(|v| v.parse::<usize>().ok())
352 .unwrap_or(3000usize);
353
354 let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356 let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357 q_targ.net.copy_from(&q_net.net);
358 q_targ.set_requires_grad_all(false);
359
360 let mut q_opt = Adam::with_learning_rate(3e-4);
362 for p in q_net.parameters() {
363 q_opt.add_parameter(p);
364 }
365
366 let mut rb = ReplayBuffer::new(100_000, state_dim);
368 let mut env = YardEnv::new(12345);
369 let mut rng = SmallRng::new(999_111);
370
371 let mut state = env.reset();
373 let mut episode_return = 0.0f32;
374 let mut episode = 0usize;
375 let mut ema_return: Option<f32> = None;
376 let ema_alpha = 0.05f32;
377 let mut best_return = f32::NEG_INFINITY;
378
379 for t in 0..total_steps {
380 let action_index = if t < start_steps || rng.next_f32() < epsilon {
382 rng.sample_index(action_dim)
383 } else {
384 let _ng = NoGradTrack::new();
385 let q_vals = q_net.forward(&state);
386 let row = q_vals.data();
387 let mut best_i = 0usize;
388 let mut best_v = row[0];
389 for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390 if r > best_v {
391 best_v = r;
392 best_i = i;
393 }
394 }
395 best_i
396 };
397
398 let (next_state, reward, done) = env.step(action_index);
400 episode_return += reward;
401
402 let s_slice = state.data().to_vec();
404 let s2_slice = next_state.data().to_vec();
405 rb.push(
406 &s_slice,
407 action_index,
408 reward,
409 if done { 1.0 } else { 0.0 },
410 &s2_slice,
411 );
412
413 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return,
430 rb.size
431 );
432 episode_return = 0.0;
433 episode += 1;
434 st
435 } else {
436 next_state
437 };
438
439 if t < eps_decay_steps {
441 epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442 }
443
444 if rb.can_sample(batch_size) {
446 let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448 let target_q = {
450 let _ng = NoGradTrack::new();
451 let q_online_s2 = q_net.forward(&s2);
452 let row_stride = action_dim;
454 let qd = q_online_s2.data();
455 let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456 for i in 0..batch_size {
457 let base = i * row_stride;
458 let mut bi = 0usize;
459 let mut bv = qd[base];
460 for j in 1..action_dim {
461 let v = qd[base + j];
462 if v > bv {
463 bv = v;
464 bi = j;
465 }
466 }
467 next_actions.push(bi);
468 }
469 let q_targ_s2 = q_targ.forward(&s2);
470 let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471 let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472 r.add_tensor(¬_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473 };
474
475 {
478 let mut params = q_net.parameters();
479 q_opt.zero_grad(&mut params);
480 }
481
482 let q_all = q_net.forward(&s);
483 let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484 let diff = q_sa.sub_tensor(&target_q);
485 let mut loss = pseudo_huber_mean(&diff);
486 loss.backward(None);
487
488 {
490 let params = q_net.parameters();
491 let mut with_grads: Vec<&mut Tensor> = Vec::new();
492 for p in params {
493 if p.grad_owned().is_some() {
494 with_grads.push(p);
495 }
496 }
497 if !with_grads.is_empty() {
498 let gn = grad_global_norm(&mut with_grads);
499 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500 q_opt.step(&mut with_grads);
501 q_opt.zero_grad(&mut with_grads);
502 if t % 100 == 0 {
503 let mut pn = q_net.parameters();
504 let pn_l2 = params_l2_norm(&mut pn);
505 let q_mean = q_all.mean().value();
506 println!(
507 "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508 t, loss.value(), q_mean, gn, pn_l2, epsilon
509 );
510 }
511 }
512 }
513
514 if t % target_update_interval == 0 {
516 q_targ.net.copy_from(&q_net.net);
517 }
518
519 clear_all_graphs_known();
521 }
522 }
523
524 println!("=== DQN training finished ===");
525 Ok(())
526}