Skip to main content

kindle/
policy.rs

1//! Policy and Value Head.
2//!
3//! - **Policy** `π(z_t) → action distribution`, updated by credit-weighted
4//!   policy gradient with entropy bonus.
5//! - **Value Head** `V(z_t) → V̂`, trained via TD using the credit-adjusted
6//!   reward signal, serving as a variance-reduction baseline.
7
8use meganeura::graph::{Graph, NodeId};
9use meganeura::nn;
10
11/// Stochastic policy network for discrete action spaces.
12pub struct Policy {
13    pub fc1: nn::Linear,
14    pub fc2: nn::Linear,
15}
16
17impl Policy {
18    pub fn new(g: &mut Graph, latent_dim: usize, action_dim: usize, hidden_dim: usize) -> Self {
19        Self {
20            fc1: nn::Linear::new(g, "policy.fc1", latent_dim, hidden_dim),
21            fc2: nn::Linear::no_bias(g, "policy.fc2", hidden_dim, action_dim),
22        }
23    }
24
25    /// Forward pass: `[batch, latent_dim] → [batch, action_dim]` (logits).
26    pub fn forward(&self, g: &mut Graph, z: NodeId) -> NodeId {
27        let h = self.fc1.forward(g, z);
28        let h = g.relu(h);
29        self.fc2.forward(g, h)
30    }
31}
32
33/// Value head: estimates cumulative future reward from the current state.
34pub struct ValueHead {
35    pub fc1: nn::Linear,
36    pub fc2: nn::Linear,
37}
38
39impl ValueHead {
40    pub fn new(g: &mut Graph, latent_dim: usize, hidden_dim: usize) -> Self {
41        Self {
42            fc1: nn::Linear::new(g, "value.fc1", latent_dim, hidden_dim),
43            fc2: nn::Linear::no_bias(g, "value.fc2", hidden_dim, 1),
44        }
45    }
46
47    /// Forward pass: `[batch, latent_dim] → [batch, 1]`.
48    pub fn forward(&self, g: &mut Graph, z: NodeId) -> NodeId {
49        let h = self.fc1.forward(g, z);
50        let h = g.relu(h);
51        self.fc2.forward(g, h)
52    }
53}
54
55/// Build the discrete policy + value training graph.
56///
57/// Inputs:
58/// - `"z"`: `[1, latent_dim]` — latent from encoder (detached, fed as input)
59/// - `"action"`: `[1, action_dim]` — one-hot taken action
60/// - `"value_target"`: `[1, 1]` — TD target for value head
61///
62/// Outputs:
63/// - `[0]`: combined loss (policy cross-entropy + value MSE)
64/// - `[1]`: logits `[1, action_dim]` — for action sampling
65/// - `[2]`: value `[1, 1]` — for advantage computation
66pub fn build_policy_graph(latent_dim: usize, action_dim: usize, hidden_dim: usize) -> Graph {
67    let mut g = Graph::new();
68    let z = g.input("z", &[1, latent_dim]);
69    let action = g.input("action", &[1, action_dim]);
70    let value_target = g.input("value_target", &[1, 1]);
71
72    let policy = Policy::new(&mut g, latent_dim, action_dim, hidden_dim);
73    let logits = policy.forward(&mut g, z);
74
75    let value_head = ValueHead::new(&mut g, latent_dim, hidden_dim);
76    let value = value_head.forward(&mut g, z);
77
78    // Policy loss: cross-entropy with one-hot action selects -log π(a|s)
79    let policy_loss = g.cross_entropy_loss(logits, action);
80    let value_loss = g.mse_loss(value, value_target);
81    let total_loss = g.add(policy_loss, value_loss);
82
83    g.set_outputs(vec![total_loss, logits, value]);
84    g
85}
86
87/// Build the continuous policy + value training graph for a diagonal
88/// Gaussian with fixed unit variance.
89///
90/// Inputs:
91/// - `"z"`: `[1, latent_dim]` — latent from encoder
92/// - `"action"`: `[1, action_dim]` — the taken action vector
93/// - `"value_target"`: `[1, 1]` — TD target for value head
94///
95/// Outputs:
96/// - `[0]`: combined loss (mean MSE + value MSE)
97/// - `[1]`: action mean `[1, action_dim]` — sampled by adding Gaussian noise
98/// - `[2]`: value `[1, 1]`
99///
100/// For a fixed-variance Gaussian, the negative log-likelihood of the taken
101/// action is `0.5·(a − μ)² / σ² + const`. With σ² = 1 this reduces to the
102/// MSE between predicted mean and taken action, up to a constant — the
103/// same advantage-weighted LR trick applies.
104pub fn build_continuous_policy_graph(
105    latent_dim: usize,
106    action_dim: usize,
107    hidden_dim: usize,
108) -> Graph {
109    let mut g = Graph::new();
110    let z = g.input("z", &[1, latent_dim]);
111    let action = g.input("action", &[1, action_dim]);
112    let value_target = g.input("value_target", &[1, 1]);
113
114    // The "Policy" struct outputs [1, action_dim] logits; for continuous
115    // actions we reinterpret this as the Gaussian mean μ.
116    let policy = Policy::new(&mut g, latent_dim, action_dim, hidden_dim);
117    let mean = policy.forward(&mut g, z);
118
119    let value_head = ValueHead::new(&mut g, latent_dim, hidden_dim);
120    let value = value_head.forward(&mut g, z);
121
122    // Policy loss: MSE(μ, taken_action) ≡ Gaussian NLL with σ² = 1
123    let policy_loss = g.mse_loss(mean, action);
124    let value_loss = g.mse_loss(value, value_target);
125    let total_loss = g.add(policy_loss, value_loss);
126
127    g.set_outputs(vec![total_loss, mean, value]);
128    g
129}
130
131/// Compute softmax probabilities from logits.
132pub fn softmax_probs(logits: &[f32]) -> Vec<f32> {
133    let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
134    let exp: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
135    let sum: f32 = exp.iter().sum();
136    exp.iter().map(|&e| e / sum).collect()
137}
138
139/// Sample an action from logits using the Gumbel-max trick.
140pub fn sample_action<R: rand::Rng>(logits: &[f32], rng: &mut R) -> usize {
141    let probs = softmax_probs(logits);
142    let u: f32 = rng.random_range(0.0..1.0);
143    let mut cumulative = 0.0;
144    for (i, &p) in probs.iter().enumerate() {
145        cumulative += p;
146        if u < cumulative {
147            return i;
148        }
149    }
150    probs.len() - 1
151}
152
153/// Compute policy entropy: `H[π] = -Σ π_i · log π_i`.
154pub fn entropy(logits: &[f32]) -> f32 {
155    let probs = softmax_probs(logits);
156    -probs
157        .iter()
158        .filter(|&&p| p > 1e-10)
159        .map(|&p| p * p.ln())
160        .sum::<f32>()
161}
162
163/// Sample from a diagonal Gaussian with mean `mu` and fixed std `scale`.
164/// Uses the Box–Muller transform.
165pub fn sample_gaussian_action<R: rand::Rng>(mu: &[f32], scale: f32, rng: &mut R) -> Vec<f32> {
166    use std::f32::consts::TAU;
167    mu.iter()
168        .map(|&m| {
169            let u1: f32 = rng.random_range(1e-7..1.0);
170            let u2: f32 = rng.random_range(0.0..1.0);
171            let noise = (-2.0 * u1.ln()).sqrt() * (TAU * u2).cos();
172            m + scale * noise
173        })
174        .collect()
175}