use meganeura::graph::{Graph, NodeId};
use meganeura::nn;
pub struct WorldModel {
pub z_proj: nn::Linear,
pub a_proj: nn::Linear,
pub fc2: nn::Linear,
pub fc_out: nn::Linear,
}
impl WorldModel {
pub fn new(g: &mut Graph, latent_dim: usize, action_dim: usize, hidden_dim: usize) -> Self {
Self {
z_proj: nn::Linear::new(g, "world_model.z_proj", latent_dim, hidden_dim),
a_proj: nn::Linear::no_bias(g, "world_model.a_proj", action_dim, hidden_dim),
fc2: nn::Linear::new(g, "world_model.fc2", hidden_dim, hidden_dim),
fc_out: nn::Linear::no_bias(g, "world_model.fc_out", hidden_dim, latent_dim),
}
}
pub fn forward(&self, g: &mut Graph, z_t: NodeId, action: NodeId) -> NodeId {
let h_z = self.z_proj.forward(g, z_t);
let h_a = self.a_proj.forward(g, action);
let h = g.add(h_z, h_a);
let h = g.relu(h);
let h = self.fc2.forward(g, h);
let h = g.relu(h);
self.fc_out.forward(g, h)
}
pub fn loss(g: &mut Graph, z_pred: NodeId, z_target: NodeId) -> NodeId {
g.mse_loss(z_pred, z_target)
}
}