Skip to main content

pursuit/
pursuit.rs

1//! Pursuit — a two-predator cooperative tracking task.
2//!
3//! Two predators on a 1-D grid of length 10 cooperate to catch a randomly
4//! moving prey. This example validates the [`rl_traits::ParallelEnvironment`]
5//! API: per-agent observations, joint actions, and the `Terminated` /
6//! `Truncated` distinction across agents.
7//!
8//! Run with:
9//! ```text
10//! cargo run --example pursuit
11//! ```
12
13use std::collections::HashMap;
14
15use rand::rngs::SmallRng;
16use rand::{Rng, SeedableRng as _};
17use rl_traits::{EpisodeStatus, ParallelEnvironment, StepResult};
18
19// ── Constants ────────────────────────────────────────────────────────────────
20
21const GRID_LEN: u8 = 10;
22const MAX_STEPS: usize = 200;
23
24// ── Observation ──────────────────────────────────────────────────────────────
25
26/// `[predator_pos / (GRID_LEN - 1), prey_pos / (GRID_LEN - 1)]`
27///
28/// Both values are normalised to `[0.0, 1.0]`.
29pub type PursuitObs = [f32; 2];
30
31// ── Environment ──────────────────────────────────────────────────────────────
32
33pub struct Pursuit {
34    active: Vec<u8>,      // currently active predator IDs (0 and/or 1)
35    pos: [u8; 2],         // predator positions indexed by ID
36    prey: u8,             // prey position
37    step: usize,
38    rng: SmallRng,
39}
40
41impl Pursuit {
42    pub fn new(seed: u64) -> Self {
43        Self {
44            active: vec![0, 1],
45            pos: [0, GRID_LEN - 1],
46            prey: GRID_LEN / 2,
47            step: 0,
48            rng: SmallRng::seed_from_u64(seed),
49        }
50    }
51
52    fn obs(&self, id: u8) -> PursuitObs {
53        let scale = (GRID_LEN - 1) as f32;
54        [self.pos[id as usize] as f32 / scale, self.prey as f32 / scale]
55    }
56
57    fn clamp_move(pos: u8, action: u8) -> u8 {
58        match action {
59            0 => pos.saturating_sub(1),
60            _ => (pos + 1).min(GRID_LEN - 1),
61        }
62    }
63}
64
65impl ParallelEnvironment for Pursuit {
66    type AgentId = u8;
67    type Observation = PursuitObs;
68    type Action = u8;   // 0 = move left, 1 = move right
69    type Info = ();
70
71    fn possible_agents(&self) -> &[u8] { &[0, 1] }
72    fn agents(&self) -> &[u8] { &self.active }
73
74    fn step(&mut self, actions: HashMap<u8, u8>)
75        -> HashMap<u8, StepResult<PursuitObs, ()>>
76    {
77        for (&id, &action) in &actions {
78            self.pos[id as usize] = Self::clamp_move(self.pos[id as usize], action);
79        }
80
81        // Prey moves randomly, bouncing at the walls.
82        self.prey = Self::clamp_move(self.prey, self.rng.gen_range(0..2));
83        self.step += 1;
84
85        let caught = self.active.iter().any(|&id| self.pos[id as usize] == self.prey);
86
87        let status = if caught {
88            EpisodeStatus::Terminated
89        } else if self.step >= MAX_STEPS {
90            EpisodeStatus::Truncated
91        } else {
92            EpisodeStatus::Continuing
93        };
94
95        // Build results before mutating the active list.
96        let results = self.active.iter().map(|&id| {
97            let reward = if caught && self.pos[id as usize] == self.prey { 1.0 } else { 0.0 };
98            (id, StepResult::new(self.obs(id), reward, status.clone(), ()))
99        }).collect();
100
101        if status.is_done() {
102            self.active.clear();
103        }
104
105        results
106    }
107
108    fn reset(&mut self, seed: Option<u64>) -> HashMap<u8, (PursuitObs, ())> {
109        if let Some(s) = seed {
110            self.rng = SmallRng::seed_from_u64(s);
111        }
112        self.active = vec![0, 1];
113        self.pos = [0, GRID_LEN - 1];
114        self.prey = GRID_LEN / 2;
115        self.step = 0;
116        [0u8, 1u8].iter().map(|&id| (id, (self.obs(id), ()))).collect()
117    }
118
119    fn sample_action(&self, _agent: &u8, rng: &mut impl Rng) -> u8 {
120        rng.gen_range(0..2)
121    }
122}
123
124// ── Demo loop ────────────────────────────────────────────────────────────────
125
126fn run_episode(env: &mut Pursuit, rng: &mut SmallRng) -> ([f64; 2], EpisodeStatus, usize) {
127    env.reset(None);
128    let mut returns = [0.0_f64; 2];
129    let mut steps = 0;
130    let mut outcome = EpisodeStatus::Continuing;
131
132    while !env.is_done() {
133        let actions = env.agents().iter()
134            .map(|&id| (id, env.sample_action(&id, rng)))
135            .collect();
136
137        let results = env.step(actions);
138        steps += 1;
139
140        for (id, result) in &results {
141            returns[*id as usize] += result.reward;
142            if result.status.is_done() {
143                outcome = result.status.clone();
144            }
145        }
146    }
147
148    (returns, outcome, steps)
149}
150
151fn main() {
152    const NUM_EPISODES: usize = 10;
153    const ENV_SEED: u64 = 42;
154
155    let mut env = Pursuit::new(ENV_SEED);
156    let mut rng = SmallRng::seed_from_u64(0);
157
158    println!("Pursuit — random predators, {NUM_EPISODES} episodes\n");
159    println!("{:<8} {:>8} {:>8} {:>7} {:>10}", "Episode", "Pred. 0", "Pred. 1", "Steps", "Outcome");
160    println!("{}", "-".repeat(46));
161
162    let mut caught = 0;
163
164    for ep in 1..=NUM_EPISODES {
165        let (returns, outcome, steps) = run_episode(&mut env, &mut rng);
166
167        let label = match outcome {
168            EpisodeStatus::Terminated => { caught += 1; "Caught" }
169            EpisodeStatus::Truncated  => "Escaped",
170            EpisodeStatus::Continuing => unreachable!(),
171        };
172
173        println!("{ep:<8} {:>8.1} {:>8.1} {:>7} {:>10}",
174            returns[0], returns[1], steps, label);
175    }
176
177    println!("{}", "-".repeat(46));
178    println!("Prey caught in {caught}/{NUM_EPISODES} episodes");
179}