use crate::error::{NeuralError, Result};
use crate::reinforcement::environments::{Action, Environment, Info, Observation, Reward};
use scirs2_core::ndarray::prelude::*;
use scirs2_core::random::{Rng, RngExt};
use std::collections::HashMap;
pub trait MultiAgentEnvironment: Send + Sync {
fn num_agents(&self) -> usize;
fn reset(&mut self) -> Result<Vec<Observation>>;
fn step(
&mut self,
actions: &[Action],
) -> Result<(Vec<Observation>, Vec<Reward>, Vec<bool>, Vec<Info>)>;
fn observation_spaces(&self) -> Vec<usize>;
fn action_spaces(&self) -> Vec<usize>;
fn continuous_actions(&self) -> Vec<bool>;
}
pub struct MultiAgentGridWorld {
width: usize,
height: usize,
agent_positions: Vec<(usize, usize)>,
goal_positions: Vec<(usize, usize)>,
obstacles: Vec<(usize, usize)>,
observation_radius: usize,
step_count: usize,
max_steps: usize,
communication_enabled: bool,
impl MultiAgentGridWorld {
pub fn new(
width: usize,
height: usize,
num_agents: usize,
observation_radius: usize,
communication_enabled: bool,
) -> Self {
let mut agent_positions = Vec::new();
let mut goal_positions = Vec::new();
let mut rng = rng();
for _ in 0..num_agents {
let x = rng.random_range(0..width);
let y = rng.random_range(0..height);
agent_positions.push((x..y));
loop {
let gx = rng.random_range(0..width);
let gy = rng.random_range(0..height);
if !agent_positions.contains(&(gx..gy)) {
goal_positions.push((gx, gy));
break;
}
}
}
let mut obstacles = Vec::new();
let num_obstacles = (width * height) / 10;
for _ in 0..num_obstacles {
let ox = rng.random_range(0..width);
let oy = rng.random_range(0..height);
if !agent_positions.contains(&(ox..oy)) && !goal_positions.contains(&(ox, oy)) {
obstacles.push((ox, oy));
Self {
width,
height,
agent_positions,
goal_positions,
obstacles,
observation_radius,
step_count: 0,
max_steps: 100,
communication_enabled,
}
fn get_local_observation(&self, agentid: usize) -> Array1<f32> {
let (ax, ay) = self.agent_positions[agent_id];
let r = self.observation_radius as i32;
let obs_size = (2 * self.observation_radius + 1).pow(2);
let mut observation = Array1::zeros(obs_size * 4); let mut idx = 0;
for dy in -r..=r {
for dx in -r..=r {
let x = ax as i32 + dx;
let y = ay as i32 + dy;
if x >= 0 && x < self.width as i32 && y >= 0 && y < self.height as i32 {
let pos = (x as usize, y as usize);
observation[idx] = 1.0;
if self.obstacles.contains(&pos) {
observation[idx] = 0.0;
observation[idx + obs_size] = 1.0;
}
for (i, &agent_pos) in self.agent_positions.iter().enumerate() {
if i != agent_id && agent_pos == pos {
observation[idx] = 0.0;
observation[idx + 2 * obs_size] = 1.0;
}
if self.goal_positions.contains(&pos) {
observation[idx + 3 * obs_size] = 1.0;
} else {
observation[idx] = 0.0;
observation[idx + obs_size] = 1.0;
idx += 1;
if self.communication_enabled {
let comm_size = self.agent_positions.len() * 2; let mut comm_data = Array1::zeros(comm_size);
for (i, &(x, y)) in self.agent_positions.iter().enumerate() {
comm_data[i * 2] = x as f32 / self.width as f32;
comm_data[i * 2 + 1] = y as f32 / self.height as f32;
let mut full_obs = Array1::zeros(observation.len() + comm_data.len());
full_obs
.slice_mut(s![..observation.len()])
.assign(&observation);
.slice_mut(s![observation.len()..])
.assign(&comm_data);
return full_obs;
observation
fn is_valid_position(&self, pos: (usize, usize), exclude_agent: Option<usize>) -> bool {
if self.obstacles.contains(&pos) {
return false;
for (i, &agent_pos) in self.agent_positions.iter().enumerate() {
if Some(i) != exclude_agent && agent_pos == pos {
return false;
true
impl MultiAgentEnvironment for MultiAgentGridWorld {
fn num_agents(&self) -> usize {
self.agent_positions.len()
fn reset(&mut self) -> Result<Vec<Observation>> {
for i in 0..self.agent_positions.len() {
let x = rng.random_range(0..self.width);
let y = rng.random_range(0..self.height);
if self.is_valid_position((x..y), Some(i)) {
self.agent_positions[i] = (x, y);
let gx = rng.random_range(0..self.width);
let gy = rng.random_range(0..self.height);
if self.is_valid_position((gx..gy), None)
&& !self.goal_positions.contains(&(gx, gy))
{
self.goal_positions[i] = (gx, gy);
self.step_count = 0;
let mut observations = Vec::new();
for i in 0..self.num_agents() {
observations.push(self.get_local_observation(i));
Ok(observations)
) -> Result<(Vec<Observation>, Vec<Reward>, Vec<bool>, Vec<Info>)> {
if actions.len() != self.num_agents() {
return Err(NeuralError::InvalidArgument(format!(
"Expected {} actions, got {}",
self.num_agents(),
actions.len()
)));
let mut rewards = vec![0.0; self.num_agents()];
let mut dones = vec![false; self.num_agents()];
let mut infos = vec![HashMap::new(); self.num_agents()];
for (i, action) in actions.iter().enumerate() {
let (x, y) = self.agent_positions[i];
let action_idx = if action[0] < 0.2 {
0 } else if action[0] < 0.4 {
1 } else if action[0] < 0.6 {
2 } else if action[0] < 0.8 {
3 } else {
4 };
let new_pos = match action_idx {
0 => (x, y.saturating_sub(1)), 1 => (x, (y + 1).min(self.height - 1)), 2 => (x.saturating_sub(1), y), 3 => ((x + 1).min(self.width - 1), y), _ => (x, y), if self.is_valid_position(new_pos, Some(i)) {
self.agent_positions[i] = new_pos;
if self.agent_positions[i] == self.goal_positions[i] {
rewards[i] = 10.0;
dones[i] = true;
rewards[i] = -0.01;
let old_dist = ((x as f32 - self.goal_positions[i].0 as f32).powi(2)
+ (y as f32 - self.goal_positions[i].1 as f32).powi(2))
.sqrt();
let new_dist = ((self.agent_positions[i].0 as f32
- self.goal_positions[i].0 as f32)
.powi(2)
+ (self.agent_positions[i].1 as f32 - self.goal_positions[i].1 as f32).powi(2))
if new_dist < old_dist {
rewards[i] += 0.1;
infos[i].insert("position_x".to_string(), self.agent_positions[i].0 as f32);
infos[i].insert("position_y".to_string(), self.agent_positions[i].1 as f32);
infos[i].insert("goal_x".to_string(), self.goal_positions[i].0 as f32);
infos[i].insert("goal_y".to_string(), self.goal_positions[i].1 as f32);
self.step_count += 1;
let episode_done = self.step_count >= self.max_steps || dones.iter().all(|&d| d);
if episode_done {
for done in &mut dones {
*done = true;
Ok((observations, rewards, dones, infos))
fn observation_spaces(&self) -> Vec<usize> {
let obs_size = (2 * self.observation_radius + 1).pow(2) * 4; let comm_size = if self.communication_enabled {
self.agent_positions.len() * 2
} else {
0
};
vec![obs_size + comm_size; self.num_agents()]
fn action_spaces(&self) -> Vec<usize> {
vec![1; self.num_agents()] fn continuous_actions(&self) -> Vec<bool> {
vec![false; self.num_agents()] pub struct PursuitEvasion {
width: f32,
height: f32,
pursuers: Vec<Agent>,
evaders: Vec<Agent>,
max_speed: f32,
capture_radius: f32,
#[derive(Debug, Clone)]
struct Agent {
position: (f32, f32),
velocity: (f32, f32),
captured: bool,
impl Agent {
fn new(x: f32, y: f32) -> Self {
position: (x, y),
velocity: (0.0, 0.0),
captured: false,
fn distance_to(&self, other: &Agent) -> f32 {
let dx = self.position.0 - other.position.0;
let dy = self.position.1 - other.position.1;
(dx * dx + dy * dy).sqrt()
impl PursuitEvasion {
width: f32,
height: f32,
num_pursuers: usize,
num_evaders: usize,
max_speed: f32,
capture_radius: f32,
let mut pursuers = Vec::new();
for _ in 0..num_pursuers {
let x = rng.random_range(0.0..width);
let y = rng.random_range(0.0..height);
pursuers.push(Agent::new(x..y));
let mut evaders = Vec::new();
for _ in 0..num_evaders {
evaders.push(Agent::new(x, y));
pursuers,
evaders,
max_speed,
capture_radius,
max_steps: 500,
fn get_pursuer_observation(&self, pursuerid: usize) -> Array1<f32> {
let pursuer = &self.pursuers[pursuer_id];
let mut obs = Vec::new();
obs.push(pursuer.position.0 / self.width);
obs.push(pursuer.position.1 / self.height);
obs.push(pursuer.velocity.0 / self.max_speed);
obs.push(pursuer.velocity.1 / self.max_speed);
for (i, other) in self.pursuers.iter().enumerate() {
if i != pursuer_id {
let dx = (other.position.0 - pursuer.position.0) / self.width;
let dy = (other.position.1 - pursuer.position.1) / self.height;
obs.push(dx);
obs.push(dy);
for evader in &self.evaders {
if !evader.captured {
let dx = (evader.position.0 - pursuer.position.0) / self.width;
let dy = (evader.position.1 - pursuer.position.1) / self.height;
obs.push(if evader.captured { 0.0 } else { 1.0 });
obs.push(0.0);
Array1::from_vec(obs)
fn get_evader_observation(&self, evaderid: usize) -> Array1<f32> {
let evader = &self.evaders[evader_id];
obs.push(evader.position.0 / self.width);
obs.push(evader.position.1 / self.height);
obs.push(evader.velocity.0 / self.max_speed);
obs.push(evader.velocity.1 / self.max_speed);
for pursuer in &self.pursuers {
let dx = (pursuer.position.0 - evader.position.0) / self.width;
let dy = (pursuer.position.1 - evader.position.1) / self.height;
obs.push(dx);
obs.push(dy);
for (i, other) in self.evaders.iter().enumerate() {
if i != evader_id && !other.captured {
let dx = (other.position.0 - evader.position.0) / self.width;
let dy = (other.position.1 - evader.position.1) / self.height;
impl MultiAgentEnvironment for PursuitEvasion {
self.pursuers.len() + self.evaders.len()
for pursuer in &mut self.pursuers {
pursuer.position.0 = rng.random_range(0.0..self.width);
pursuer.position.1 = rng.random_range(0.0..self.height);
pursuer.velocity = (0.0..0.0);
pursuer.captured = false;
for evader in &mut self.evaders {
evader.position.0 = rng.random_range(0.0..self.width);
evader.position.1 = rng.random_range(0.0..self.height);
evader.velocity = (0.0..0.0);
evader.captured = false;
for i in 0..self.pursuers.len() {
observations.push(self.get_pursuer_observation(i));
for i in 0..self.evaders.len() {
observations.push(self.get_evader_observation(i));
let dt = 0.1; for (i, action) in actions.iter().take(self.pursuers.len()).enumerate() {
let pursuer = &mut self.pursuers[i];
let ax = action[0] * self.max_speed;
let ay = action[1] * self.max_speed;
pursuer.velocity.0 = (pursuer.velocity.0 + ax * dt) * 0.9;
pursuer.velocity.1 = (pursuer.velocity.1 + ay * dt) * 0.9;
let speed = (pursuer.velocity.0 * pursuer.velocity.0
+ pursuer.velocity.1 * pursuer.velocity.1)
if speed > self.max_speed {
pursuer.velocity.0 *= self.max_speed / speed;
pursuer.velocity.1 *= self.max_speed / speed;
pursuer.position.0 += pursuer.velocity.0 * dt;
pursuer.position.1 += pursuer.velocity.1 * dt;
pursuer.position.0 = pursuer.position.0.max(0.0).min(self.width);
pursuer.position.1 = pursuer.position.1.max(0.0).min(self.height);
for (i, action) in actions.iter().skip(self.pursuers.len()).enumerate() {
if self.evaders[i].captured {
continue;
let evader = &mut self.evaders[i];
evader.velocity.0 = (evader.velocity.0 + ax * dt) * 0.9;
evader.velocity.1 = (evader.velocity.1 + ay * dt) * 0.9;
let speed = (evader.velocity.0 * evader.velocity.0
+ evader.velocity.1 * evader.velocity.1)
evader.velocity.0 *= self.max_speed / speed;
evader.velocity.1 *= self.max_speed / speed;
evader.position.0 += evader.velocity.0 * dt;
evader.position.1 += evader.velocity.1 * dt;
evader.position.0 = evader.position.0.max(0.0).min(self.width);
evader.position.1 = evader.position.1.max(0.0).min(self.height);
if evader.captured {
for pursuer in &self.pursuers {
if pursuer.distance_to(evader) < self.capture_radius {
evader.captured = true;
let captured_count = self.evaders.iter().filter(|e| e.captured).count();
rewards[i] = captured_count as f32 * 10.0 - 0.01;
rewards[self.pursuers.len() + i] = -10.0;
rewards[self.pursuers.len() + i] = 0.1;
let all_captured = self.evaders.iter().all(|e| e.captured);
let episode_done = all_captured || self.step_count >= self.max_steps;
let dones = vec![episode_done; self.num_agents()];
for (i, info) in infos.iter_mut().enumerate() {
info.insert("step".to_string(), self.step_count as f32);
info.insert("captured_count".to_string(), captured_count as f32);
let mut spaces = Vec::new();
for _ in 0..self.pursuers.len() {
let obs_size = 4 + (self.pursuers.len() - 1) * 2 + self.evaders.len() * 3; spaces.push(obs_size);
for _ in 0..self.evaders.len() {
self.pursuers.len() * 2 + (self.evaders.len() - 1) * 2; spaces
vec![2; self.num_agents()] vec![true; self.num_agents()] pub struct MultiAgentWrapper<E: MultiAgentEnvironment> {
env: E,
agent_id: usize,
impl<E: MultiAgentEnvironment> MultiAgentWrapper<E> {
pub fn new(_env: E, agentid: usize) -> Result<Self> {
if agent_id >= env.num_agents() {
"Agent ID {} out of range (0-{})",
agent_id,
env.num_agents() - 1
Ok(Self { env, agent_id })
/// Get the underlying multi-agent environment
pub fn get_env(&self) -> &E {
&self.env
/// Get mutable reference to the underlying environment
pub fn get_env_mut(&mut self) -> &mut E {
&mut self.env
impl<E: MultiAgentEnvironment> Environment for MultiAgentWrapper<E> {
fn reset(&mut self) -> Result<Observation> {
let observations = self.env.reset()?;
Ok(observations[self.agent_id].clone())
fn step(&mut self, action: &Action) -> Result<(Observation, Reward, bool, Info)> {
// Create dummy actions for other agents (zeros)
let mut actions = Vec::new();
for i in 0..self.env.num_agents() {
if i == self.agent_id {
actions.push(action.clone());
let action_size = self.env.action_spaces()[i];
actions.push(Array1::zeros(action_size));
let (observations, rewards, dones, infos) = self.env.step(&actions)?;
Ok((
observations[self.agent_id].clone(),
rewards[self.agent_id],
dones[self.agent_id],
infos[self.agent_id].clone(),
))
fn observation_space(&self) -> usize {
self.env.observation_spaces()[self.agent_id]
fn action_space(&self) -> usize {
self.env.action_spaces()[self.agent_id]
fn continuous_actions(&self) -> bool {
self.env.continuous_actions()[self.agent_id]
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multi_agent_grid_world() {
let mut env = MultiAgentGridWorld::new(5, 5, 2, 1, false);
assert_eq!(env.num_agents(), 2);
let observations = env.reset().expect("Operation failed");
assert_eq!(observations.len(), 2);
let actions = vec![
Array1::from_vec(vec![0.1]), // up
Array1::from_vec(vec![0.9]), // stay
];
let (next_obs, rewards, dones, infos) = env.step(&actions).expect("Operation failed");
assert_eq!(next_obs.len(), 2);
assert_eq!(rewards.len(), 2);
assert_eq!(dones.len(), 2);
assert_eq!(infos.len(), 2);
fn test_pursuit_evasion() {
let mut env = PursuitEvasion::new(10.0, 10.0, 2, 1, 1.0, 0.5);
assert_eq!(env.num_agents(), 3); // 2 pursuers + 1 evader
assert_eq!(observations.len(), 3);
Array1::from_vec(vec![0.5, 0.5]), // pursuer 1
Array1::from_vec(vec![-0.5, 0.0]), // pursuer 2
Array1::from_vec(vec![0.0, -0.5]), // evader
assert_eq!(next_obs.len(), 3);
assert_eq!(rewards.len(), 3);
fn test_multi_agent_wrapper() {
let env = MultiAgentGridWorld::new(3, 3, 2, 1, false);
let mut wrapper = MultiAgentWrapper::new(env, 0).expect("Operation failed");
let obs = wrapper.reset().expect("Operation failed");
assert!(obs.len() > 0);
let action = Array1::from_vec(vec![0.5]);
let (next_obs, reward, done, info) = wrapper.step(&action).expect("Operation failed");
assert!(next_obs.len() > 0);
assert!(reward.is_finite());
assert!(info.contains_key("position_x"));