pub mod kernel;
pub mod replay_buffer;
pub use kernel::*;
pub use replay_buffer::*;
use crate::logging::{EpochLogger, SessionLogger};
use crate::templates::base_algorithm::{AlgorithmError, AlgorithmTrait, TrajectoryData};
use crate::templates::base_replay_buffer::{Batch, GenericReplayBuffer, ReplayBufferError};
use burn_tensor::TensorKind;
use burn_tensor::backend::Backend;
use relayrl_types::prelude::tensor::relayrl::BackendMatcher;
use relayrl_types::prelude::trajectory::RelayRLTrajectory;
use std::any::Any;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::path::{Path, PathBuf};
type AgentKey = String;
const DEFAULT_AGENT_KEY: &str = "__default_maddpg_agent__";
fn resolve_agent_key(trajectory: &RelayRLTrajectory) -> AgentKey {
trajectory
.get_agent_id()
.map(|id| id.to_string())
.or_else(|| {
trajectory
.actions
.iter()
.find_map(|a| a.get_agent_id().map(|id| id.to_string()))
})
.unwrap_or_else(|| DEFAULT_AGENT_KEY.to_string())
}
fn sample_buffer_blocking<RB: GenericReplayBuffer>(
buffer: &RB,
) -> Result<Batch, ReplayBufferError> {
if tokio::runtime::Handle::try_current().is_ok() {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(buffer.sample_buffer())
})
} else {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| ReplayBufferError::BufferSamplingError(e.to_string()))?
.block_on(buffer.sample_buffer())
}
}
#[derive(Default)]
struct AgentRegistry {
indices: HashMap<AgentKey, usize>,
}
impl AgentRegistry {
fn get(&self, key: &str) -> Option<usize> {
self.indices.get(key).copied()
}
fn insert(&mut self, key: AgentKey, index: usize) {
self.indices.insert(key, index);
}
fn len(&self) -> usize {
self.indices.len()
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct MADDPGParams {
pub gamma: f32,
pub tau: f32,
pub actor_lr: f32,
pub critic_lr: f32,
pub batch_size: usize,
pub buffer_size: usize,
pub policy_frequency: usize,
pub traj_per_epoch: u64,
pub train_iters: usize,
pub noise_scale: f32,
}
impl Default for MADDPGParams {
fn default() -> Self {
Self {
gamma: 0.99,
tau: 0.01,
actor_lr: 3e-4,
critic_lr: 3e-4,
batch_size: 128,
buffer_size: 1_000_000,
policy_frequency: 1,
traj_per_epoch: 8,
train_iters: 50,
noise_scale: 0.1,
}
}
}
#[allow(dead_code)]
struct RuntimeArgs {
env_dir: PathBuf,
save_model_path: PathBuf,
obs_dim: usize,
act_dim: usize,
buffer_size: usize,
}
impl Default for RuntimeArgs {
fn default() -> Self {
Self {
env_dir: PathBuf::from(""),
save_model_path: PathBuf::from(""),
obs_dim: 1,
act_dim: 1,
buffer_size: 1_000_000,
}
}
}
struct AgentRuntimeSlot {
#[allow(dead_code)]
agent_key: AgentKey,
trajectory_count: u64,
replay_buffer: MultiagentDDPGReplayBuffer,
}
impl AgentRuntimeSlot {
fn new(agent_key: AgentKey, replay_buffer: MultiagentDDPGReplayBuffer) -> Self {
Self {
agent_key,
trajectory_count: 0,
replay_buffer,
}
}
}
struct RuntimeComponents<
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
KN: MultiagentDDPGKernelTrait<B, InK, OutK>,
> {
epoch_logger: EpochLogger,
epoch_count: u64,
agent_registry: AgentRegistry,
agent_slots: Vec<AgentRuntimeSlot>,
kernel: KN,
_phantom: PhantomData<(B, InK, OutK)>,
}
impl<B, InK, OutK, KN> Default for RuntimeComponents<B, InK, OutK, KN>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
KN: MultiagentDDPGKernelTrait<B, InK, OutK> + Default,
{
fn default() -> Self {
Self {
epoch_logger: EpochLogger::new(),
epoch_count: 0,
agent_registry: AgentRegistry::default(),
agent_slots: Vec::new(),
kernel: KN::default(),
_phantom: PhantomData,
}
}
}
struct RuntimeParams<
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
KN: MultiagentDDPGKernelTrait<B, InK, OutK>,
> {
#[allow(dead_code)]
args: RuntimeArgs,
components: RuntimeComponents<B, InK, OutK, KN>,
}
impl<B, InK, OutK, KN> Default for RuntimeParams<B, InK, OutK, KN>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
KN: MultiagentDDPGKernelTrait<B, InK, OutK> + Default,
{
fn default() -> Self {
Self {
args: Default::default(),
components: Default::default(),
}
}
}
pub struct MultiagentDDPGAlgorithm<
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
KN: MultiagentDDPGKernelTrait<B, InK, OutK>,
> {
runtime: RuntimeParams<B, InK, OutK, KN>,
hyperparams: MADDPGParams,
}
pub type MADDPGAlgorithm<B, InK, OutK, KN> = MultiagentDDPGAlgorithm<B, InK, OutK, KN>;
impl<B, InK, OutK, KN> Default for MultiagentDDPGAlgorithm<B, InK, OutK, KN>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
KN: MultiagentDDPGKernelTrait<B, InK, OutK> + Default,
{
fn default() -> Self {
Self {
runtime: Default::default(),
hyperparams: Default::default(),
}
}
}
impl<B, InK, OutK, KN> MultiagentDDPGAlgorithm<B, InK, OutK, KN>
where
B: Backend + BackendMatcher,
InK: TensorKind<B>,
OutK: TensorKind<B>,
KN: MultiagentDDPGKernelTrait<B, InK, OutK> + Default,
{
#[allow(dead_code)]
pub(crate) fn new(
hyperparams: Option<MADDPGParams>,
env_dir: &Path,
save_model_path: &Path,
obs_dim: usize,
act_dim: usize,
buffer_size: usize,
kernel: KN,
) -> Result<Self, AlgorithmError> {
let hyperparams = hyperparams.unwrap_or_default();
let algorithm = MultiagentDDPGAlgorithm {
runtime: RuntimeParams::<B, InK, OutK, KN> {
args: RuntimeArgs {
env_dir: env_dir.to_path_buf(),
save_model_path: save_model_path.to_path_buf(),
obs_dim,
act_dim,
buffer_size,
},
components: RuntimeComponents::<B, InK, OutK, KN> {
epoch_logger: EpochLogger::new(),
epoch_count: 0,
agent_registry: AgentRegistry::default(),
agent_slots: Vec::new(),
kernel,
_phantom: PhantomData,
},
},
hyperparams,
};
let session_logger = SessionLogger::new();
session_logger
.log_session(&algorithm)
.map_err(|e| AlgorithmError::BufferSamplingError(e.to_string()))?;
Ok(algorithm)
}
fn register_agent_slot(&mut self, agent_key: AgentKey) -> usize {
if let Some(index) = self.runtime.components.agent_registry.get(&agent_key) {
return index;
}
let replay_buffer = MultiagentDDPGReplayBuffer::new(
self.runtime.args.buffer_size,
self.hyperparams.batch_size,
);
let index = self.runtime.components.agent_slots.len();
self.runtime
.components
.agent_slots
.push(AgentRuntimeSlot::new(agent_key.clone(), replay_buffer));
self.runtime
.components
.agent_registry
.insert(agent_key, index);
self.runtime.components.kernel.register_agent();
index
}
fn all_agents_ready(&self) -> bool {
self.runtime.components.agent_registry.len() > 0
&& self
.runtime
.components
.agent_slots
.iter()
.all(|slot| slot.trajectory_count >= self.hyperparams.traj_per_epoch)
}
fn reset_agent_counts(&mut self) {
for slot in &mut self.runtime.components.agent_slots {
slot.trajectory_count = 0;
}
}
}
#[cfg(feature = "ndarray-backend")]
impl<B, InK, OutK, KN> MultiagentDDPGAlgorithm<B, InK, OutK, KN>
where
B: Backend + BackendMatcher<Backend = B>,
InK: TensorKind<B>,
OutK: TensorKind<B>,
KN: MultiagentDDPGKernelTrait<B, InK, OutK>
+ crate::templates::base_algorithm::WeightProvider
+ Default,
{
pub fn acquire_model_module(&self) -> Option<relayrl_types::model::ModelModule<B>> {
use relayrl_types::data::tensor::{DType, NdArrayDType};
let layer_specs = self.runtime.components.kernel.get_pi_layer_specs()?;
crate::acquire_model_module::<B>(
"maddpg",
layer_specs,
DType::NdArray(NdArrayDType::F32),
DType::NdArray(NdArrayDType::F32),
vec![1, self.runtime.args.obs_dim],
vec![1, self.runtime.args.act_dim],
None,
)
}
}
impl<B, InK, OutK, KN, T> AlgorithmTrait<T> for MultiagentDDPGAlgorithm<B, InK, OutK, KN>
where
B: Backend + BackendMatcher<Backend = B>,
InK: TensorKind<B>,
OutK: TensorKind<B>,
KN: MultiagentDDPGKernelTrait<B, InK, OutK>
+ crate::templates::base_algorithm::WeightProvider
+ Default,
T: TrajectoryData,
{
fn save(&self, _filename: &str) {}
async fn receive_trajectory(&mut self, trajectory: T) -> Result<bool, AlgorithmError> {
let extracted_traj: RelayRLTrajectory = trajectory.into_relayrl().ok_or_else(|| {
AlgorithmError::TrajectoryInsertionError("Missing RelayRL trajectory".to_string())
})?;
let agent_key = resolve_agent_key(&extracted_traj);
let agent_index = self.register_agent_slot(agent_key);
let slot = &mut self.runtime.components.agent_slots[agent_index];
slot.trajectory_count += 1;
let result: Box<dyn Any> = slot
.replay_buffer
.insert_trajectory(extracted_traj)
.await
.map_err(|e| AlgorithmError::TrajectoryInsertionError(format!("{e}")))?;
let (episode_return, episode_length) = match result.downcast::<(f32, i32)>() {
Ok(payload) => *payload,
Err(_) => {
return Err(AlgorithmError::TrajectoryInsertionError(
"Unexpected replay buffer return payload".to_string(),
));
}
};
self.runtime
.components
.epoch_logger
.store("EpRet", episode_return);
self.runtime
.components
.epoch_logger
.store("EpLen", episode_length as f32);
if self.all_agents_ready() {
self.runtime.components.epoch_count += 1;
<Self as AlgorithmTrait<T>>::train_model(self);
<Self as AlgorithmTrait<T>>::log_epoch(self);
self.reset_agent_counts();
return Ok(true);
}
Ok(false)
}
fn train_model(&mut self) {
let gamma = self.hyperparams.gamma;
let tau = self.hyperparams.tau;
let policy_frequency = self.hyperparams.policy_frequency;
let train_iters = self.hyperparams.train_iters;
for _ in 0..train_iters {
let mut agent_batches = Vec::with_capacity(self.runtime.components.agent_slots.len());
for slot in &self.runtime.components.agent_slots {
let batch = match sample_buffer_blocking(&slot.replay_buffer) {
Ok(b) => b,
Err(_) => continue,
};
if let Some(agent_batch) = AgentBatch::from_batch(batch) {
agent_batches.push(agent_batch);
}
}
if agent_batches.is_empty() {
return;
}
let metrics = self.runtime.components.kernel.train_epoch(
&agent_batches,
gamma,
tau,
policy_frequency,
);
self.runtime
.components
.epoch_logger
.store("ActorLoss", metrics.actor_loss);
self.runtime
.components
.epoch_logger
.store("CriticLoss", metrics.critic_loss);
}
}
fn log_epoch(&mut self) {
self.runtime
.components
.epoch_logger
.log_tabular("Epoch", Some(self.runtime.components.epoch_count as f32));
self.runtime
.components
.epoch_logger
.log_tabular("EpRet", None);
self.runtime
.components
.epoch_logger
.log_tabular("EpLen", None);
self.runtime
.components
.epoch_logger
.log_tabular("ActorLoss", None);
self.runtime
.components
.epoch_logger
.log_tabular("CriticLoss", None);
self.runtime.components.epoch_logger.dump_tabular();
}
#[cfg(all(
any(feature = "tch-model", feature = "onnx-model"),
any(feature = "ndarray-backend", feature = "tch-backend")
))]
fn acquire_model<B2: Backend + BackendMatcher<Backend = B2>>(
&self,
) -> Option<relayrl_types::model::ModelModule<B2>>
where
B: 'static,
B2: 'static,
{
use std::any::TypeId;
if TypeId::of::<B>() != TypeId::of::<B2>() {
return None;
}
let module_b = self.acquire_model_module()?;
unsafe {
let module_b2: relayrl_types::model::ModelModule<B2> =
std::mem::transmute_copy(&module_b);
std::mem::forget(module_b);
Some(module_b2)
}
}
}
#[cfg(test)]
mod tests {
use super::{AgentRegistry, MADDPGParams};
#[test]
fn agent_registry_tracks_distinct_agents() {
let mut registry = AgentRegistry::default();
registry.insert("agent-a".to_string(), 0);
registry.insert("agent-b".to_string(), 1);
assert_eq!(registry.get("agent-a"), Some(0));
assert_eq!(registry.get("agent-b"), Some(1));
assert_eq!(registry.len(), 2);
}
#[test]
fn maddpg_params_default_continuous() {
let params = MADDPGParams::default();
assert!(params.gamma > 0.0);
assert!(params.tau > 0.0);
assert!(params.traj_per_epoch > 0);
}
}