1use train_station::{
14 gradtrack::{clear_all_graphs_known, NoGradTrack},
15 optimizers::{Adam, Optimizer},
16 Tensor,
17};
18
19#[allow(clippy::duplicate_mod)]
21#[path = "../neural_networks/basic_linear_layer.rs"]
22mod basic_linear_layer;
23use basic_linear_layer::LinearLayer;
24
25struct SmallRng {
31 state: u64,
32}
33
34impl SmallRng {
35 fn new(seed: u64) -> Self {
36 Self { state: seed.max(1) }
37 }
38 fn next_u32(&mut self) -> u32 {
39 self.state = self.state.wrapping_mul(1664525).wrapping_add(1013904223);
41 (self.state >> 16) as u32
42 }
43 fn next_f32(&mut self) -> f32 {
44 (self.next_u32() as f32) / (u32::MAX as f32)
45 }
46 fn uniform(&mut self, low: f32, high: f32) -> f32 {
47 low + (high - low) * self.next_f32()
48 }
49 fn sample_index(&mut self, upper_exclusive: usize) -> usize {
50 (self.next_u32() as usize) % upper_exclusive.max(1)
51 }
52}
53
54fn tanh_bounded(x: &Tensor) -> Tensor {
55 x.tanh()
56}
57
58struct Mlp {
63 layers: Vec<LinearLayer>,
64}
65
66impl Mlp {
67 fn new(sizes: &[usize], seed: Option<u64>) -> Self {
68 assert!(sizes.len() >= 2);
69 let mut layers = Vec::new();
70 let mut s = seed;
71 for w in sizes.windows(2) {
72 layers.push(LinearLayer::new(w[0], w[1], s));
73 s = s.map(|v| v + 1);
74 }
75 Self { layers }
76 }
77
78 fn forward(&self, input: &Tensor, final_activation: Option<fn(&Tensor) -> Tensor>) -> Tensor {
79 let mut current: Option<Tensor> = None;
80 for (i, layer) in self.layers.iter().enumerate() {
81 let out = if i == 0 {
82 layer.forward(input)
83 } else {
84 layer.forward(current.as_ref().unwrap())
85 };
86 let is_last = i + 1 == self.layers.len();
87 let out = if !is_last {
88 out.relu()
89 } else if let Some(act) = final_activation {
90 act(&out)
91 } else {
92 out
93 };
94 current = Some(out);
95 }
96 current.expect("MLP has at least one layer")
97 }
98
99 fn parameters(&mut self) -> Vec<&mut Tensor> {
100 let mut params = Vec::new();
101 for l in &mut self.layers {
102 params.extend(l.parameters());
103 }
104 params
105 }
106
107 fn set_requires_grad_all(&mut self, enable: bool) {
108 for l in &mut self.layers {
109 l.weight.set_requires_grad(enable);
110 l.bias.set_requires_grad(enable);
111 }
112 }
113
114 fn copy_from(&mut self, other: &Self) {
115 for (t, s) in self.layers.iter_mut().zip(other.layers.iter()) {
116 {
117 let src = s.weight.data();
118 let dst = t.weight.data_mut();
119 dst.copy_from_slice(src);
120 }
121 {
122 let src = s.bias.data();
123 let dst = t.bias.data_mut();
124 dst.copy_from_slice(src);
125 }
126 t.weight.set_requires_grad(false);
127 t.bias.set_requires_grad(false);
128 }
129 }
130
131 fn soft_update_from(&mut self, source: &Self, tau: f32) {
132 let _ng = NoGradTrack::new();
133 for (t, s) in self.layers.iter_mut().zip(source.layers.iter()) {
134 let new_w = t
136 .weight
137 .mul_scalar(1.0 - tau)
138 .add_tensor(&s.weight.mul_scalar(tau));
139 let new_b = t
140 .bias
141 .mul_scalar(1.0 - tau)
142 .add_tensor(&s.bias.mul_scalar(tau));
143 {
144 let src = new_w.data();
145 let dst = t.weight.data_mut();
146 dst.copy_from_slice(src);
147 }
148 {
149 let src = new_b.data();
150 let dst = t.bias.data_mut();
151 dst.copy_from_slice(src);
152 }
153 t.weight.set_requires_grad(false);
154 t.bias.set_requires_grad(false);
155 }
156 }
157}
158
159struct Actor {
164 net: Mlp,
165}
166
167impl Actor {
168 fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
169 let net = Mlp::new(&[state_dim, 64, 64, action_dim], seed);
171 Self { net }
172 }
173 fn forward(&self, state: &Tensor) -> Tensor {
174 self.net.forward(state, Some(tanh_bounded))
175 }
176 fn parameters(&mut self) -> Vec<&mut Tensor> {
177 self.net.parameters()
178 }
179 fn set_requires_grad_all(&mut self, enable: bool) {
180 self.net.set_requires_grad_all(enable);
181 }
182}
183
184struct Critic {
185 net: Mlp,
186}
187
188impl Critic {
189 fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
190 let net = Mlp::new(&[state_dim + action_dim, 64, 64, 1], seed);
191 Self { net }
192 }
193 fn forward(&self, state: &Tensor, action: &Tensor) -> Tensor {
194 let s_view = state.view(state.shape().dims().iter().map(|&d| d as i32).collect());
197 let a_view = action.view(action.shape().dims().iter().map(|&d| d as i32).collect());
198 let sa = Tensor::cat(&[s_view, a_view], 1);
199 self.net.forward(&sa, None)
200 }
201 fn parameters(&mut self) -> Vec<&mut Tensor> {
202 self.net.parameters()
203 }
204 fn set_requires_grad_all(&mut self, enable: bool) {
205 self.net.set_requires_grad_all(enable);
206 }
207}
208
209struct YardEnv {
217 pos: f32,
218 vel: f32,
219 steps: usize,
220 max_steps: usize,
221 rng: SmallRng,
222}
223
224impl YardEnv {
225 fn new(seed: u64) -> Self {
226 let mut env = Self {
227 pos: 0.0,
228 vel: 0.0,
229 steps: 0,
230 max_steps: 200,
231 rng: SmallRng::new(seed),
232 };
233 env.reset();
234 env
235 }
236
237 fn reset(&mut self) -> Tensor {
238 self.pos = self.rng.uniform(-0.5, 0.5);
239 self.vel = self.rng.uniform(-0.1, 0.1);
240 self.steps = 0;
241 self.state_tensor()
242 }
243
244 fn state_tensor(&self) -> Tensor {
245 let pos_n = self.pos / 3.0;
249 let vel_n = self.vel.clamp(-1.0, 1.0);
250 Tensor::from_slice(&[pos_n, vel_n, 0.0], vec![1, 3]).unwrap()
251 }
252
253 fn step(&mut self, action_value: f32) -> (Tensor, f32, bool) {
254 let a = action_value.clamp(-1.0, 1.0);
255 self.vel += 0.1 * a - 0.01 * self.pos;
256 self.pos += self.vel;
257 self.steps += 1;
258
259 let reward = -(self.pos * self.pos) - 0.1 * (a * a);
260 let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
261 (self.state_tensor(), reward, done)
262 }
263}
264
265struct ReplayBuffer {
270 capacity: usize,
271 size: usize,
272 pos: usize,
273 state_dim: usize,
274 action_dim: usize,
275 states: Vec<f32>,
276 actions: Vec<f32>,
277 rewards: Vec<f32>,
278 dones: Vec<f32>,
279 next_states: Vec<f32>,
280}
281
282impl ReplayBuffer {
283 fn new(capacity: usize, state_dim: usize, action_dim: usize) -> Self {
284 Self {
285 capacity,
286 size: 0,
287 pos: 0,
288 state_dim,
289 action_dim,
290 states: vec![0.0; capacity * state_dim],
291 actions: vec![0.0; capacity * action_dim],
292 rewards: vec![0.0; capacity],
293 dones: vec![0.0; capacity],
294 next_states: vec![0.0; capacity * state_dim],
295 }
296 }
297
298 fn push(&mut self, s: &[f32], a: &[f32], r: f32, d: f32, s2: &[f32]) {
299 let i = self.pos;
300 let so = i * self.state_dim;
301 let ao = i * self.action_dim;
302 self.states[so..so + self.state_dim].copy_from_slice(s);
303 self.actions[ao..ao + self.action_dim].copy_from_slice(a);
304 self.rewards[i] = r;
305 self.dones[i] = d;
306 self.next_states[so..so + self.state_dim].copy_from_slice(s2);
307
308 self.pos = (self.pos + 1) % self.capacity;
309 self.size = self.size.saturating_add(1).min(self.capacity);
310 }
311
312 fn can_sample(&self, batch_size: usize) -> bool {
313 self.size >= batch_size
314 }
315
316 fn sample(
317 &self,
318 batch_size: usize,
319 rng: &mut SmallRng,
320 ) -> (Tensor, Tensor, Tensor, Tensor, Tensor) {
321 let mut s_vec = Vec::with_capacity(batch_size * self.state_dim);
322 let mut a_vec = Vec::with_capacity(batch_size * self.action_dim);
323 let mut r_vec = Vec::with_capacity(batch_size);
324 let mut d_vec = Vec::with_capacity(batch_size);
325 let mut s2_vec = Vec::with_capacity(batch_size * self.state_dim);
326
327 for _ in 0..batch_size {
328 let idx = rng.sample_index(self.size);
329 let so = idx * self.state_dim;
330 let ao = idx * self.action_dim;
331 s_vec.extend_from_slice(&self.states[so..so + self.state_dim]);
332 a_vec.extend_from_slice(&self.actions[ao..ao + self.action_dim]);
333 r_vec.push(self.rewards[idx]);
334 d_vec.push(self.dones[idx]);
335 s2_vec.extend_from_slice(&self.next_states[so..so + self.state_dim]);
336 }
337
338 let s = Tensor::from_slice(&s_vec, vec![batch_size, self.state_dim]).unwrap();
339 let a = Tensor::from_slice(&a_vec, vec![batch_size, self.action_dim]).unwrap();
340 let r = Tensor::from_slice(&r_vec, vec![batch_size, 1]).unwrap();
341 let d = Tensor::from_slice(&d_vec, vec![batch_size, 1]).unwrap();
342 let s2 = Tensor::from_slice(&s2_vec, vec![batch_size, self.state_dim]).unwrap();
343 (s, a, r, d, s2)
344 }
345}
346
347fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
352 let mut total_sq = 0.0f32;
354 for p in parameters.iter() {
355 if let Some(g) = p.grad_owned() {
356 for &v in g.data() {
357 total_sq += v * v;
358 }
359 }
360 }
361 let norm = total_sq.sqrt();
362 if norm > max_norm {
363 let scale = max_norm / (norm + eps);
364 for p in parameters.iter_mut() {
365 if let Some(g) = p.grad_owned() {
366 let scaled = g.mul_scalar(scale);
367 p.set_grad(scaled);
368 }
369 }
370 }
371}
372
373fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
375 let mut total_sq = 0.0f32;
376 for p in parameters.iter_mut() {
377 if let Some(g) = p.grad_owned() {
378 for &v in g.data() {
379 total_sq += v * v;
380 }
381 }
382 }
383 total_sq.sqrt()
384}
385
386fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
388 let _ng = NoGradTrack::new();
389 let mut total_sq = 0.0f32;
390 for p in parameters.iter_mut() {
391 for &v in p.data() {
392 total_sq += v * v;
393 }
394 }
395 total_sq.sqrt()
396}
397
398pub fn main() -> Result<(), Box<dyn std::error::Error>> {
403 println!("=== TD3 Example (YardEnv) ===");
404
405 let state_dim = 3usize;
407 let action_dim = 1usize;
408
409 let gamma = 0.99f32;
411 let tau = 0.005f32; let policy_noise = 0.2f32; let exploration_noise = 0.1f32; let policy_delay = 2usize;
415 let batch_size = 64usize;
416 let start_steps = 500usize; let total_steps = 1500usize;
418 let max_grad_norm = 1.0f32;
419
420 let mut actor = Actor::new(state_dim, action_dim, Some(11));
422 let mut actor_targ = Actor::new(state_dim, action_dim, Some(12));
423 actor_targ.net.copy_from(&actor.net);
424 actor_targ.set_requires_grad_all(false);
425
426 let mut critic1 = Critic::new(state_dim, action_dim, Some(21));
427 let mut critic2 = Critic::new(state_dim, action_dim, Some(22));
428 let mut critic1_targ = Critic::new(state_dim, action_dim, Some(23));
429 let mut critic2_targ = Critic::new(state_dim, action_dim, Some(24));
430 critic1_targ.net.copy_from(&critic1.net);
431 critic2_targ.net.copy_from(&critic2.net);
432 critic1_targ.set_requires_grad_all(false);
433 critic2_targ.set_requires_grad_all(false);
434
435 let mut actor_opt = Adam::with_learning_rate(1e-3);
437 for p in actor.parameters() {
438 actor_opt.add_parameter(p);
439 }
440
441 let mut critic_opt = Adam::with_learning_rate(1e-4);
442 for p in critic1.parameters() {
443 critic_opt.add_parameter(p);
444 }
445 for p in critic2.parameters() {
446 critic_opt.add_parameter(p);
447 }
448
449 let mut rb = ReplayBuffer::new(100_000, state_dim, action_dim);
451 let mut env = YardEnv::new(1234);
452 let mut rng = SmallRng::new(987654321);
453
454 let mut state = env.reset(); let mut episode_return = 0.0f32;
457 let mut episode = 0usize;
458 let mut ema_return: Option<f32> = None;
459 let ema_alpha = 0.05f32; let mut best_return = f32::NEG_INFINITY;
461 let mut policy_updates: usize = 0;
462
463 for t in 0..total_steps {
464 let action_tensor = if t < start_steps {
466 let a = rng.uniform(-1.0, 1.0);
467 Tensor::from_slice(&[a], vec![1, action_dim]).unwrap()
468 } else {
469 let _ng = NoGradTrack::new();
471 let det = actor.forward(&state);
472 let noise = Tensor::randn(vec![1, action_dim], None).mul_scalar(exploration_noise);
473 tanh_bounded(&det.add_tensor(&noise))
474 };
475 let action_value = action_tensor.data()[0];
476
477 let (next_state, reward, done) = env.step(action_value);
479 episode_return += reward;
480
481 let s_slice = state.data().to_vec();
483 let a_slice = action_tensor.data().to_vec();
484 let s2_slice = next_state.data().to_vec();
485 rb.push(
486 &s_slice,
487 &a_slice,
488 reward,
489 if done { 1.0 } else { 0.0 },
490 &s2_slice,
491 );
492
493 state = if done {
494 let st = env.reset();
495 ema_return = Some(match ema_return {
497 None => episode_return,
498 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
499 });
500 if episode_return > best_return {
501 best_return = episode_return;
502 }
503 println!(
504 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={} | policy_updates={}",
505 t,
506 episode,
507 episode_return,
508 ema_return.unwrap_or(episode_return),
509 best_return,
510 rb.size,
511 policy_updates
512 );
513 episode_return = 0.0;
514 episode += 1;
515 st
516 } else {
517 next_state
518 };
519
520 if rb.can_sample(batch_size) {
522 let (s, a, r, d, s2) = rb.sample(batch_size, &mut rng);
524
525 let target_q = {
527 let _ng = NoGradTrack::new();
528 let noise =
530 Tensor::randn(vec![batch_size, action_dim], None).mul_scalar(policy_noise);
531 let a_targ = tanh_bounded(&actor_targ.forward(&s2).add_tensor(&noise));
532 let q1_t = critic1_targ.forward(&s2, &a_targ);
533 let q2_t = critic2_targ.forward(&s2, &a_targ);
534
535 let q1d = q1_t.data();
537 let q2d = q2_t.data();
538 let mut min_vec = Vec::with_capacity(batch_size);
539 for i in 0..batch_size {
540 let v1 = q1d[i];
541 let v2 = q2d[i];
542 min_vec.push(v1.min(v2));
543 }
544 let min_q = Tensor::from_slice(&min_vec, vec![batch_size, 1]).unwrap();
545 let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
546 r.add_tensor(¬_done.mul_scalar(gamma).mul_tensor(&min_q))
547 };
548
549 {
552 let mut params = {
553 let c_params = critic1.parameters();
554 let c2_params = critic2.parameters();
555 let mut tmp: Vec<&mut Tensor> = Vec::new();
556 tmp.extend(c_params);
557 tmp.extend(c2_params);
558 tmp
559 };
560 critic_opt.zero_grad(&mut params);
561 }
562
563 let q1 = critic1.forward(&s, &a);
565 let q2 = critic2.forward(&s, &a);
566 let diff1 = q1.sub_tensor(&target_q);
567 let diff2 = q2.sub_tensor(&target_q);
568 let mut critic_loss = diff1
569 .pow_scalar(2.0)
570 .mean()
571 .add_tensor(&diff2.pow_scalar(2.0).mean());
572
573 critic_loss.backward(None);
575
576 {
578 let params = {
579 let c_params = critic1.parameters();
580 let c2_params = critic2.parameters();
581 let mut tmp: Vec<&mut Tensor> = Vec::new();
582 tmp.extend(c_params);
583 tmp.extend(c2_params);
584 tmp
585 };
586 let mut with_grads: Vec<&mut Tensor> = Vec::new();
587 for p in params {
588 if p.grad_owned().is_some() {
589 with_grads.push(p);
590 }
591 }
592 if !with_grads.is_empty() {
593 let grad_norm_before = grad_global_norm(&mut with_grads);
595 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
596 critic_opt.step(&mut with_grads);
597 critic_opt.zero_grad(&mut with_grads);
598
599 let mut for_norm_params = {
601 let c_params = critic1.parameters();
602 let c2_params = critic2.parameters();
603 let mut tmp: Vec<&mut Tensor> = Vec::new();
604 tmp.extend(c_params);
605 tmp.extend(c2_params);
606 tmp
607 };
608 let param_norm = params_l2_norm(&mut for_norm_params);
609
610 if t % 100 == 0 {
612 let q1_mean = q1.mean().value();
613 let q2_mean = q2.mean().value();
614 let tq_mean = target_q.mean().value();
615 println!(
616 "t={:5} | critic_loss={:.4} | q1_mean={:.3} q2_mean={:.3} tq_mean={:.3} | grad_norm={:.3} | crit_param_norm={:.3}",
617 t,
618 critic_loss.value(),
619 q1_mean,
620 q2_mean,
621 tq_mean,
622 grad_norm_before,
623 param_norm
624 );
625 }
626 }
627 }
628
629 if t % policy_delay == 0 {
631 {
634 let mut a_params: Vec<&mut Tensor> = actor.parameters();
635 actor_opt.zero_grad(&mut a_params);
636 }
637
638 let a_pred = actor.forward(&s);
639 let q_for_actor = critic1.forward(&s, &a_pred);
640 let mut actor_loss = q_for_actor.mul_scalar(-1.0).mean();
641 actor_loss.backward(None);
642
643 {
644 let a_params: Vec<&mut Tensor> = actor.parameters();
645 let mut with_grads: Vec<&mut Tensor> = Vec::new();
646 for p in a_params {
647 if p.grad_owned().is_some() {
648 with_grads.push(p);
649 }
650 }
651 if !with_grads.is_empty() {
652 let grad_norm_before = grad_global_norm(&mut with_grads);
653 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
654 actor_opt.step(&mut with_grads);
655 actor_opt.zero_grad(&mut with_grads);
656
657 let mut for_norm_params = actor.parameters();
659 let param_norm = params_l2_norm(&mut for_norm_params);
660
661 policy_updates += 1;
662 if t % 200 == 0 {
663 println!(
664 "t={:5} | actor_loss={:.4} | act_grad_norm={:.3} | act_param_norm={:.3} | lr_a={:.4e} lr_c={:.4e} | policy_updates={}",
665 t,
666 actor_loss.value(),
667 grad_norm_before,
668 param_norm,
669 actor_opt.learning_rate(),
670 critic_opt.learning_rate(),
671 policy_updates
672 );
673 }
674 }
675 }
676
677 actor_targ.net.soft_update_from(&actor.net, tau);
679 critic1_targ.net.soft_update_from(&critic1.net, tau);
680 critic2_targ.net.soft_update_from(&critic2.net, tau);
681 }
682
683 clear_all_graphs_known();
685 }
686 }
687
688 println!("=== TD3 training finished ===");
689 Ok(())
690}