1use std::collections::HashMap;
14
15use rand::rngs::SmallRng;
16use rand::{Rng, SeedableRng as _};
17use rl_traits::{EpisodeStatus, ParallelEnvironment, StepResult};
18
19const GRID_LEN: u8 = 10;
22const MAX_STEPS: usize = 200;
23
24pub type PursuitObs = [f32; 2];
30
31pub struct Pursuit {
34 active: Vec<u8>, pos: [u8; 2], prey: u8, step: usize,
38 rng: SmallRng,
39}
40
41impl Pursuit {
42 pub fn new(seed: u64) -> Self {
43 Self {
44 active: vec![0, 1],
45 pos: [0, GRID_LEN - 1],
46 prey: GRID_LEN / 2,
47 step: 0,
48 rng: SmallRng::seed_from_u64(seed),
49 }
50 }
51
52 fn obs(&self, id: u8) -> PursuitObs {
53 let scale = (GRID_LEN - 1) as f32;
54 [self.pos[id as usize] as f32 / scale, self.prey as f32 / scale]
55 }
56
57 fn clamp_move(pos: u8, action: u8) -> u8 {
58 match action {
59 0 => pos.saturating_sub(1),
60 _ => (pos + 1).min(GRID_LEN - 1),
61 }
62 }
63}
64
65impl ParallelEnvironment for Pursuit {
66 type AgentId = u8;
67 type Observation = PursuitObs;
68 type Action = u8; type Info = ();
70
71 fn possible_agents(&self) -> &[u8] { &[0, 1] }
72 fn agents(&self) -> &[u8] { &self.active }
73
74 fn step(&mut self, actions: HashMap<u8, u8>)
75 -> HashMap<u8, StepResult<PursuitObs, ()>>
76 {
77 for (&id, &action) in &actions {
78 self.pos[id as usize] = Self::clamp_move(self.pos[id as usize], action);
79 }
80
81 self.prey = Self::clamp_move(self.prey, self.rng.gen_range(0..2));
83 self.step += 1;
84
85 let caught = self.active.iter().any(|&id| self.pos[id as usize] == self.prey);
86
87 let status = if caught {
88 EpisodeStatus::Terminated
89 } else if self.step >= MAX_STEPS {
90 EpisodeStatus::Truncated
91 } else {
92 EpisodeStatus::Continuing
93 };
94
95 let results = self.active.iter().map(|&id| {
97 let reward = if caught && self.pos[id as usize] == self.prey { 1.0 } else { 0.0 };
98 (id, StepResult::new(self.obs(id), reward, status.clone(), ()))
99 }).collect();
100
101 if status.is_done() {
102 self.active.clear();
103 }
104
105 results
106 }
107
108 fn reset(&mut self, seed: Option<u64>) -> HashMap<u8, (PursuitObs, ())> {
109 if let Some(s) = seed {
110 self.rng = SmallRng::seed_from_u64(s);
111 }
112 self.active = vec![0, 1];
113 self.pos = [0, GRID_LEN - 1];
114 self.prey = GRID_LEN / 2;
115 self.step = 0;
116 [0u8, 1u8].iter().map(|&id| (id, (self.obs(id), ()))).collect()
117 }
118
119 fn sample_action(&self, _agent: &u8, rng: &mut impl Rng) -> u8 {
120 rng.gen_range(0..2)
121 }
122}
123
124fn run_episode(env: &mut Pursuit, rng: &mut SmallRng) -> ([f64; 2], EpisodeStatus, usize) {
127 env.reset(None);
128 let mut returns = [0.0_f64; 2];
129 let mut steps = 0;
130 let mut outcome = EpisodeStatus::Continuing;
131
132 while !env.is_done() {
133 let actions = env.agents().iter()
134 .map(|&id| (id, env.sample_action(&id, rng)))
135 .collect();
136
137 let results = env.step(actions);
138 steps += 1;
139
140 for (id, result) in &results {
141 returns[*id as usize] += result.reward;
142 if result.status.is_done() {
143 outcome = result.status.clone();
144 }
145 }
146 }
147
148 (returns, outcome, steps)
149}
150
151fn main() {
152 const NUM_EPISODES: usize = 10;
153 const ENV_SEED: u64 = 42;
154
155 let mut env = Pursuit::new(ENV_SEED);
156 let mut rng = SmallRng::seed_from_u64(0);
157
158 println!("Pursuit — random predators, {NUM_EPISODES} episodes\n");
159 println!("{:<8} {:>8} {:>8} {:>7} {:>10}", "Episode", "Pred. 0", "Pred. 1", "Steps", "Outcome");
160 println!("{}", "-".repeat(46));
161
162 let mut caught = 0;
163
164 for ep in 1..=NUM_EPISODES {
165 let (returns, outcome, steps) = run_episode(&mut env, &mut rng);
166
167 let label = match outcome {
168 EpisodeStatus::Terminated => { caught += 1; "Caught" }
169 EpisodeStatus::Truncated => "Escaped",
170 EpisodeStatus::Continuing => unreachable!(),
171 };
172
173 println!("{ep:<8} {:>8.1} {:>8.1} {:>7} {:>10}",
174 returns[0], returns[1], steps, label);
175 }
176
177 println!("{}", "-".repeat(46));
178 println!("Prey caught in {caught}/{NUM_EPISODES} episodes");
179}