rl_traits/multi_agent.rs
1use std::collections::HashMap;
2use std::hash::Hash;
3
4use rand::Rng;
5
6use crate::episode::{EpisodeStatus, StepResult};
7
8/// A multi-agent environment where all agents act simultaneously each step.
9///
10/// Mirrors the semantics of PettingZoo's Parallel API, adapted for Rust's
11/// type system. Use this when all agents observe and act at every step —
12/// cooperative navigation, competitive games, mixed-team tasks.
13///
14/// # Design principles
15///
16/// - **`possible_agents` vs `agents`**: `possible_agents()` is the fixed
17/// universe of all agent IDs. `agents()` is the live subset for the current
18/// episode. After `reset()`, `agents == possible_agents`. Agents are removed
19/// from `agents` when their episode ends; the episode is over when `agents`
20/// is empty.
21///
22/// - **Joint step**: `step()` takes exactly one action per agent in `agents()`
23/// and returns one [`StepResult`] per active agent. Providing actions for
24/// terminated agents or omitting active agents is undefined behaviour.
25///
26/// - **Homogeneous agents**: all agents share `Observation`, `Action`, and
27/// `Info` types. Heterogeneous agents can be modelled with enum wrappers
28/// over the per-type variants.
29///
30/// - **Bevy-compatible by design**: `AgentId: Eq + Hash + Send + Sync +
31/// 'static` means Bevy `Entity` is a valid agent ID directly, enabling
32/// free ECS-based parallelisation across agents in bevy-gym.
33///
34/// - **No `render()`**: visualisation is bevy-gym's concern.
35///
36/// - **No `close()`**: implement `Drop` if your environment holds resources.
37///
38/// # Example
39///
40/// ```rust
41/// use std::collections::HashMap;
42/// use rl_traits::{ParallelEnvironment, StepResult, EpisodeStatus};
43/// use rand::Rng;
44///
45/// struct CoopGame {
46/// active: Vec<usize>,
47/// }
48///
49/// impl ParallelEnvironment for CoopGame {
50/// type AgentId = usize;
51/// type Observation = f32;
52/// type Action = bool; // cooperate or defect
53/// type Info = ();
54///
55/// fn possible_agents(&self) -> &[usize] { &[0, 1] }
56/// fn agents(&self) -> &[usize] { &self.active }
57///
58/// fn step(&mut self, _actions: HashMap<usize, bool>)
59/// -> HashMap<usize, StepResult<f32, ()>>
60/// {
61/// self.active.iter()
62/// .map(|&id| (id, StepResult::new(0.0_f32, 1.0, EpisodeStatus::Continuing, ())))
63/// .collect()
64/// }
65///
66/// fn reset(&mut self, _seed: Option<u64>) -> HashMap<usize, (f32, ())> {
67/// self.active = vec![0, 1];
68/// self.active.iter().map(|&id| (id, (0.0_f32, ()))).collect()
69/// }
70///
71/// fn sample_action(&self, _agent: &usize, rng: &mut impl Rng) -> bool {
72/// rng.gen()
73/// }
74/// }
75/// ```
76pub trait ParallelEnvironment {
77 /// Identifier for each agent.
78 ///
79 /// Common choices: `usize` (index), `&'static str` (name), or a Bevy
80 /// `Entity` for direct ECS integration without an extra lookup.
81 type AgentId: Eq + Hash + Clone + Send + Sync + 'static;
82
83 /// The observation type produced by `step()` and `reset()`.
84 ///
85 /// `Send + Sync + 'static` are required for Bevy ECS compatibility.
86 type Observation: Clone + Send + Sync + 'static;
87
88 /// The action type consumed by `step()`.
89 type Action: Clone + Send + Sync + 'static;
90
91 /// Auxiliary information returned alongside observations.
92 ///
93 /// Use `()` if you don't need it — `Default` is implemented for `()`.
94 type Info: Default + Clone + Send + Sync + 'static;
95
96 /// The complete, fixed set of agent IDs for this environment.
97 ///
98 /// Does not change between episodes or as agents terminate mid-episode.
99 /// Use `agents()` for the currently live set.
100 fn possible_agents(&self) -> &[Self::AgentId];
101
102 /// The agents currently active in this episode.
103 ///
104 /// Starts equal to `possible_agents()` after `reset()`. Shrinks as agents
105 /// terminate or are truncated; never grows. Empty when the episode is over.
106 fn agents(&self) -> &[Self::AgentId];
107
108 /// Advance the environment by one step using joint actions.
109 ///
110 /// `actions` must contain exactly one entry per agent in `self.agents()`.
111 /// After this call, agents whose result was done are removed from `agents()`.
112 fn step(
113 &mut self,
114 actions: HashMap<Self::AgentId, Self::Action>,
115 ) -> HashMap<Self::AgentId, StepResult<Self::Observation, Self::Info>>;
116
117 /// Reset the environment to an initial state, starting a new episode.
118 ///
119 /// If `seed` is `Some(u64)`, the environment should use it to seed its
120 /// internal RNG for deterministic reproduction of episodes.
121 /// Returns the initial observation and info for every agent.
122 fn reset(
123 &mut self,
124 seed: Option<u64>,
125 ) -> HashMap<Self::AgentId, (Self::Observation, Self::Info)>;
126
127 /// Sample a random action for the given agent.
128 ///
129 /// The `rng` is caller-supplied so exploration randomness can be seeded
130 /// and tracked independently from environment randomness.
131 fn sample_action(&self, agent: &Self::AgentId, rng: &mut impl Rng) -> Self::Action;
132
133 /// A global state observation of the full environment.
134 ///
135 /// Used by centralised-training / decentralised-execution algorithms
136 /// (e.g. MADDPG, QMIX) that condition a centralised critic on the full
137 /// state while individual policies see only local observations.
138 /// Returns `None` by default; override if your environment supports it.
139 fn state(&self) -> Option<Self::Observation> {
140 None
141 }
142
143 /// Returns `true` when all agents have finished (active set is empty).
144 fn is_done(&self) -> bool {
145 self.agents().is_empty()
146 }
147
148 /// Number of currently active agents.
149 fn num_agents(&self) -> usize {
150 self.agents().len()
151 }
152
153 /// Maximum number of agents that could ever be active simultaneously.
154 fn max_num_agents(&self) -> usize {
155 self.possible_agents().len()
156 }
157}
158
159/// A multi-agent environment with Agent Environment Cycle (turn-based) semantics.
160///
161/// Mirrors the semantics of PettingZoo's AEC API, adapted for Rust's type
162/// system. Use this when agents act one at a time — board games, card games,
163/// or any domain where simultaneous action is not meaningful.
164///
165/// # Design principles
166///
167/// - **Turn-based execution**: one agent acts per `step()` call.
168/// `agent_selection()` identifies whose turn it is. After each call,
169/// the selection advances to the next active agent.
170///
171/// - **Persistent state**: the environment tracks each agent's most recent
172/// reward, status, and info as mutable state. Read it via `agent_state()`
173/// before deciding on an action. `last()` is a convenience that combines
174/// `observe()` and `agent_state()` for the current agent.
175///
176/// - **Cycling out terminated agents**: when `agent_state()` reports a done
177/// status for `agent_selection()`, pass `None` to `step()` to advance the
178/// turn without applying an action. The type signature makes this contract
179/// explicit — passing `Some(action)` for a done agent is undefined behaviour.
180///
181/// - **Bevy-compatible by design**: same `Send + Sync + 'static` bounds as
182/// [`ParallelEnvironment`]. The turn-based nature is inherently sequential,
183/// so ECS parallelisation applies less directly than with `ParallelEnvironment`.
184///
185/// - **No `render()`**: visualisation is bevy-gym's concern.
186///
187/// - **No `close()`**: implement `Drop` if your environment holds resources.
188///
189/// # Example
190///
191/// Typical AEC loop:
192///
193/// ```rust,ignore
194/// env.reset(None);
195/// while !env.is_done() {
196/// let (obs, _reward, status, _info) = env.last();
197/// let action = if status.is_done() {
198/// None // cycle the terminated agent out
199/// } else {
200/// Some(policy.act(env.agent_selection(), &obs.unwrap()))
201/// };
202/// env.step(action);
203/// }
204/// ```
205pub trait AecEnvironment {
206 /// Identifier for each agent. Same semantics as [`ParallelEnvironment::AgentId`].
207 type AgentId: Eq + Hash + Clone + Send + Sync + 'static;
208
209 /// The observation type produced by `observe()`.
210 ///
211 /// `Send + Sync + 'static` are required for Bevy ECS compatibility.
212 type Observation: Clone + Send + Sync + 'static;
213
214 /// The action type consumed by `step()`.
215 type Action: Clone + Send + Sync + 'static;
216
217 /// Auxiliary information returned alongside observations.
218 ///
219 /// Use `()` if you don't need it — `Default` is implemented for `()`.
220 type Info: Default + Clone + Send + Sync + 'static;
221
222 /// The complete, fixed set of agent IDs for this environment.
223 ///
224 /// Does not change between episodes or as agents terminate mid-episode.
225 fn possible_agents(&self) -> &[Self::AgentId];
226
227 /// The agents currently active in this episode.
228 ///
229 /// Starts equal to `possible_agents()` after `reset()`. Shrinks as agents
230 /// terminate or are truncated; never grows. Empty when the episode is over.
231 fn agents(&self) -> &[Self::AgentId];
232
233 /// The agent whose turn it currently is to act.
234 fn agent_selection(&self) -> &Self::AgentId;
235
236 /// Execute the current agent's action and advance to the next agent.
237 ///
238 /// Pass `None` when `agent_state(agent_selection())` reports a done status,
239 /// to cycle the agent out without applying an action. Pass `Some(action)`
240 /// otherwise.
241 fn step(&mut self, action: Option<Self::Action>);
242
243 /// Reset the environment to an initial state, starting a new episode.
244 ///
245 /// Unlike [`ParallelEnvironment::reset`], this returns nothing. Retrieve
246 /// initial observations via `observe()` after calling `reset()`.
247 /// If `seed` is `Some(u64)`, it is used to seed the internal RNG.
248 fn reset(&mut self, seed: Option<u64>);
249
250 /// Retrieve the current observation for the given agent.
251 ///
252 /// Returns `None` if the agent has terminated or been truncated — their
253 /// last observation is no longer valid.
254 fn observe(&self, agent: &Self::AgentId) -> Option<Self::Observation>;
255
256 /// Retrieve the persistent `(reward, status, info)` for the given agent.
257 ///
258 /// This state is updated each time the agent acts and persists until its
259 /// next turn. It reflects what the agent received as a result of its last
260 /// action.
261 fn agent_state(&self, agent: &Self::AgentId) -> (f64, EpisodeStatus, Self::Info);
262
263 /// Sample a random action for the given agent.
264 ///
265 /// The `rng` is caller-supplied so exploration randomness can be seeded
266 /// and tracked independently from environment randomness.
267 fn sample_action(&self, agent: &Self::AgentId, rng: &mut impl Rng) -> Self::Action;
268
269 /// Returns the full state for the currently selected agent.
270 ///
271 /// Convenience wrapper around `observe(agent_selection())` and
272 /// `agent_state(agent_selection())`. This is the idiomatic way to read the
273 /// current agent's situation at the top of the AEC loop.
274 fn last(&self) -> (Option<Self::Observation>, f64, EpisodeStatus, Self::Info) {
275 let agent = self.agent_selection().clone();
276 let obs = self.observe(&agent);
277 let (reward, status, info) = self.agent_state(&agent);
278 (obs, reward, status, info)
279 }
280
281 /// Returns `true` when all agents have finished (active set is empty).
282 fn is_done(&self) -> bool {
283 self.agents().is_empty()
284 }
285
286 /// Number of currently active agents.
287 fn num_agents(&self) -> usize {
288 self.agents().len()
289 }
290
291 /// Maximum number of agents that could ever be active simultaneously.
292 fn max_num_agents(&self) -> usize {
293 self.possible_agents().len()
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use crate::episode::{EpisodeStatus, StepResult};
301
302 // ── ParallelEnvironment mock ─────────────────────────────────────────────
303
304 struct ParallelMock {
305 agents: &'static [usize],
306 }
307
308 impl ParallelEnvironment for ParallelMock {
309 type AgentId = usize;
310 type Observation = ();
311 type Action = ();
312 type Info = ();
313
314 fn possible_agents(&self) -> &[usize] { &[0, 1, 2] }
315 fn agents(&self) -> &[usize] { self.agents }
316
317 fn step(&mut self, _: HashMap<usize, ()>) -> HashMap<usize, StepResult<(), ()>> {
318 unimplemented!()
319 }
320 fn reset(&mut self, _: Option<u64>) -> HashMap<usize, ((), ())> {
321 unimplemented!()
322 }
323 fn sample_action(&self, _: &usize, _: &mut impl rand::Rng) {}
324 }
325
326 // ── AecEnvironment mock ──────────────────────────────────────────────────
327
328 struct AecMock {
329 agents: &'static [usize],
330 current: usize,
331 }
332
333 impl AecEnvironment for AecMock {
334 type AgentId = usize;
335 type Observation = i32;
336 type Action = ();
337 type Info = String;
338
339 fn possible_agents(&self) -> &[usize] { &[0, 1] }
340 fn agents(&self) -> &[usize] { self.agents }
341 fn agent_selection(&self) -> &usize { &self.current }
342
343 fn step(&mut self, _: Option<()>) { unimplemented!() }
344 fn reset(&mut self, _: Option<u64>) { unimplemented!() }
345
346 fn observe(&self, agent: &usize) -> Option<i32> {
347 if *agent == 0 { Some(99) } else { None }
348 }
349
350 fn agent_state(&self, agent: &usize) -> (f64, EpisodeStatus, String) {
351 match agent {
352 0 => (2.5, EpisodeStatus::Continuing, "alive".to_string()),
353 _ => (0.0, EpisodeStatus::Terminated, String::new()),
354 }
355 }
356
357 fn sample_action(&self, _: &usize, _: &mut impl rand::Rng) {}
358 }
359
360 // ── ParallelEnvironment ──────────────────────────────────────────────────
361
362 #[test]
363 fn parallel_is_done_when_agents_empty() {
364 assert!(ParallelMock { agents: &[] }.is_done());
365 }
366
367 #[test]
368 fn parallel_not_done_with_active_agents() {
369 assert!(!ParallelMock { agents: &[0, 1] }.is_done());
370 }
371
372 #[test]
373 fn parallel_num_agents_reflects_active_set() {
374 assert_eq!(ParallelMock { agents: &[0, 1] }.num_agents(), 2);
375 }
376
377 #[test]
378 fn parallel_max_num_agents_reflects_possible_set() {
379 // possible_agents is always [0, 1, 2] regardless of the active set
380 assert_eq!(ParallelMock { agents: &[0] }.max_num_agents(), 3);
381 }
382
383 #[test]
384 fn parallel_state_returns_none_by_default() {
385 assert!(ParallelMock { agents: &[0] }.state().is_none());
386 }
387
388 // ── AecEnvironment ───────────────────────────────────────────────────────
389
390 #[test]
391 fn aec_is_done_when_agents_empty() {
392 assert!(AecMock { agents: &[], current: 0 }.is_done());
393 }
394
395 #[test]
396 fn aec_not_done_with_active_agents() {
397 assert!(!AecMock { agents: &[0, 1], current: 0 }.is_done());
398 }
399
400 #[test]
401 fn aec_num_and_max_num_agents() {
402 let env = AecMock { agents: &[0], current: 0 };
403 assert_eq!(env.num_agents(), 1);
404 assert_eq!(env.max_num_agents(), 2); // possible_agents is [0, 1]
405 }
406
407 #[test]
408 fn aec_last_composes_observe_and_agent_state_for_current_agent() {
409 // Agent 0 is selected — has observation Some(99) and is Continuing.
410 let env = AecMock { agents: &[0, 1], current: 0 };
411 let (obs, reward, status, info) = env.last();
412 assert_eq!(obs, Some(99));
413 assert_eq!(reward, 2.5);
414 assert_eq!(status, EpisodeStatus::Continuing);
415 assert_eq!(info, "alive");
416 }
417
418 #[test]
419 fn aec_last_returns_none_obs_for_terminated_agent() {
420 // Agent 1 is selected — terminated, so observe() returns None.
421 let env = AecMock { agents: &[1], current: 1 };
422 let (obs, _reward, status, _info) = env.last();
423 assert_eq!(obs, None);
424 assert!(status.is_terminal());
425 }
426}