ppo_discrete/
ppo_discrete.rs

1//! PPO (Proximal Policy Optimization) - Discrete actions example using Train Station public API
2//!
3//! - Discrete `YardEnv` (3 actions: -1, 0, +1)
4//! - Actor outputs logits over actions (softmax for probabilities), Critic outputs value
5//! - Trajectory collection, GAE advantages, PPO clipped surrogate objective
6//! - Gradient clipping, zero_grad, clear_all_graphs between updates
7//! - Reuses `basic_linear_layer.rs`; no unsafe code
8//!
9//! Run:
10//!   cargo run --release --example ppo_discrete
11
12use train_station::{
13    gradtrack::{clear_all_graphs_known, NoGradTrack},
14    optimizers::{Adam, Optimizer},
15    Tensor,
16};
17
18#[allow(clippy::duplicate_mod)]
19#[path = "../neural_networks/basic_linear_layer.rs"]
20mod basic_linear_layer;
21use basic_linear_layer::LinearLayer;
22
23// -------------------------------
24// Small RNG
25// -------------------------------
26
27struct SmallRng {
28    state: u64,
29}
30impl SmallRng {
31    fn new(seed: u64) -> Self {
32        Self { state: seed.max(1) }
33    }
34    fn next_u32(&mut self) -> u32 {
35        self.state = self.state.wrapping_mul(1664525).wrapping_add(1013904223);
36        (self.state >> 16) as u32
37    }
38    fn next_f32(&mut self) -> f32 {
39        (self.next_u32() as f32) / (u32::MAX as f32)
40    }
41}
42
43// -------------------------------
44// MLP
45// -------------------------------
46
47struct Mlp {
48    layers: Vec<LinearLayer>,
49}
50impl Mlp {
51    fn new(sizes: &[usize], seed: Option<u64>) -> Self {
52        let mut layers = Vec::new();
53        let mut s = seed;
54        for w in sizes.windows(2) {
55            layers.push(LinearLayer::new(w[0], w[1], s));
56            s = s.map(|v| v + 1);
57        }
58        Self { layers }
59    }
60    fn forward(&self, input: &Tensor) -> Tensor {
61        let mut current: Option<Tensor> = None;
62        for (i, layer) in self.layers.iter().enumerate() {
63            let out = if i == 0 {
64                layer.forward(input)
65            } else {
66                layer.forward(current.as_ref().unwrap())
67            };
68            let is_last = i + 1 == self.layers.len();
69            let out = if !is_last { out.relu() } else { out };
70            current = Some(out);
71        }
72        current.expect("MLP has at least one layer")
73    }
74    fn parameters(&mut self) -> Vec<&mut Tensor> {
75        self.layers
76            .iter_mut()
77            .flat_map(|l| l.parameters())
78            .collect()
79    }
80}
81
82// -------------------------------
83// Actor (logits) + Critic
84// -------------------------------
85
86struct Actor {
87    net: Mlp,
88}
89impl Actor {
90    fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
91        Self {
92            net: Mlp::new(&[state_dim, 64, 64, action_dim], seed),
93        }
94    }
95    fn forward(&self, state: &Tensor) -> Tensor {
96        self.net.forward(state)
97    } // logits [B, A]
98    fn parameters(&mut self) -> Vec<&mut Tensor> {
99        self.net.parameters()
100    }
101}
102
103struct Critic {
104    net: Mlp,
105}
106impl Critic {
107    fn new(state_dim: usize, seed: Option<u64>) -> Self {
108        Self {
109            net: Mlp::new(&[state_dim, 64, 64, 1], seed),
110        }
111    }
112    fn forward(&self, state: &Tensor) -> Tensor {
113        self.net.forward(state)
114    }
115    fn parameters(&mut self) -> Vec<&mut Tensor> {
116        self.net.parameters()
117    }
118}
119
120// -------------------------------
121// Discrete YardEnv (3 actions -> -1, 0, +1)
122// -------------------------------
123
124struct YardEnv {
125    pos: f32,
126    vel: f32,
127    steps: usize,
128    max_steps: usize,
129    rng: SmallRng,
130}
131impl YardEnv {
132    const ACTIONS: [f32; 3] = [-1.0, 0.0, 1.0];
133    fn new(seed: u64) -> Self {
134        let mut e = Self {
135            pos: 0.0,
136            vel: 0.0,
137            steps: 0,
138            max_steps: 200,
139            rng: SmallRng::new(seed),
140        };
141        e.reset();
142        e
143    }
144    fn reset(&mut self) -> Tensor {
145        self.pos = (self.rng.next_f32() * 1.0) - 0.5;
146        self.vel = (self.rng.next_f32() * 0.2) - 0.1;
147        self.steps = 0;
148        self.state_tensor()
149    }
150    fn state_tensor(&self) -> Tensor {
151        Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
152    }
153    fn step(&mut self, action_idx: usize) -> (Tensor, f32, bool) {
154        let a = Self::ACTIONS[action_idx.min(2)];
155        self.vel += 0.1 * a - 0.01 * self.pos;
156        self.pos += self.vel;
157        self.steps += 1;
158        let reward = -(self.pos * self.pos) - 0.05 * (a * a);
159        let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
160        (self.state_tensor(), reward, done)
161    }
162}
163
164// -------------------------------
165// Rollout storage
166// -------------------------------
167
168struct RolloutBatch {
169    states: Vec<f32>,
170    actions: Vec<usize>,
171    old_logps: Vec<f32>,
172    rewards: Vec<f32>,
173    dones: Vec<f32>,
174    values: Vec<f32>,
175    next_states: Vec<f32>,
176    _state_dim: usize,
177}
178impl RolloutBatch {
179    fn new(cap: usize, sd: usize) -> Self {
180        Self {
181            states: Vec::with_capacity(cap * sd),
182            actions: Vec::with_capacity(cap),
183            old_logps: Vec::with_capacity(cap),
184            rewards: Vec::with_capacity(cap),
185            dones: Vec::with_capacity(cap),
186            values: Vec::with_capacity(cap),
187            next_states: Vec::with_capacity(cap * sd),
188            _state_dim: sd,
189        }
190    }
191    #[allow(clippy::too_many_arguments)]
192    fn push(&mut self, s: &[f32], a: usize, lp: f32, r: f32, d: f32, v: f32, s2: &[f32]) {
193        self.states.extend_from_slice(s);
194        self.actions.push(a);
195        self.old_logps.push(lp);
196        self.rewards.push(r);
197        self.dones.push(d);
198        self.values.push(v);
199        self.next_states.extend_from_slice(s2);
200    }
201    fn len(&self) -> usize {
202        self.actions.len()
203    }
204}
205
206// -------------------------------
207// Helpers
208// -------------------------------
209
210#[allow(clippy::too_many_arguments)]
211fn compute_gae(
212    returns_out: &mut [f32],
213    adv_out: &mut [f32],
214    rewards: &[f32],
215    dones: &[f32],
216    values: &[f32],
217    next_values: &[f32],
218    gamma: f32,
219    lam: f32,
220) {
221    let n = rewards.len();
222    let mut gae = 0.0f32;
223    for t in (0..n).rev() {
224        let not_done = 1.0 - dones[t];
225        let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
226        gae = delta + gamma * lam * not_done * gae;
227        adv_out[t] = gae;
228        returns_out[t] = gae + values[t];
229    }
230}
231
232fn normalize_in_place(x: &mut [f32], eps: f32) {
233    let n = x.len() as f32;
234    if n <= 1.0 {
235        return;
236    }
237    let mean = x.iter().copied().sum::<f32>() / n;
238    let var = x
239        .iter()
240        .map(|v| {
241            let d = v - mean;
242            d * d
243        })
244        .sum::<f32>()
245        / n;
246    let std = (var + eps).sqrt();
247    for v in x.iter_mut() {
248        *v = (*v - mean) / std;
249    }
250}
251
252fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
253    let mut total_sq = 0.0f32;
254    for p in parameters.iter() {
255        if let Some(g) = p.grad_owned() {
256            for &v in g.data() {
257                total_sq += v * v;
258            }
259        }
260    }
261    let norm = total_sq.sqrt();
262    if norm > max_norm {
263        let scale = max_norm / (norm + eps);
264        for p in parameters.iter_mut() {
265            if let Some(g) = p.grad_owned() {
266                p.set_grad(g.mul_scalar(scale));
267            }
268        }
269    }
270}
271
272// log-softmax for selected actions: given logits [B,A] and actions Vec<usize> -> log_prob [B,1]
273fn log_prob_actions(
274    logits: &Tensor,
275    actions: &[usize],
276    batch: usize,
277    _action_dim: usize,
278) -> Tensor {
279    let max_logits = logits.max_dims(&[1], true); // [B,1]
280    let shifted = logits.sub_tensor(&max_logits);
281    let exp = shifted.exp();
282    let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283    let log_sum_exp = sum_exp.log(); // [B,1]
284    let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285                                                        // gather selected action log-probs
286    log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291    new_logp.sub_tensor(old_logp).exp()
292}
293
294// Clamp ratio to [1-clip, 1+clip] using ReLU-based clamp (no custom ops)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296    let b = ratio.shape().dims()[0];
297    let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298    let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299    let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300    high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304    let mut total_sq = 0.0f32;
305    for p in parameters.iter_mut() {
306        if let Some(g) = p.grad_owned() {
307            for &v in g.data() {
308                total_sq += v * v;
309            }
310        }
311    }
312    total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320    println!("=== PPO Discrete Example (YardEnv) ===");
321
322    let state_dim = 3usize;
323    let action_dim = 3usize;
324    let total_steps = std::env::var("PPOD_STEPS")
325        .ok()
326        .and_then(|v| v.parse::<usize>().ok())
327        .unwrap_or(3500usize);
328    let horizon = 128usize;
329    let epochs = 4usize;
330    let mini_batch_size = 64usize;
331    let gamma = 0.99f32;
332    let lam = 0.95f32;
333    let clip_eps = 0.2f32;
334    let vf_coef = 0.5f32;
335    let ent_coef = 0.0f32;
336    let max_grad_norm = 1.0f32;
337
338    let mut actor = Actor::new(state_dim, action_dim, Some(111));
339    let mut critic = Critic::new(state_dim, Some(222));
340    let mut actor_opt = Adam::with_learning_rate(3e-4);
341    for p in actor.parameters() {
342        actor_opt.add_parameter(p);
343    }
344    let mut critic_opt = Adam::with_learning_rate(3e-4);
345    for p in critic.parameters() {
346        critic_opt.add_parameter(p);
347    }
348
349    let mut env = YardEnv::new(1234);
350    let mut rng = SmallRng::new(98765);
351    let mut state = env.reset();
352    let mut episode_return = 0.0f32;
353    let mut episode = 0usize;
354    let mut ema_return: Option<f32> = None;
355    let ema_alpha = 0.05f32;
356    let mut best_return = f32::NEG_INFINITY;
357
358    let mut t = 0usize;
359    while t < total_steps {
360        let mut batch = RolloutBatch::new(horizon, state_dim);
361        for _ in 0..horizon {
362            // Actor logits and categorical sampling
363            let logits = actor.forward(&state); // [1, A]
364            let probs = logits.softmax(1); // [1, A]
365                                           // sample action from probs (CPU sampling)
366            let p = probs.data();
367            let (p0, p1, _p2) = (p[0], p[1], p[2]);
368            let u = rng.next_f32();
369            let a_idx = if u < p0 {
370                0
371            } else if u < p0 + p1 {
372                1
373            } else {
374                2
375            };
376
377            let old_logp = {
378                let _ng = NoGradTrack::new();
379                let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380                lp.data()[0]
381            };
382
383            // Step env
384            let (next_state, reward, done) = env.step(a_idx);
385            episode_return += reward;
386
387            // Critic value
388            let value_t = critic.forward(&state);
389            let value_v = value_t.data()[0];
390
391            batch.push(
392                state.data(),
393                a_idx,
394                old_logp,
395                reward,
396                if done { 1.0 } else { 0.0 },
397                value_v,
398                next_state.data(),
399            );
400
401            state = if done {
402                let st = env.reset();
403                ema_return = Some(match ema_return {
404                    None => episode_return,
405                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406                });
407                if episode_return > best_return {
408                    best_return = episode_return;
409                }
410                println!(
411                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412                    t,
413                    episode,
414                    episode_return,
415                    ema_return.unwrap_or(episode_return),
416                    best_return
417                );
418                episode_return = 0.0;
419                episode += 1;
420                st
421            } else {
422                next_state
423            };
424
425            t += 1;
426            if t >= total_steps {
427                break;
428            }
429        }
430
431        // Bootstrap values for GAE
432        let next_values: Vec<f32> = {
433            let mut out = Vec::with_capacity(batch.len());
434            for i in 0..batch.len() {
435                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437                out.push(critic.forward(&s2_t).data()[0]);
438            }
439            out
440        };
441
442        let mut returns = vec![0.0f32; batch.len()];
443        let mut adv = vec![0.0f32; batch.len()];
444        compute_gae(
445            &mut returns,
446            &mut adv,
447            &batch.rewards,
448            &batch.dones,
449            &batch.values,
450            &next_values,
451            gamma,
452            lam,
453        );
454        normalize_in_place(&mut adv, 1e-8);
455
456        // Tensors for training
457        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458        let actions_vec = batch.actions.clone();
459        let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463        // PPO epochs
464        let num_minibatches = batch.len().div_ceil(mini_batch_size);
465        for e in 0..epochs {
466            for mb in 0..num_minibatches {
467                let start = mb * mini_batch_size;
468                let end = (start + mini_batch_size).min(batch.len());
469                if start >= end {
470                    break;
471                }
472
473                // Views
474                let s_mb = states_t
475                    .slice_view(start * state_dim, 1, (end - start) * state_dim)
476                    .reshape(vec![(end - start) as i32, state_dim as i32]);
477                let oldlp_mb = old_logp_t
478                    .slice_view(start, 1, end - start)
479                    .reshape(vec![(end - start) as i32, 1]);
480                let ret_mb = returns_t
481                    .slice_view(start, 1, end - start)
482                    .reshape(vec![(end - start) as i32, 1]);
483                let adv_mb = adv_t
484                    .slice_view(start, 1, end - start)
485                    .reshape(vec![(end - start) as i32, 1]);
486                let a_slice = &actions_vec[start..end];
487
488                // Zero grads
489                {
490                    let mut ps = actor.parameters();
491                    actor_opt.zero_grad(&mut ps);
492                }
493                {
494                    let mut ps = critic.parameters();
495                    critic_opt.zero_grad(&mut ps);
496                }
497
498                // Forward
499                let logits_mb = actor.forward(&s_mb); // [B,A]
500                let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501                let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502                let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503                let pg1 = ratio.mul_tensor(&adv_mb);
504                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507                let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509                let v_pred = critic.forward(&s_mb);
510                let v_loss = v_pred
511                    .sub_tensor(&ret_mb)
512                    .pow_scalar(2.0)
513                    .mean()
514                    .mul_scalar(vf_coef);
515
516                // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517                let probs_mb = logits_mb.softmax(1);
518                let logp_all = probs_mb.add_scalar(1e-8).log();
519                let ent = probs_mb
520                    .mul_tensor(&logp_all)
521                    .sum_dims(&[1], true)
522                    .mul_scalar(-1.0)
523                    .mean()
524                    .mul_scalar(ent_coef);
525
526                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527                loss.backward(None);
528
529                // Step actor
530                {
531                    let params = actor.parameters();
532                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
533                    for p in params {
534                        if p.grad_owned().is_some() {
535                            with_grads.push(p);
536                        }
537                    }
538                    if !with_grads.is_empty() {
539                        let _ = grad_global_norm(&mut with_grads);
540                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541                        actor_opt.step(&mut with_grads);
542                        actor_opt.zero_grad(&mut with_grads);
543                    }
544                }
545
546                // Step critic
547                {
548                    let params = critic.parameters();
549                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
550                    for p in params {
551                        if p.grad_owned().is_some() {
552                            with_grads.push(p);
553                        }
554                    }
555                    if !with_grads.is_empty() {
556                        let _ = grad_global_norm(&mut with_grads);
557                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558                        critic_opt.step(&mut with_grads);
559                        critic_opt.zero_grad(&mut with_grads);
560                    }
561                }
562
563                if e == 0 && mb == 0 {
564                    println!(
565                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
566                        t,
567                        actor_loss.value(),
568                        v_loss.value()
569                    );
570                }
571
572                clear_all_graphs_known();
573            }
574        }
575    }
576
577    println!("=== PPO discrete training finished ===");
578    Ok(())
579}