1use meganeura::graph::{Graph, NodeId};
9use meganeura::nn;
10
11pub struct Policy {
13 pub fc1: nn::Linear,
14 pub fc2: nn::Linear,
15}
16
17impl Policy {
18 pub fn new(g: &mut Graph, latent_dim: usize, action_dim: usize, hidden_dim: usize) -> Self {
19 Self {
20 fc1: nn::Linear::new(g, "policy.fc1", latent_dim, hidden_dim),
21 fc2: nn::Linear::no_bias(g, "policy.fc2", hidden_dim, action_dim),
22 }
23 }
24
25 pub fn forward(&self, g: &mut Graph, z: NodeId) -> NodeId {
27 let h = self.fc1.forward(g, z);
28 let h = g.relu(h);
29 self.fc2.forward(g, h)
30 }
31}
32
33pub struct ValueHead {
35 pub fc1: nn::Linear,
36 pub fc2: nn::Linear,
37}
38
39impl ValueHead {
40 pub fn new(g: &mut Graph, latent_dim: usize, hidden_dim: usize) -> Self {
41 Self {
42 fc1: nn::Linear::new(g, "value.fc1", latent_dim, hidden_dim),
43 fc2: nn::Linear::no_bias(g, "value.fc2", hidden_dim, 1),
44 }
45 }
46
47 pub fn forward(&self, g: &mut Graph, z: NodeId) -> NodeId {
49 let h = self.fc1.forward(g, z);
50 let h = g.relu(h);
51 self.fc2.forward(g, h)
52 }
53}
54
55pub fn build_policy_graph(latent_dim: usize, action_dim: usize, hidden_dim: usize) -> Graph {
67 let mut g = Graph::new();
68 let z = g.input("z", &[1, latent_dim]);
69 let action = g.input("action", &[1, action_dim]);
70 let value_target = g.input("value_target", &[1, 1]);
71
72 let policy = Policy::new(&mut g, latent_dim, action_dim, hidden_dim);
73 let logits = policy.forward(&mut g, z);
74
75 let value_head = ValueHead::new(&mut g, latent_dim, hidden_dim);
76 let value = value_head.forward(&mut g, z);
77
78 let policy_loss = g.cross_entropy_loss(logits, action);
80 let value_loss = g.mse_loss(value, value_target);
81 let total_loss = g.add(policy_loss, value_loss);
82
83 g.set_outputs(vec![total_loss, logits, value]);
84 g
85}
86
87pub fn build_continuous_policy_graph(
105 latent_dim: usize,
106 action_dim: usize,
107 hidden_dim: usize,
108) -> Graph {
109 let mut g = Graph::new();
110 let z = g.input("z", &[1, latent_dim]);
111 let action = g.input("action", &[1, action_dim]);
112 let value_target = g.input("value_target", &[1, 1]);
113
114 let policy = Policy::new(&mut g, latent_dim, action_dim, hidden_dim);
117 let mean = policy.forward(&mut g, z);
118
119 let value_head = ValueHead::new(&mut g, latent_dim, hidden_dim);
120 let value = value_head.forward(&mut g, z);
121
122 let policy_loss = g.mse_loss(mean, action);
124 let value_loss = g.mse_loss(value, value_target);
125 let total_loss = g.add(policy_loss, value_loss);
126
127 g.set_outputs(vec![total_loss, mean, value]);
128 g
129}
130
131pub fn softmax_probs(logits: &[f32]) -> Vec<f32> {
133 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
134 let exp: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
135 let sum: f32 = exp.iter().sum();
136 exp.iter().map(|&e| e / sum).collect()
137}
138
139pub fn sample_action<R: rand::Rng>(logits: &[f32], rng: &mut R) -> usize {
141 let probs = softmax_probs(logits);
142 let u: f32 = rng.random_range(0.0..1.0);
143 let mut cumulative = 0.0;
144 for (i, &p) in probs.iter().enumerate() {
145 cumulative += p;
146 if u < cumulative {
147 return i;
148 }
149 }
150 probs.len() - 1
151}
152
153pub fn entropy(logits: &[f32]) -> f32 {
155 let probs = softmax_probs(logits);
156 -probs
157 .iter()
158 .filter(|&&p| p > 1e-10)
159 .map(|&p| p * p.ln())
160 .sum::<f32>()
161}
162
163pub fn sample_gaussian_action<R: rand::Rng>(mu: &[f32], scale: f32, rng: &mut R) -> Vec<f32> {
166 use std::f32::consts::TAU;
167 mu.iter()
168 .map(|&m| {
169 let u1: f32 = rng.random_range(1e-7..1.0);
170 let u2: f32 = rng.random_range(0.0..1.0);
171 let noise = (-2.0 * u1.ln()).sqrt() * (TAU * u2).cos();
172 m + scale * noise
173 })
174 .collect()
175}