1use train_station::{
13 gradtrack::clear_all_graphs_known,
14 optimizers::{Adam, Optimizer},
15 Tensor,
16};
17
18#[allow(clippy::duplicate_mod)]
19#[path = "../neural_networks/basic_linear_layer.rs"]
20mod basic_linear_layer;
21use basic_linear_layer::LinearLayer;
22
23struct SmallRng {
28 state: u64,
29}
30impl SmallRng {
31 fn new(seed: u64) -> Self {
32 Self { state: seed.max(1) }
33 }
34 fn next_u32(&mut self) -> u32 {
35 self.state = self.state.wrapping_mul(1664525).wrapping_add(1013904223);
36 (self.state >> 16) as u32
37 }
38 fn next_f32(&mut self) -> f32 {
39 (self.next_u32() as f32) / (u32::MAX as f32)
40 }
41 fn normal(&mut self) -> f32 {
42 let u1 = self.next_f32().clamp(1e-7, 1.0 - 1e-7);
44 let u2 = self.next_f32();
45 let r = (-2.0 * u1.ln()).sqrt();
46 let theta = 2.0 * std::f32::consts::PI * u2;
47 r * theta.cos()
48 }
49}
50
51struct Mlp {
56 layers: Vec<LinearLayer>,
57}
58impl Mlp {
59 fn new(sizes: &[usize], seed: Option<u64>) -> Self {
60 let mut layers = Vec::new();
61 let mut s = seed;
62 for w in sizes.windows(2) {
63 layers.push(LinearLayer::new(w[0], w[1], s));
64 s = s.map(|v| v + 1);
65 }
66 Self { layers }
67 }
68 fn forward(&self, input: &Tensor) -> Tensor {
69 let mut current: Option<Tensor> = None;
70 for (i, layer) in self.layers.iter().enumerate() {
71 let out = if i == 0 {
72 layer.forward(input)
73 } else {
74 layer.forward(current.as_ref().unwrap())
75 };
76 let is_last = i + 1 == self.layers.len();
77 let out = if !is_last { out.relu() } else { out };
78 current = Some(out);
79 }
80 current.expect("MLP has at least one layer")
81 }
82 fn parameters(&mut self) -> Vec<&mut Tensor> {
83 self.layers
84 .iter_mut()
85 .flat_map(|l| l.parameters())
86 .collect()
87 }
88}
89
90struct Actor {
95 net: Mlp,
96 log_std: Tensor, }
98impl Actor {
99 fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
100 let net = Mlp::new(&[state_dim, 64, 64, action_dim], seed);
101 let log_std = Tensor::from_slice(&vec![0.0; action_dim], vec![action_dim])
102 .unwrap()
103 .with_requires_grad();
104 Self { net, log_std }
105 }
106 fn forward(&self, state: &Tensor) -> (Tensor, Tensor) {
107 let mean = self.net.forward(state);
109 (
110 mean,
111 self.log_std
112 .view(vec![1, self.log_std.shape().dims()[0] as i32]),
113 )
114 }
115 fn parameters(&mut self) -> Vec<&mut Tensor> {
116 let mut ps = self.net.parameters();
117 ps.push(&mut self.log_std);
118 ps
119 }
120}
121
122struct Critic {
127 net: Mlp,
128}
129impl Critic {
130 fn new(state_dim: usize, seed: Option<u64>) -> Self {
131 Self {
132 net: Mlp::new(&[state_dim, 64, 64, 1], seed),
133 }
134 }
135 fn forward(&self, state: &Tensor) -> Tensor {
136 self.net.forward(state)
137 }
138 fn parameters(&mut self) -> Vec<&mut Tensor> {
139 self.net.parameters()
140 }
141}
142
143struct YardEnv {
148 pos: f32,
149 vel: f32,
150 steps: usize,
151 max_steps: usize,
152 rng: SmallRng,
153}
154impl YardEnv {
155 fn new(seed: u64) -> Self {
156 let mut e = Self {
157 pos: 0.0,
158 vel: 0.0,
159 steps: 0,
160 max_steps: 200,
161 rng: SmallRng::new(seed),
162 };
163 e.reset();
164 e
165 }
166 fn reset(&mut self) -> Tensor {
167 self.pos = (self.rng.next_f32() * 1.0) - 0.5;
168 self.vel = (self.rng.next_f32() * 0.2) - 0.1;
169 self.steps = 0;
170 self.state_tensor()
171 }
172 fn state_tensor(&self) -> Tensor {
173 Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
174 }
175 fn step(&mut self, action_value: f32) -> (Tensor, f32, bool) {
176 let a = action_value.clamp(-1.0, 1.0);
177 self.vel += 0.1 * a - 0.01 * self.pos;
178 self.pos += self.vel;
179 self.steps += 1;
180 let reward = -(self.pos * self.pos) - 0.1 * (a * a);
181 let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
182 (self.state_tensor(), reward, done)
183 }
184}
185
186struct RolloutBatch {
191 states: Vec<f32>,
192 actions: Vec<f32>,
193 log_probs: Vec<f32>,
194 rewards: Vec<f32>,
195 dones: Vec<f32>,
196 values: Vec<f32>,
197 next_states: Vec<f32>,
198 _state_dim: usize,
199}
200impl RolloutBatch {
201 fn new(capacity: usize, state_dim: usize) -> Self {
202 Self {
203 states: Vec::with_capacity(capacity * state_dim),
204 actions: Vec::with_capacity(capacity),
205 log_probs: Vec::with_capacity(capacity),
206 rewards: Vec::with_capacity(capacity),
207 dones: Vec::with_capacity(capacity),
208 values: Vec::with_capacity(capacity),
209 next_states: Vec::with_capacity(capacity * state_dim),
210 _state_dim: state_dim,
211 }
212 }
213
214 #[allow(clippy::too_many_arguments)]
215 fn push(&mut self, s: &[f32], a: f32, lp: f32, r: f32, d: f32, v: f32, s2: &[f32]) {
216 self.states.extend_from_slice(s);
217 self.actions.push(a);
218 self.log_probs.push(lp);
219 self.rewards.push(r);
220 self.dones.push(d);
221 self.values.push(v);
222 self.next_states.extend_from_slice(s2);
223 }
224
225 fn len(&self) -> usize {
226 self.actions.len()
227 }
228}
229
230fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235 let std = log_std.exp();
237 let var = std.pow_scalar(2.0);
238 let log_scale = log_std;
239 let diff = action.sub_tensor(mean);
240 let log_prob = diff
241 .pow_scalar(2.0)
242 .div_tensor(&var)
243 .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244 .add_tensor(&log_scale.mul_scalar(2.0))
245 .mul_scalar(0.5)
246 .mul_scalar(-1.0);
247 log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253 returns_out: &mut [f32],
254 adv_out: &mut [f32],
255 rewards: &[f32],
256 dones: &[f32],
257 values: &[f32],
258 next_values: &[f32],
259 gamma: f32,
260 lam: f32,
261) {
262 let n = rewards.len();
263 let mut gae = 0.0f32;
264 for t in (0..n).rev() {
265 let not_done = 1.0 - dones[t];
266 let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267 gae = delta + gamma * lam * not_done * gae;
268 adv_out[t] = gae;
269 returns_out[t] = gae + values[t];
270 }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274 let n = x.len() as f32;
275 if n <= 1.0 {
276 return;
277 }
278 let mean = x.iter().copied().sum::<f32>() / n;
279 let var = x
280 .iter()
281 .map(|v| {
282 let d = v - mean;
283 d * d
284 })
285 .sum::<f32>()
286 / n;
287 let std = (var + eps).sqrt();
288 for v in x.iter_mut() {
289 *v = (*v - mean) / std;
290 }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294 let mut total_sq = 0.0f32;
295 for p in parameters.iter() {
296 if let Some(g) = p.grad_owned() {
297 for &v in g.data() {
298 total_sq += v * v;
299 }
300 }
301 }
302 let norm = total_sq.sqrt();
303 if norm > max_norm {
304 let scale = max_norm / (norm + eps);
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 p.set_grad(g.mul_scalar(scale));
308 }
309 }
310 }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314 let mut total_sq = 0.0f32;
315 for p in parameters.iter_mut() {
316 if let Some(g) = p.grad_owned() {
317 for &v in g.data() {
318 total_sq += v * v;
319 }
320 }
321 }
322 total_sq.sqrt()
323}
324
325pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330 println!("=== PPO Continuous Example (YardEnv) ===");
331
332 let state_dim = 3usize;
333 let action_dim = 1usize;
334
335 let total_steps = std::env::var("PPO_STEPS")
337 .ok()
338 .and_then(|v| v.parse::<usize>().ok())
339 .unwrap_or(4000usize);
340 let horizon = 128usize; let epochs = 4usize; let mini_batch_size = 64usize; let gamma = 0.99f32;
344 let lam = 0.95f32; let clip_eps = 0.2f32;
346 let vf_coef = 0.5f32;
347 let ent_coef = 0.0f32;
348 let max_grad_norm = 1.0f32;
349
350 let mut actor = Actor::new(state_dim, action_dim, Some(101));
352 let mut critic = Critic::new(state_dim, Some(202));
353
354 let mut actor_opt = Adam::with_learning_rate(3e-4);
356 for p in actor.parameters() {
357 actor_opt.add_parameter(p);
358 }
359 let mut critic_opt = Adam::with_learning_rate(3e-4);
360 for p in critic.parameters() {
361 critic_opt.add_parameter(p);
362 }
363
364 let mut env = YardEnv::new(42);
366 let mut rng = SmallRng::new(999);
367 let mut state = env.reset();
368
369 let mut episode_return = 0.0f32;
371 let mut episode = 0usize;
372 let mut ema_return: Option<f32> = None;
373 let ema_alpha = 0.05f32;
374 let mut best_return = f32::NEG_INFINITY;
375
376 let mut t = 0usize;
377 while t < total_steps {
378 let mut batch = RolloutBatch::new(horizon, state_dim);
380 for _ in 0..horizon {
381 let (mean, log_std_row) = actor.forward(&state);
383 let mean_v = mean.data()[0];
384 let log_std_v = log_std_row.data()[0];
385 let std_v = log_std_v.exp();
386 let noise = rng.normal();
387 let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389 let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391 let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392 let log_prob_v = log_prob_t.data()[0];
393
394 let (next_state, reward, done) = env.step(action_v);
396 episode_return += reward;
397
398 let value_t = critic.forward(&state);
400 let value_v = value_t.data()[0];
401
402 batch.push(
404 state.data(),
405 action_v,
406 log_prob_v,
407 reward,
408 if done { 1.0 } else { 0.0 },
409 value_v,
410 next_state.data(),
411 );
412
413 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return
430 );
431 episode_return = 0.0;
432 episode += 1;
433 st
434 } else {
435 next_state
436 };
437
438 t += 1;
439 if t >= total_steps {
440 break;
441 }
442 }
443
444 let next_values: Vec<f32> = {
446 let mut out = Vec::with_capacity(batch.len());
447 for i in 0..batch.len() {
448 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450 let v2 = critic.forward(&s2_t).data()[0];
451 out.push(v2);
452 }
453 out
454 };
455
456 let mut returns = vec![0.0f32; batch.len()];
458 let mut adv = vec![0.0f32; batch.len()];
459 compute_gae(
460 &mut returns,
461 &mut adv,
462 &batch.rewards,
463 &batch.dones,
464 &batch.values,
465 &next_values,
466 gamma,
467 lam,
468 );
469 normalize_in_place(&mut adv, 1e-8);
470
471 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473 let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474 let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478 let num_minibatches = batch.len().div_ceil(mini_batch_size);
480 for e in 0..epochs {
481 for mb in 0..num_minibatches {
482 let start = mb * mini_batch_size;
483 let end = (start + mini_batch_size).min(batch.len());
484 if start >= end {
485 break;
486 }
487
488 let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490 let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491 let a_mb = actions_t
492 .slice_view(start * action_dim, 1, (end - start) * action_dim)
493 .reshape(vec![(end - start) as i32, action_dim as i32]);
494 let oldlp_mb = old_logp_t
495 .slice_view(start, 1, end - start)
496 .reshape(vec![(end - start) as i32, 1]);
497 let ret_mb = returns_t
498 .slice_view(start, 1, end - start)
499 .reshape(vec![(end - start) as i32, 1]);
500 let adv_mb = adv_t
501 .slice_view(start, 1, end - start)
502 .reshape(vec![(end - start) as i32, 1]);
503
504 {
506 let mut ps = actor.parameters();
507 actor_opt.zero_grad(&mut ps);
508 }
509 {
510 let mut ps = critic.parameters();
511 critic_opt.zero_grad(&mut ps);
512 }
513
514 let (mean_mb, log_std_row) = actor.forward(&s_mb);
516 let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517 let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); let clip_low =
519 Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520 .unwrap();
521 let clip_high =
522 Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523 .unwrap();
524 let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526 let ratio_clipped =
527 clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528 let pg1 = ratio.mul_tensor(&adv_mb);
529 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532 let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534 let v_pred = critic.forward(&s_mb);
535 let v_loss = v_pred
536 .sub_tensor(&ret_mb)
537 .pow_scalar(2.0)
538 .mean()
539 .mul_scalar(vf_coef);
540
541 let entropy = log_std_row
543 .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544 .sum_dims(&[1], true)
545 .mean()
546 .mul_scalar(ent_coef);
547
548 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549 loss.backward(None);
550
551 {
553 let params = actor.parameters();
554 let mut with_grads: Vec<&mut Tensor> = Vec::new();
555 for p in params {
556 if p.grad_owned().is_some() {
557 with_grads.push(p);
558 }
559 }
560 if !with_grads.is_empty() {
561 let _ = grad_global_norm(&mut with_grads);
562 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563 actor_opt.step(&mut with_grads);
564 actor_opt.zero_grad(&mut with_grads);
565 }
566 }
567
568 {
570 let params = critic.parameters();
571 let mut with_grads: Vec<&mut Tensor> = Vec::new();
572 for p in params {
573 if p.grad_owned().is_some() {
574 with_grads.push(p);
575 }
576 }
577 if !with_grads.is_empty() {
578 let _ = grad_global_norm(&mut with_grads);
579 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580 critic_opt.step(&mut with_grads);
581 critic_opt.zero_grad(&mut with_grads);
582 }
583 }
584
585 if e == 0 && mb == 0 {
587 println!(
588 "update@t={} | actor_loss={:.4} v_loss={:.4}",
589 t,
590 actor_loss.value(),
591 v_loss.value()
592 );
593 }
594
595 clear_all_graphs_known();
596 }
597 }
598 }
599
600 println!("=== PPO training finished ===");
601 Ok(())
602}