Skip to main content

rl_traits/
multi_agent.rs

1use std::collections::HashMap;
2use std::hash::Hash;
3
4use rand::Rng;
5
6use crate::episode::{EpisodeStatus, StepResult};
7
8/// A multi-agent environment where all agents act simultaneously each step.
9///
10/// Mirrors the semantics of PettingZoo's Parallel API, adapted for Rust's
11/// type system. Use this when all agents observe and act at every step —
12/// cooperative navigation, competitive games, mixed-team tasks.
13///
14/// # Design principles
15///
16/// - **`possible_agents` vs `agents`**: `possible_agents()` is the fixed
17///   universe of all agent IDs. `agents()` is the live subset for the current
18///   episode. After `reset()`, `agents == possible_agents`. Agents are removed
19///   from `agents` when their episode ends; the episode is over when `agents`
20///   is empty.
21///
22/// - **Joint step**: `step()` takes exactly one action per agent in `agents()`
23///   and returns one [`StepResult`] per active agent. Providing actions for
24///   terminated agents or omitting active agents is undefined behaviour.
25///
26/// - **Homogeneous agents**: all agents share `Observation`, `Action`, and
27///   `Info` types. Heterogeneous agents can be modelled with enum wrappers
28///   over the per-type variants.
29///
30/// - **Bevy-compatible by design**: `AgentId: Eq + Hash + Send + Sync +
31///   'static` means Bevy `Entity` is a valid agent ID directly, enabling
32///   free ECS-based parallelisation across agents in bevy-gym.
33///
34/// - **No `render()`**: visualisation is bevy-gym's concern.
35///
36/// - **No `close()`**: implement `Drop` if your environment holds resources.
37///
38/// # Example
39///
40/// ```rust
41/// use std::collections::HashMap;
42/// use rl_traits::{ParallelEnvironment, StepResult, EpisodeStatus};
43/// use rand::Rng;
44///
45/// struct CoopGame {
46///     active: Vec<usize>,
47/// }
48///
49/// impl ParallelEnvironment for CoopGame {
50///     type AgentId = usize;
51///     type Observation = f32;
52///     type Action = bool;   // cooperate or defect
53///     type Info = ();
54///
55///     fn possible_agents(&self) -> &[usize] { &[0, 1] }
56///     fn agents(&self) -> &[usize] { &self.active }
57///
58///     fn step(&mut self, _actions: HashMap<usize, bool>)
59///         -> HashMap<usize, StepResult<f32, ()>>
60///     {
61///         self.active.iter()
62///             .map(|&id| (id, StepResult::new(0.0_f32, 1.0, EpisodeStatus::Continuing, ())))
63///             .collect()
64///     }
65///
66///     fn reset(&mut self, _seed: Option<u64>) -> HashMap<usize, (f32, ())> {
67///         self.active = vec![0, 1];
68///         self.active.iter().map(|&id| (id, (0.0_f32, ()))).collect()
69///     }
70///
71///     fn sample_action(&self, _agent: &usize, rng: &mut impl Rng) -> bool {
72///         rng.gen()
73///     }
74/// }
75/// ```
76pub trait ParallelEnvironment {
77    /// Identifier for each agent.
78    ///
79    /// Common choices: `usize` (index), `&'static str` (name), or a Bevy
80    /// `Entity` for direct ECS integration without an extra lookup.
81    type AgentId: Eq + Hash + Clone + Send + Sync + 'static;
82
83    /// The observation type produced by `step()` and `reset()`.
84    ///
85    /// `Send + Sync + 'static` are required for Bevy ECS compatibility.
86    type Observation: Clone + Send + Sync + 'static;
87
88    /// The action type consumed by `step()`.
89    type Action: Clone + Send + Sync + 'static;
90
91    /// Auxiliary information returned alongside observations.
92    ///
93    /// Use `()` if you don't need it — `Default` is implemented for `()`.
94    type Info: Default + Clone + Send + Sync + 'static;
95
96    /// The complete, fixed set of agent IDs for this environment.
97    ///
98    /// Does not change between episodes or as agents terminate mid-episode.
99    /// Use `agents()` for the currently live set.
100    fn possible_agents(&self) -> &[Self::AgentId];
101
102    /// The agents currently active in this episode.
103    ///
104    /// Starts equal to `possible_agents()` after `reset()`. Shrinks as agents
105    /// terminate or are truncated; never grows. Empty when the episode is over.
106    fn agents(&self) -> &[Self::AgentId];
107
108    /// Advance the environment by one step using joint actions.
109    ///
110    /// `actions` must contain exactly one entry per agent in `self.agents()`.
111    /// After this call, agents whose result was done are removed from `agents()`.
112    fn step(
113        &mut self,
114        actions: HashMap<Self::AgentId, Self::Action>,
115    ) -> HashMap<Self::AgentId, StepResult<Self::Observation, Self::Info>>;
116
117    /// Reset the environment to an initial state, starting a new episode.
118    ///
119    /// If `seed` is `Some(u64)`, the environment should use it to seed its
120    /// internal RNG for deterministic reproduction of episodes.
121    /// Returns the initial observation and info for every agent.
122    fn reset(
123        &mut self,
124        seed: Option<u64>,
125    ) -> HashMap<Self::AgentId, (Self::Observation, Self::Info)>;
126
127    /// Sample a random action for the given agent.
128    ///
129    /// The `rng` is caller-supplied so exploration randomness can be seeded
130    /// and tracked independently from environment randomness.
131    fn sample_action(&self, agent: &Self::AgentId, rng: &mut impl Rng) -> Self::Action;
132
133    /// A global state observation of the full environment.
134    ///
135    /// Used by centralised-training / decentralised-execution algorithms
136    /// (e.g. MADDPG, QMIX) that condition a centralised critic on the full
137    /// state while individual policies see only local observations.
138    /// Returns `None` by default; override if your environment supports it.
139    fn state(&self) -> Option<Self::Observation> {
140        None
141    }
142
143    /// Returns `true` when all agents have finished (active set is empty).
144    fn is_done(&self) -> bool {
145        self.agents().is_empty()
146    }
147
148    /// Number of currently active agents.
149    fn num_agents(&self) -> usize {
150        self.agents().len()
151    }
152
153    /// Maximum number of agents that could ever be active simultaneously.
154    fn max_num_agents(&self) -> usize {
155        self.possible_agents().len()
156    }
157}
158
159/// A multi-agent environment with Agent Environment Cycle (turn-based) semantics.
160///
161/// Mirrors the semantics of PettingZoo's AEC API, adapted for Rust's type
162/// system. Use this when agents act one at a time — board games, card games,
163/// or any domain where simultaneous action is not meaningful.
164///
165/// # Design principles
166///
167/// - **Turn-based execution**: one agent acts per `step()` call.
168///   `agent_selection()` identifies whose turn it is. After each call,
169///   the selection advances to the next active agent.
170///
171/// - **Persistent state**: the environment tracks each agent's most recent
172///   reward, status, and info as mutable state. Read it via `agent_state()`
173///   before deciding on an action. `last()` is a convenience that combines
174///   `observe()` and `agent_state()` for the current agent.
175///
176/// - **Cycling out terminated agents**: when `agent_state()` reports a done
177///   status for `agent_selection()`, pass `None` to `step()` to advance the
178///   turn without applying an action. The type signature makes this contract
179///   explicit — passing `Some(action)` for a done agent is undefined behaviour.
180///
181/// - **Bevy-compatible by design**: same `Send + Sync + 'static` bounds as
182///   [`ParallelEnvironment`]. The turn-based nature is inherently sequential,
183///   so ECS parallelisation applies less directly than with `ParallelEnvironment`.
184///
185/// - **No `render()`**: visualisation is bevy-gym's concern.
186///
187/// - **No `close()`**: implement `Drop` if your environment holds resources.
188///
189/// # Example
190///
191/// Typical AEC loop:
192///
193/// ```rust,ignore
194/// env.reset(None);
195/// while !env.is_done() {
196///     let (obs, _reward, status, _info) = env.last();
197///     let action = if status.is_done() {
198///         None  // cycle the terminated agent out
199///     } else {
200///         Some(policy.act(env.agent_selection(), &obs.unwrap()))
201///     };
202///     env.step(action);
203/// }
204/// ```
205pub trait AecEnvironment {
206    /// Identifier for each agent. Same semantics as [`ParallelEnvironment::AgentId`].
207    type AgentId: Eq + Hash + Clone + Send + Sync + 'static;
208
209    /// The observation type produced by `observe()`.
210    ///
211    /// `Send + Sync + 'static` are required for Bevy ECS compatibility.
212    type Observation: Clone + Send + Sync + 'static;
213
214    /// The action type consumed by `step()`.
215    type Action: Clone + Send + Sync + 'static;
216
217    /// Auxiliary information returned alongside observations.
218    ///
219    /// Use `()` if you don't need it — `Default` is implemented for `()`.
220    type Info: Default + Clone + Send + Sync + 'static;
221
222    /// The complete, fixed set of agent IDs for this environment.
223    ///
224    /// Does not change between episodes or as agents terminate mid-episode.
225    fn possible_agents(&self) -> &[Self::AgentId];
226
227    /// The agents currently active in this episode.
228    ///
229    /// Starts equal to `possible_agents()` after `reset()`. Shrinks as agents
230    /// terminate or are truncated; never grows. Empty when the episode is over.
231    fn agents(&self) -> &[Self::AgentId];
232
233    /// The agent whose turn it currently is to act.
234    fn agent_selection(&self) -> &Self::AgentId;
235
236    /// Execute the current agent's action and advance to the next agent.
237    ///
238    /// Pass `None` when `agent_state(agent_selection())` reports a done status,
239    /// to cycle the agent out without applying an action. Pass `Some(action)`
240    /// otherwise.
241    fn step(&mut self, action: Option<Self::Action>);
242
243    /// Reset the environment to an initial state, starting a new episode.
244    ///
245    /// Unlike [`ParallelEnvironment::reset`], this returns nothing. Retrieve
246    /// initial observations via `observe()` after calling `reset()`.
247    /// If `seed` is `Some(u64)`, it is used to seed the internal RNG.
248    fn reset(&mut self, seed: Option<u64>);
249
250    /// Retrieve the current observation for the given agent.
251    ///
252    /// Returns `None` if the agent has terminated or been truncated — their
253    /// last observation is no longer valid.
254    fn observe(&self, agent: &Self::AgentId) -> Option<Self::Observation>;
255
256    /// Retrieve the persistent `(reward, status, info)` for the given agent.
257    ///
258    /// This state is updated each time the agent acts and persists until its
259    /// next turn. It reflects what the agent received as a result of its last
260    /// action.
261    fn agent_state(&self, agent: &Self::AgentId) -> (f64, EpisodeStatus, Self::Info);
262
263    /// Sample a random action for the given agent.
264    ///
265    /// The `rng` is caller-supplied so exploration randomness can be seeded
266    /// and tracked independently from environment randomness.
267    fn sample_action(&self, agent: &Self::AgentId, rng: &mut impl Rng) -> Self::Action;
268
269    /// Returns the full state for the currently selected agent.
270    ///
271    /// Convenience wrapper around `observe(agent_selection())` and
272    /// `agent_state(agent_selection())`. This is the idiomatic way to read the
273    /// current agent's situation at the top of the AEC loop.
274    fn last(&self) -> (Option<Self::Observation>, f64, EpisodeStatus, Self::Info) {
275        let agent = self.agent_selection().clone();
276        let obs = self.observe(&agent);
277        let (reward, status, info) = self.agent_state(&agent);
278        (obs, reward, status, info)
279    }
280
281    /// Returns `true` when all agents have finished (active set is empty).
282    fn is_done(&self) -> bool {
283        self.agents().is_empty()
284    }
285
286    /// Number of currently active agents.
287    fn num_agents(&self) -> usize {
288        self.agents().len()
289    }
290
291    /// Maximum number of agents that could ever be active simultaneously.
292    fn max_num_agents(&self) -> usize {
293        self.possible_agents().len()
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use crate::episode::{EpisodeStatus, StepResult};
301
302    // ── ParallelEnvironment mock ─────────────────────────────────────────────
303
304    struct ParallelMock {
305        agents: &'static [usize],
306    }
307
308    impl ParallelEnvironment for ParallelMock {
309        type AgentId = usize;
310        type Observation = ();
311        type Action = ();
312        type Info = ();
313
314        fn possible_agents(&self) -> &[usize] { &[0, 1, 2] }
315        fn agents(&self) -> &[usize] { self.agents }
316
317        fn step(&mut self, _: HashMap<usize, ()>) -> HashMap<usize, StepResult<(), ()>> {
318            unimplemented!()
319        }
320        fn reset(&mut self, _: Option<u64>) -> HashMap<usize, ((), ())> {
321            unimplemented!()
322        }
323        fn sample_action(&self, _: &usize, _: &mut impl rand::Rng) {}
324    }
325
326    // ── AecEnvironment mock ──────────────────────────────────────────────────
327
328    struct AecMock {
329        agents: &'static [usize],
330        current: usize,
331    }
332
333    impl AecEnvironment for AecMock {
334        type AgentId = usize;
335        type Observation = i32;
336        type Action = ();
337        type Info = String;
338
339        fn possible_agents(&self) -> &[usize] { &[0, 1] }
340        fn agents(&self) -> &[usize] { self.agents }
341        fn agent_selection(&self) -> &usize { &self.current }
342
343        fn step(&mut self, _: Option<()>) { unimplemented!() }
344        fn reset(&mut self, _: Option<u64>) { unimplemented!() }
345
346        fn observe(&self, agent: &usize) -> Option<i32> {
347            if *agent == 0 { Some(99) } else { None }
348        }
349
350        fn agent_state(&self, agent: &usize) -> (f64, EpisodeStatus, String) {
351            match agent {
352                0 => (2.5, EpisodeStatus::Continuing, "alive".to_string()),
353                _ => (0.0, EpisodeStatus::Terminated, String::new()),
354            }
355        }
356
357        fn sample_action(&self, _: &usize, _: &mut impl rand::Rng) {}
358    }
359
360    // ── ParallelEnvironment ──────────────────────────────────────────────────
361
362    #[test]
363    fn parallel_is_done_when_agents_empty() {
364        assert!(ParallelMock { agents: &[] }.is_done());
365    }
366
367    #[test]
368    fn parallel_not_done_with_active_agents() {
369        assert!(!ParallelMock { agents: &[0, 1] }.is_done());
370    }
371
372    #[test]
373    fn parallel_num_agents_reflects_active_set() {
374        assert_eq!(ParallelMock { agents: &[0, 1] }.num_agents(), 2);
375    }
376
377    #[test]
378    fn parallel_max_num_agents_reflects_possible_set() {
379        // possible_agents is always [0, 1, 2] regardless of the active set
380        assert_eq!(ParallelMock { agents: &[0] }.max_num_agents(), 3);
381    }
382
383    #[test]
384    fn parallel_state_returns_none_by_default() {
385        assert!(ParallelMock { agents: &[0] }.state().is_none());
386    }
387
388    // ── AecEnvironment ───────────────────────────────────────────────────────
389
390    #[test]
391    fn aec_is_done_when_agents_empty() {
392        assert!(AecMock { agents: &[], current: 0 }.is_done());
393    }
394
395    #[test]
396    fn aec_not_done_with_active_agents() {
397        assert!(!AecMock { agents: &[0, 1], current: 0 }.is_done());
398    }
399
400    #[test]
401    fn aec_num_and_max_num_agents() {
402        let env = AecMock { agents: &[0], current: 0 };
403        assert_eq!(env.num_agents(), 1);
404        assert_eq!(env.max_num_agents(), 2); // possible_agents is [0, 1]
405    }
406
407    #[test]
408    fn aec_last_composes_observe_and_agent_state_for_current_agent() {
409        // Agent 0 is selected — has observation Some(99) and is Continuing.
410        let env = AecMock { agents: &[0, 1], current: 0 };
411        let (obs, reward, status, info) = env.last();
412        assert_eq!(obs, Some(99));
413        assert_eq!(reward, 2.5);
414        assert_eq!(status, EpisodeStatus::Continuing);
415        assert_eq!(info, "alive");
416    }
417
418    #[test]
419    fn aec_last_returns_none_obs_for_terminated_agent() {
420        // Agent 1 is selected — terminated, so observe() returns None.
421        let env = AecMock { agents: &[1], current: 1 };
422        let (obs, _reward, status, _info) = env.last();
423        assert_eq!(obs, None);
424        assert!(status.is_terminal());
425    }
426}