use meganeura::graph::{Graph, NodeId};
use meganeura::nn;
pub struct Policy {
pub fc1: nn::Linear,
pub fc2: nn::Linear,
}
impl Policy {
pub fn new(g: &mut Graph, latent_dim: usize, action_dim: usize, hidden_dim: usize) -> Self {
Self {
fc1: nn::Linear::new(g, "policy.fc1", latent_dim, hidden_dim),
fc2: nn::Linear::no_bias(g, "policy.fc2", hidden_dim, action_dim),
}
}
pub fn forward(&self, g: &mut Graph, z: NodeId) -> NodeId {
let h = self.fc1.forward(g, z);
let h = g.relu(h);
self.fc2.forward(g, h)
}
}
pub struct ValueHead {
pub fc1: nn::Linear,
pub fc2: nn::Linear,
}
impl ValueHead {
pub fn new(g: &mut Graph, latent_dim: usize, hidden_dim: usize) -> Self {
Self {
fc1: nn::Linear::new(g, "value.fc1", latent_dim, hidden_dim),
fc2: nn::Linear::no_bias(g, "value.fc2", hidden_dim, 1),
}
}
pub fn forward(&self, g: &mut Graph, z: NodeId) -> NodeId {
let h = self.fc1.forward(g, z);
let h = g.relu(h);
self.fc2.forward(g, h)
}
}
pub fn build_policy_graph(latent_dim: usize, action_dim: usize, hidden_dim: usize) -> Graph {
let mut g = Graph::new();
let z = g.input("z", &[1, latent_dim]);
let action = g.input("action", &[1, action_dim]);
let value_target = g.input("value_target", &[1, 1]);
let policy = Policy::new(&mut g, latent_dim, action_dim, hidden_dim);
let logits = policy.forward(&mut g, z);
let value_head = ValueHead::new(&mut g, latent_dim, hidden_dim);
let value = value_head.forward(&mut g, z);
let policy_loss = g.cross_entropy_loss(logits, action);
let value_loss = g.mse_loss(value, value_target);
let total_loss = g.add(policy_loss, value_loss);
g.set_outputs(vec![total_loss, logits, value]);
g
}
pub fn build_continuous_policy_graph(
latent_dim: usize,
action_dim: usize,
hidden_dim: usize,
) -> Graph {
let mut g = Graph::new();
let z = g.input("z", &[1, latent_dim]);
let action = g.input("action", &[1, action_dim]);
let value_target = g.input("value_target", &[1, 1]);
let policy = Policy::new(&mut g, latent_dim, action_dim, hidden_dim);
let mean = policy.forward(&mut g, z);
let value_head = ValueHead::new(&mut g, latent_dim, hidden_dim);
let value = value_head.forward(&mut g, z);
let policy_loss = g.mse_loss(mean, action);
let value_loss = g.mse_loss(value, value_target);
let total_loss = g.add(policy_loss, value_loss);
g.set_outputs(vec![total_loss, mean, value]);
g
}
pub fn softmax_probs(logits: &[f32]) -> Vec<f32> {
let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
let sum: f32 = exp.iter().sum();
exp.iter().map(|&e| e / sum).collect()
}
pub fn sample_action<R: rand::Rng>(logits: &[f32], rng: &mut R) -> usize {
let probs = softmax_probs(logits);
let u: f32 = rng.random_range(0.0..1.0);
let mut cumulative = 0.0;
for (i, &p) in probs.iter().enumerate() {
cumulative += p;
if u < cumulative {
return i;
}
}
probs.len() - 1
}
pub fn entropy(logits: &[f32]) -> f32 {
let probs = softmax_probs(logits);
-probs
.iter()
.filter(|&&p| p > 1e-10)
.map(|&p| p * p.ln())
.sum::<f32>()
}
pub fn sample_gaussian_action<R: rand::Rng>(mu: &[f32], scale: f32, rng: &mut R) -> Vec<f32> {
use std::f32::consts::TAU;
mu.iter()
.map(|&m| {
let u1: f32 = rng.random_range(1e-7..1.0);
let u2: f32 = rng.random_range(0.0..1.0);
let noise = (-2.0 * u1.ln()).sqrt() * (TAU * u2).cos();
m + scale * noise
})
.collect()
}