use burn::prelude::*;
use burn::tensor::Distribution;
use gpc_core::config::PolicyConfig;
use gpc_core::noise::DdpmSchedule;
use gpc_core::tensor_utils;
use gpc_core::traits::Policy;
use crate::network::{DenoisingNetwork, DenoisingNetworkConfig};
#[derive(Module, Debug)]
pub struct DiffusionPolicy<B: Backend> {
network: DenoisingNetwork<B>,
#[module(skip)]
flat_action_dim: usize,
#[module(skip)]
cond_dim: usize,
#[module(skip)]
time_embed_dim: usize,
#[module(skip)]
pred_horizon: usize,
#[module(skip)]
action_dim: usize,
}
#[derive(Config, Debug)]
pub struct DiffusionPolicyConfig {
pub obs_dim: usize,
pub action_dim: usize,
#[config(default = 2)]
pub obs_horizon: usize,
#[config(default = 16)]
pub pred_horizon: usize,
#[config(default = 256)]
pub hidden_dim: usize,
#[config(default = 128)]
pub time_embed_dim: usize,
#[config(default = 3)]
pub num_res_blocks: usize,
}
impl DiffusionPolicyConfig {
pub fn from_policy_config(config: &PolicyConfig) -> Self {
Self {
obs_dim: config.obs_dim,
action_dim: config.action_dim,
obs_horizon: config.obs_horizon,
pred_horizon: config.pred_horizon,
hidden_dim: config.hidden_dim,
time_embed_dim: 128,
num_res_blocks: config.num_res_blocks,
}
}
pub fn init<B: Backend>(&self, device: &B::Device) -> DiffusionPolicy<B> {
let flat_action_dim = self.pred_horizon * self.action_dim;
let cond_dim = self.obs_horizon * self.obs_dim;
let network_config = DenoisingNetworkConfig {
input_dim: flat_action_dim,
cond_dim,
hidden_dim: self.hidden_dim,
time_embed_dim: self.time_embed_dim,
num_blocks: self.num_res_blocks,
};
DiffusionPolicy {
network: network_config.init(device),
flat_action_dim,
cond_dim,
time_embed_dim: self.time_embed_dim,
pred_horizon: self.pred_horizon,
action_dim: self.action_dim,
}
}
}
impl<B: Backend> DiffusionPolicy<B> {
pub fn predict_noise(
&self,
noisy_actions: Tensor<B, 2>,
obs_cond: Tensor<B, 2>,
timesteps: Tensor<B, 1>,
) -> Tensor<B, 2> {
let device = noisy_actions.device();
let time_emb = tensor_utils::timestep_embedding(×teps, self.time_embed_dim, &device);
self.network.forward(noisy_actions, obs_cond, time_emb)
}
fn ddpm_sample(
&self,
obs_cond: Tensor<B, 2>,
schedule: &DdpmSchedule,
device: &B::Device,
) -> Tensor<B, 3> {
let [batch_size, _] = obs_cond.dims();
let mut x_t = Tensor::<B, 2>::random(
[batch_size, self.flat_action_dim],
Distribution::Normal(0.0, 1.0),
device,
);
for t in (0..schedule.num_timesteps).rev() {
let timesteps =
Tensor::<B, 1>::from_floats(vec![t as f32; batch_size].as_slice(), device);
let predicted_noise = self.predict_noise(x_t.clone(), obs_cond.clone(), timesteps);
x_t = schedule.remove_noise(&x_t, &predicted_noise, t);
if t > 0 {
let noise = Tensor::<B, 2>::random(
[batch_size, self.flat_action_dim],
Distribution::Normal(0.0, 1.0),
device,
);
let sigma = (schedule.posterior_variance[t] as f32).sqrt();
x_t = x_t + noise * sigma;
}
}
x_t.reshape([batch_size, self.pred_horizon, self.action_dim])
}
}
impl<B: Backend> Policy<B> for DiffusionPolicy<B> {
fn sample(
&self,
obs_history: &Tensor<B, 3>,
device: &B::Device,
) -> gpc_core::Result<Tensor<B, 3>> {
let schedule = DdpmSchedule::new(&gpc_core::config::NoiseScheduleConfig::default());
let obs_cond = tensor_utils::flatten_last_two(obs_history.clone());
Ok(self.ddpm_sample(obs_cond, &schedule, device))
}
fn sample_k(
&self,
obs_history: &Tensor<B, 3>,
num_candidates: usize,
device: &B::Device,
) -> gpc_core::Result<Tensor<B, 3>> {
let schedule = DdpmSchedule::new(&gpc_core::config::NoiseScheduleConfig::default());
let obs_repeated = tensor_utils::repeat_batch(obs_history, num_candidates);
let obs_cond = tensor_utils::flatten_last_two(obs_repeated);
Ok(self.ddpm_sample(obs_cond, &schedule, device))
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use gpc_core::traits::Policy;
type TestBackend = NdArray;
#[test]
fn test_diffusion_policy_sample_shape() {
let device = <TestBackend as Backend>::Device::default();
let config = DiffusionPolicyConfig {
obs_dim: 10,
action_dim: 2,
obs_horizon: 2,
pred_horizon: 8,
hidden_dim: 32,
time_embed_dim: 16,
num_res_blocks: 1,
};
let policy = config.init::<TestBackend>(&device);
let obs = Tensor::<TestBackend, 3>::zeros([1, 2, 10], &device);
let actions = policy.sample(&obs, &device).unwrap();
assert_eq!(actions.dims(), [1, 8, 2]);
}
#[test]
fn test_diffusion_policy_sample_k_shape() {
let device = <TestBackend as Backend>::Device::default();
let config = DiffusionPolicyConfig {
obs_dim: 10,
action_dim: 2,
obs_horizon: 2,
pred_horizon: 8,
hidden_dim: 32,
time_embed_dim: 16,
num_res_blocks: 1,
};
let policy = config.init::<TestBackend>(&device);
let obs = Tensor::<TestBackend, 3>::zeros([1, 2, 10], &device);
let actions = policy.sample_k(&obs, 5, &device).unwrap();
assert_eq!(actions.dims(), [5, 8, 2]);
}
#[test]
fn test_predict_noise_shape() {
let device = <TestBackend as Backend>::Device::default();
let config = DiffusionPolicyConfig {
obs_dim: 10,
action_dim: 2,
obs_horizon: 2,
pred_horizon: 8,
hidden_dim: 32,
time_embed_dim: 16,
num_res_blocks: 1,
};
let policy = config.init::<TestBackend>(&device);
let noisy = Tensor::<TestBackend, 2>::zeros([4, 16], &device);
let cond = Tensor::<TestBackend, 2>::zeros([4, 20], &device);
let t = Tensor::<TestBackend, 1>::from_floats([5.0, 10.0, 50.0, 99.0], &device);
let noise_pred = policy.predict_noise(noisy, cond, t);
assert_eq!(noise_pred.dims(), [4, 16]);
}
}