Skip to main content

burn_rl/policy/
async_policy.rs

1use std::{
2    sync::{
3        Arc,
4        atomic::{AtomicUsize, Ordering},
5        mpsc::{self, Sender},
6    },
7    thread::spawn,
8};
9
10use burn_core::prelude::Backend;
11
12use crate::{ActionContext, Batchable, Policy, PolicyState};
13
14#[derive(Clone)]
15struct PolicyInferenceServer<B: Backend, P: Policy<B>> {
16    // `num_agents` used to make sure autobatching doesn't block the agents if they are less than the autobatch size.
17    num_agents: Arc<AtomicUsize>,
18    max_autobatch_size: usize,
19    inner_policy: P,
20    batch_action: Vec<ActionItem<P::Observation, P::Action, P::ActionContext>>,
21    batch_logits: Vec<ForwardItem<P::Observation, P::ActionDistribution>>,
22}
23
24impl<B, P> PolicyInferenceServer<B, P>
25where
26    B: Backend,
27    P: Policy<B>,
28    P::Observation: Clone + Batchable,
29    P::ActionDistribution: Clone + Batchable,
30    P::Action: Clone + Batchable,
31    P::ActionContext: Clone,
32{
33    pub fn new(max_autobatch_size: usize, inner_policy: P) -> Self {
34        Self {
35            num_agents: Arc::new(AtomicUsize::new(0)),
36            max_autobatch_size,
37            inner_policy,
38            batch_action: vec![],
39            batch_logits: vec![],
40        }
41    }
42
43    pub fn push_action(&mut self, item: ActionItem<P::Observation, P::Action, P::ActionContext>) {
44        self.batch_action.push(item);
45        if self.len_actions()
46            >= self
47                .num_agents
48                .load(Ordering::Relaxed)
49                .min(self.max_autobatch_size)
50        {
51            self.flush_actions();
52        }
53    }
54
55    pub fn push_logits(&mut self, item: ForwardItem<P::Observation, P::ActionDistribution>) {
56        self.batch_logits.push(item);
57        if self.len_logits()
58            >= self
59                .num_agents
60                .load(Ordering::Relaxed)
61                .min(self.max_autobatch_size)
62        {
63            self.flush_logits();
64        }
65    }
66
67    pub fn len_actions(&self) -> usize {
68        self.batch_action.len()
69    }
70
71    pub fn len_logits(&self) -> usize {
72        self.batch_logits.len()
73    }
74
75    pub fn flush_actions(&mut self) {
76        if self.len_actions() == 0 {
77            return;
78        }
79        let input: Vec<_> = self
80            .batch_action
81            .iter()
82            .map(|m| m.inference_state.clone())
83            .collect();
84        // Only deterministic if all actions are requested as deterministic.
85        let deterministic = self.batch_action.iter().all(|item| item.deterministic);
86        let (actions, context) = self
87            .inner_policy
88            .action(P::Observation::batch(input), deterministic);
89        let actions: Vec<_> = actions.unbatch();
90
91        for (i, item) in self.batch_action.iter().enumerate() {
92            item.sender
93                .send(ActionContext {
94                    context: vec![context[i].clone()],
95                    action: actions[i].clone(),
96                })
97                .expect("Autobatcher should be able to send resulting actions.");
98        }
99        self.batch_action.clear();
100    }
101
102    pub fn flush_logits(&mut self) {
103        if self.len_logits() == 0 {
104            return;
105        }
106        let input: Vec<_> = self
107            .batch_logits
108            .iter()
109            .map(|m| m.inference_state.clone())
110            .collect();
111        let output = self.inner_policy.forward(P::Observation::batch(input));
112        let logits: Vec<_> = output.unbatch();
113        for (i, item) in self.batch_logits.iter().enumerate() {
114            item.sender
115                .send(logits[i].clone())
116                .expect("Autobatcher should be able to send resulting probabilities.");
117        }
118        self.batch_logits.clear();
119    }
120
121    pub fn update_policy(&mut self, policy_update: P::PolicyState) {
122        if self.len_actions() > 0 {
123            self.flush_actions();
124        }
125        if self.len_logits() > 0 {
126            self.flush_logits();
127        }
128        self.inner_policy.update(policy_update);
129    }
130
131    pub fn state(&self) -> P::PolicyState {
132        self.inner_policy.state()
133    }
134
135    pub fn increment_agents(&mut self, num: usize) {
136        self.num_agents.fetch_add(num, Ordering::Relaxed);
137    }
138
139    pub fn decrement_agents(&mut self, num: usize) {
140        self.num_agents.fetch_sub(num, Ordering::Relaxed);
141        if self.len_actions()
142            >= self
143                .num_agents
144                .load(Ordering::Relaxed)
145                .min(self.max_autobatch_size)
146        {
147            self.flush_actions();
148        }
149        if self.len_logits()
150            >= self
151                .num_agents
152                .load(Ordering::Relaxed)
153                .min(self.max_autobatch_size)
154        {
155            self.flush_logits();
156        }
157    }
158}
159
160enum InferenceMessage<B: Backend, P: Policy<B>> {
161    ActionMessage(ActionItem<P::Observation, P::Action, P::ActionContext>),
162    ForwardMessage(ForwardItem<P::Observation, P::ActionDistribution>),
163    PolicyUpdate(P::PolicyState),
164    PolicyRequest(Sender<P::PolicyState>),
165    IncrementAgents(usize),
166    DecrementAgents(usize),
167}
168
169#[derive(Clone)]
170struct ActionItem<S, A, C> {
171    sender: Sender<ActionContext<A, Vec<C>>>,
172    inference_state: S,
173    deterministic: bool,
174}
175
176#[derive(Clone)]
177struct ForwardItem<S, O> {
178    sender: Sender<O>,
179    inference_state: S,
180}
181
182/// An asynchronous policy using an inference server with autobatching.
183#[derive(Clone)]
184pub struct AsyncPolicy<B: Backend, P: Policy<B>> {
185    inference_state_sender: Sender<InferenceMessage<B, P>>,
186}
187
188impl<B, P> AsyncPolicy<B, P>
189where
190    B: Backend,
191    P: Policy<B> + Clone + Send + 'static,
192    P::ActionContext: Clone + Send,
193    P::PolicyState: Send,
194    P::Observation: Clone + Send + Batchable,
195    P::ActionDistribution: Clone + Send + Batchable,
196    P::Action: Clone + Send + Batchable,
197{
198    /// Create the policy.
199    ///
200    /// # Arguments
201    ///
202    /// * `autobatch_size` - Number of observations to accumulate before running a pass of inference.
203    /// * `inner_policy` - The policy used to take actions.
204    pub fn new(autobatch_size: usize, inner_policy: P) -> Self {
205        let (sender, receiver) = std::sync::mpsc::channel();
206        let mut autobatcher = PolicyInferenceServer::new(autobatch_size, inner_policy.clone());
207        spawn(move || {
208            loop {
209                match receiver.recv() {
210                    Ok(msg) => match msg {
211                        InferenceMessage::ActionMessage(item) => autobatcher.push_action(item),
212                        InferenceMessage::ForwardMessage(item) => autobatcher.push_logits(item),
213                        InferenceMessage::PolicyUpdate(update) => autobatcher.update_policy(update),
214                        InferenceMessage::PolicyRequest(sender) => sender
215                            .send(autobatcher.state())
216                            .expect("Autobatcher should be able to send current policy state."),
217                        InferenceMessage::IncrementAgents(num) => autobatcher.increment_agents(num),
218                        InferenceMessage::DecrementAgents(num) => autobatcher.decrement_agents(num),
219                    },
220                    Err(err) => {
221                        log::error!("Error in AsyncPolicy : {}", err);
222                        break;
223                    }
224                }
225            }
226        });
227
228        Self {
229            inference_state_sender: sender,
230        }
231    }
232
233    /// Increment the number of agents using the inference server.
234    pub fn increment_agents(&self, num: usize) {
235        self.inference_state_sender
236            .send(InferenceMessage::IncrementAgents(num))
237            .expect("Can send message to autobatcher.")
238    }
239
240    /// Decrement the number of agents using the inference server.
241    pub fn decrement_agents(&self, num: usize) {
242        self.inference_state_sender
243            .send(InferenceMessage::DecrementAgents(num))
244            .expect("Can send message to autobatcher.")
245    }
246}
247
248impl<B, P> Policy<B> for AsyncPolicy<B, P>
249where
250    B: Backend,
251    P: Policy<B> + Send + 'static,
252{
253    type ActionContext = P::ActionContext;
254    type PolicyState = P::PolicyState;
255
256    type Observation = P::Observation;
257    type ActionDistribution = P::ActionDistribution;
258    type Action = P::Action;
259
260    fn forward(&mut self, states: Self::Observation) -> Self::ActionDistribution {
261        let (action_sender, action_receiver) = std::sync::mpsc::channel();
262        let item = ForwardItem {
263            sender: action_sender,
264            inference_state: states,
265        };
266        self.inference_state_sender
267            .send(InferenceMessage::ForwardMessage(item))
268            .expect("Should be able to send message to inference_server");
269        action_receiver
270            .recv()
271            .expect("AsyncPolicy should receive queued probabilities.")
272    }
273
274    fn action(
275        &mut self,
276        states: Self::Observation,
277        deterministic: bool,
278    ) -> (Self::Action, Vec<Self::ActionContext>) {
279        let (action_sender, action_receiver) = std::sync::mpsc::channel();
280        let item = ActionItem {
281            sender: action_sender,
282            inference_state: states,
283            deterministic,
284        };
285        self.inference_state_sender
286            .send(InferenceMessage::ActionMessage(item))
287            .expect("should be able to send message to inference_server.");
288        let action = action_receiver
289            .recv()
290            .expect("AsyncPolicy should receive queued actions.");
291        (action.action, action.context)
292    }
293
294    fn update(&mut self, update: Self::PolicyState) {
295        self.inference_state_sender
296            .send(InferenceMessage::PolicyUpdate(update))
297            .expect("AsyncPolicy should be able to send policy state.")
298    }
299
300    fn state(&self) -> Self::PolicyState {
301        let (sender, receiver) = mpsc::channel();
302        self.inference_state_sender
303            .send(InferenceMessage::PolicyRequest(sender))
304            .expect("should be able to send message to inference_server.");
305        receiver
306            .recv()
307            .expect("AsyncPolicy should be able to receive policy state.")
308    }
309
310    fn load_record(self, _record: <Self::PolicyState as PolicyState<B>>::Record) -> Self {
311        // Not needed for now
312        todo!()
313    }
314}
315
316#[cfg(test)]
317#[allow(clippy::needless_range_loop)]
318mod tests {
319    use std::thread::JoinHandle;
320    use std::time::Duration;
321
322    use crate::TestBackend;
323    use crate::tests::{MockAction, MockObservation, MockPolicy};
324
325    use super::*;
326
327    #[test]
328    fn test_multiple_actions_before_flush() {
329        fn launch_thread(
330            policy: &AsyncPolicy<TestBackend, MockPolicy>,
331            handles: &mut Vec<JoinHandle<()>>,
332        ) {
333            let mut thread_policy = policy.clone();
334            let handle = spawn(move || {
335                thread_policy.action(MockObservation(vec![0.]), false);
336            });
337            handles.push(handle);
338        }
339
340        let policy = AsyncPolicy::new(8, MockPolicy::new());
341        policy.increment_agents(1000);
342
343        let mut handles = vec![];
344        launch_thread(&policy, &mut handles);
345        std::thread::sleep(Duration::from_millis(10));
346        assert!(!handles[0].is_finished());
347
348        for _ in 0..6 {
349            launch_thread(&policy, &mut handles);
350        }
351        std::thread::sleep(Duration::from_millis(10));
352        for i in 0..7 {
353            assert!(!handles[i].is_finished());
354        }
355
356        launch_thread(&policy, &mut handles);
357        std::thread::sleep(Duration::from_millis(10));
358        for i in 0..8 {
359            assert!(handles[i].is_finished());
360        }
361
362        let mut handles = vec![];
363        launch_thread(&policy, &mut handles);
364        std::thread::sleep(Duration::from_millis(10));
365        assert!(!handles[0].is_finished());
366    }
367
368    #[test]
369    fn test_multiple_forward_before_flush() {
370        fn launch_thread(
371            policy: &AsyncPolicy<TestBackend, MockPolicy>,
372            handles: &mut Vec<JoinHandle<()>>,
373        ) {
374            let mut thread_policy = policy.clone();
375            let handle = spawn(move || {
376                thread_policy.forward(MockObservation(vec![0.]));
377            });
378            handles.push(handle);
379        }
380
381        let policy = AsyncPolicy::new(8, MockPolicy::new());
382        policy.increment_agents(1000);
383
384        let mut handles = vec![];
385        launch_thread(&policy, &mut handles);
386        std::thread::sleep(Duration::from_millis(10));
387        assert!(!handles[0].is_finished());
388
389        for _ in 0..6 {
390            launch_thread(&policy, &mut handles);
391        }
392        std::thread::sleep(Duration::from_millis(10));
393        for i in 0..7 {
394            assert!(!handles[i].is_finished());
395        }
396
397        launch_thread(&policy, &mut handles);
398        std::thread::sleep(Duration::from_millis(10));
399        for i in 0..8 {
400            assert!(handles[i].is_finished());
401        }
402
403        let mut handles = vec![];
404        launch_thread(&policy, &mut handles);
405        std::thread::sleep(Duration::from_millis(10));
406        assert!(!handles[0].is_finished());
407    }
408
409    #[test]
410    fn test_async_policy_deterministic_behaviour() {
411        fn launch_thread(
412            policy: &AsyncPolicy<TestBackend, MockPolicy>,
413            handles: &mut Vec<JoinHandle<MockAction>>,
414            deterministic: bool,
415        ) {
416            let mut thread_policy = policy.clone();
417            let handle = spawn(move || {
418                let (action, _) = thread_policy.action(MockObservation(vec![0.]), deterministic);
419                action
420            });
421            handles.push(handle);
422        }
423
424        let policy = AsyncPolicy::new(2, MockPolicy::new());
425        policy.increment_agents(1000);
426
427        let mut handles = vec![];
428        launch_thread(&policy, &mut handles, true);
429        launch_thread(&policy, &mut handles, false);
430        for _ in 0..2 {
431            let action = handles.pop().unwrap().join().unwrap();
432            assert_eq!(action.0, vec![0]);
433        }
434
435        let mut handles = vec![];
436        launch_thread(&policy, &mut handles, true);
437        launch_thread(&policy, &mut handles, true);
438        for _ in 0..2 {
439            let action = handles.pop().unwrap().join().unwrap();
440            assert_eq!(action.0, vec![1]);
441        }
442    }
443
444    #[test]
445    fn flush_when_running_agents_smaller_than_autobatch_size() {
446        fn launch_thread(
447            policy: &AsyncPolicy<TestBackend, MockPolicy>,
448            handles: &mut Vec<JoinHandle<()>>,
449        ) {
450            let mut thread_policy = policy.clone();
451            let handle = spawn(move || {
452                thread_policy.action(MockObservation(vec![0.]), false);
453            });
454            handles.push(handle);
455        }
456
457        let policy = AsyncPolicy::new(8, MockPolicy::new());
458        policy.increment_agents(3);
459
460        let mut handles = vec![];
461        launch_thread(&policy, &mut handles);
462        launch_thread(&policy, &mut handles);
463        std::thread::sleep(Duration::from_millis(10));
464        assert!(!handles[0].is_finished());
465        assert!(!handles[1].is_finished());
466
467        launch_thread(&policy, &mut handles);
468        std::thread::sleep(Duration::from_millis(10));
469        for i in 0..3 {
470            assert!(handles[i].is_finished());
471        }
472
473        let mut handles = vec![];
474        launch_thread(&policy, &mut handles);
475        launch_thread(&policy, &mut handles);
476        std::thread::sleep(Duration::from_millis(10));
477        assert!(!handles[0].is_finished());
478        assert!(!handles[1].is_finished());
479
480        policy.decrement_agents(1);
481        std::thread::sleep(Duration::from_millis(10));
482        assert!(handles[0].is_finished());
483        assert!(handles[1].is_finished());
484    }
485}