use meganeura::graph::{Graph, NodeId};
use meganeura::nn;
pub struct Encoder {
pub obs_proj: nn::Linear,
pub task_proj: nn::Linear,
pub norm: nn::RmsNorm,
pub fc2: nn::Linear,
}
impl Encoder {
pub fn new(
g: &mut Graph,
obs_dim: usize,
task_dim: usize,
latent_dim: usize,
hidden_dim: usize,
) -> Self {
Self {
obs_proj: nn::Linear::new(g, "encoder.obs_proj", obs_dim, hidden_dim),
task_proj: nn::Linear::no_bias(g, "encoder.task_proj", task_dim, hidden_dim),
norm: nn::RmsNorm::new(g, "encoder.norm.weight", hidden_dim, 1e-5),
fc2: nn::Linear::no_bias(g, "encoder.fc2", hidden_dim, latent_dim),
}
}
pub fn forward(&self, g: &mut Graph, obs: NodeId, task: NodeId) -> NodeId {
let h_obs = self.obs_proj.forward(g, obs);
let h_task = self.task_proj.forward(g, task);
let h = g.add(h_obs, h_task);
let h = g.relu(h);
let h = self.norm.forward(g, h);
self.fc2.forward(g, h)
}
}
pub struct CnnEncoder {
pub conv1: nn::Conv2d,
pub conv2: nn::Conv2d,
pub fc: nn::Linear,
pub batch: u32,
pub pool_channels: u32,
}
impl CnnEncoder {
pub fn new(
g: &mut Graph,
channels: u32,
height: u32,
width: u32,
latent_dim: usize,
batch: u32,
) -> Self {
let out_ch1 = 8u32;
let out_ch2 = 16u32;
let h1 = (height - 3 + 2) / 2 + 1; let w1 = (width - 3 + 2) / 2 + 1;
let conv1 = nn::Conv2d::new(
g,
"encoder.conv1",
channels,
out_ch1,
3,
height,
width,
2,
1,
);
let conv2 = nn::Conv2d::new(g, "encoder.conv2", out_ch1, out_ch2, 3, h1, w1, 2, 1);
let fc = nn::Linear::no_bias(g, "encoder.fc", out_ch2 as usize, latent_dim);
Self {
conv1,
conv2,
fc,
batch,
pool_channels: out_ch2,
}
}
pub fn forward(&self, g: &mut Graph, obs: NodeId) -> NodeId {
let h = self.conv1.forward(g, obs, self.batch);
let h = g.relu(h);
let h = self.conv2.forward(g, h, self.batch);
let h = g.relu(h);
let spatial = {
let shape = &g.node(h).ty.shape;
(shape[0] / (self.batch as usize * self.pool_channels as usize)) as u32
};
let h = g.global_avg_pool(h, self.batch, self.pool_channels, spatial);
self.fc.forward(g, h)
}
}