1use std::{
2 sync::{
3 Arc,
4 atomic::{AtomicUsize, Ordering},
5 mpsc::{self, Sender},
6 },
7 thread::spawn,
8};
9
10use burn_core::prelude::Backend;
11
12use crate::{ActionContext, Batchable, Policy, PolicyState};
13
14#[derive(Clone)]
15struct PolicyInferenceServer<B: Backend, P: Policy<B>> {
16 num_agents: Arc<AtomicUsize>,
18 max_autobatch_size: usize,
19 inner_policy: P,
20 batch_action: Vec<ActionItem<P::Observation, P::Action, P::ActionContext>>,
21 batch_logits: Vec<ForwardItem<P::Observation, P::ActionDistribution>>,
22}
23
24impl<B, P> PolicyInferenceServer<B, P>
25where
26 B: Backend,
27 P: Policy<B>,
28 P::Observation: Clone + Batchable,
29 P::ActionDistribution: Clone + Batchable,
30 P::Action: Clone + Batchable,
31 P::ActionContext: Clone,
32{
33 pub fn new(max_autobatch_size: usize, inner_policy: P) -> Self {
34 Self {
35 num_agents: Arc::new(AtomicUsize::new(0)),
36 max_autobatch_size,
37 inner_policy,
38 batch_action: vec![],
39 batch_logits: vec![],
40 }
41 }
42
43 pub fn push_action(&mut self, item: ActionItem<P::Observation, P::Action, P::ActionContext>) {
44 self.batch_action.push(item);
45 if self.len_actions()
46 >= self
47 .num_agents
48 .load(Ordering::Relaxed)
49 .min(self.max_autobatch_size)
50 {
51 self.flush_actions();
52 }
53 }
54
55 pub fn push_logits(&mut self, item: ForwardItem<P::Observation, P::ActionDistribution>) {
56 self.batch_logits.push(item);
57 if self.len_logits()
58 >= self
59 .num_agents
60 .load(Ordering::Relaxed)
61 .min(self.max_autobatch_size)
62 {
63 self.flush_logits();
64 }
65 }
66
67 pub fn len_actions(&self) -> usize {
68 self.batch_action.len()
69 }
70
71 pub fn len_logits(&self) -> usize {
72 self.batch_logits.len()
73 }
74
75 pub fn flush_actions(&mut self) {
76 if self.len_actions() == 0 {
77 return;
78 }
79 let input: Vec<_> = self
80 .batch_action
81 .iter()
82 .map(|m| m.inference_state.clone())
83 .collect();
84 let deterministic = self.batch_action.iter().all(|item| item.deterministic);
86 let (actions, context) = self
87 .inner_policy
88 .action(P::Observation::batch(input), deterministic);
89 let actions: Vec<_> = actions.unbatch();
90
91 for (i, item) in self.batch_action.iter().enumerate() {
92 item.sender
93 .send(ActionContext {
94 context: vec![context[i].clone()],
95 action: actions[i].clone(),
96 })
97 .expect("Autobatcher should be able to send resulting actions.");
98 }
99 self.batch_action.clear();
100 }
101
102 pub fn flush_logits(&mut self) {
103 if self.len_logits() == 0 {
104 return;
105 }
106 let input: Vec<_> = self
107 .batch_logits
108 .iter()
109 .map(|m| m.inference_state.clone())
110 .collect();
111 let output = self.inner_policy.forward(P::Observation::batch(input));
112 let logits: Vec<_> = output.unbatch();
113 for (i, item) in self.batch_logits.iter().enumerate() {
114 item.sender
115 .send(logits[i].clone())
116 .expect("Autobatcher should be able to send resulting probabilities.");
117 }
118 self.batch_logits.clear();
119 }
120
121 pub fn update_policy(&mut self, policy_update: P::PolicyState) {
122 if self.len_actions() > 0 {
123 self.flush_actions();
124 }
125 if self.len_logits() > 0 {
126 self.flush_logits();
127 }
128 self.inner_policy.update(policy_update);
129 }
130
131 pub fn state(&self) -> P::PolicyState {
132 self.inner_policy.state()
133 }
134
135 pub fn increment_agents(&mut self, num: usize) {
136 self.num_agents.fetch_add(num, Ordering::Relaxed);
137 }
138
139 pub fn decrement_agents(&mut self, num: usize) {
140 self.num_agents.fetch_sub(num, Ordering::Relaxed);
141 if self.len_actions()
142 >= self
143 .num_agents
144 .load(Ordering::Relaxed)
145 .min(self.max_autobatch_size)
146 {
147 self.flush_actions();
148 }
149 if self.len_logits()
150 >= self
151 .num_agents
152 .load(Ordering::Relaxed)
153 .min(self.max_autobatch_size)
154 {
155 self.flush_logits();
156 }
157 }
158}
159
160enum InferenceMessage<B: Backend, P: Policy<B>> {
161 ActionMessage(ActionItem<P::Observation, P::Action, P::ActionContext>),
162 ForwardMessage(ForwardItem<P::Observation, P::ActionDistribution>),
163 PolicyUpdate(P::PolicyState),
164 PolicyRequest(Sender<P::PolicyState>),
165 IncrementAgents(usize),
166 DecrementAgents(usize),
167}
168
169#[derive(Clone)]
170struct ActionItem<S, A, C> {
171 sender: Sender<ActionContext<A, Vec<C>>>,
172 inference_state: S,
173 deterministic: bool,
174}
175
176#[derive(Clone)]
177struct ForwardItem<S, O> {
178 sender: Sender<O>,
179 inference_state: S,
180}
181
182#[derive(Clone)]
184pub struct AsyncPolicy<B: Backend, P: Policy<B>> {
185 inference_state_sender: Sender<InferenceMessage<B, P>>,
186}
187
188impl<B, P> AsyncPolicy<B, P>
189where
190 B: Backend,
191 P: Policy<B> + Clone + Send + 'static,
192 P::ActionContext: Clone + Send,
193 P::PolicyState: Send,
194 P::Observation: Clone + Send + Batchable,
195 P::ActionDistribution: Clone + Send + Batchable,
196 P::Action: Clone + Send + Batchable,
197{
198 pub fn new(autobatch_size: usize, inner_policy: P) -> Self {
205 let (sender, receiver) = std::sync::mpsc::channel();
206 let mut autobatcher = PolicyInferenceServer::new(autobatch_size, inner_policy.clone());
207 spawn(move || {
208 loop {
209 match receiver.recv() {
210 Ok(msg) => match msg {
211 InferenceMessage::ActionMessage(item) => autobatcher.push_action(item),
212 InferenceMessage::ForwardMessage(item) => autobatcher.push_logits(item),
213 InferenceMessage::PolicyUpdate(update) => autobatcher.update_policy(update),
214 InferenceMessage::PolicyRequest(sender) => sender
215 .send(autobatcher.state())
216 .expect("Autobatcher should be able to send current policy state."),
217 InferenceMessage::IncrementAgents(num) => autobatcher.increment_agents(num),
218 InferenceMessage::DecrementAgents(num) => autobatcher.decrement_agents(num),
219 },
220 Err(err) => {
221 log::error!("Error in AsyncPolicy : {}", err);
222 break;
223 }
224 }
225 }
226 });
227
228 Self {
229 inference_state_sender: sender,
230 }
231 }
232
233 pub fn increment_agents(&self, num: usize) {
235 self.inference_state_sender
236 .send(InferenceMessage::IncrementAgents(num))
237 .expect("Can send message to autobatcher.")
238 }
239
240 pub fn decrement_agents(&self, num: usize) {
242 self.inference_state_sender
243 .send(InferenceMessage::DecrementAgents(num))
244 .expect("Can send message to autobatcher.")
245 }
246}
247
248impl<B, P> Policy<B> for AsyncPolicy<B, P>
249where
250 B: Backend,
251 P: Policy<B> + Send + 'static,
252{
253 type ActionContext = P::ActionContext;
254 type PolicyState = P::PolicyState;
255
256 type Observation = P::Observation;
257 type ActionDistribution = P::ActionDistribution;
258 type Action = P::Action;
259
260 fn forward(&mut self, states: Self::Observation) -> Self::ActionDistribution {
261 let (action_sender, action_receiver) = std::sync::mpsc::channel();
262 let item = ForwardItem {
263 sender: action_sender,
264 inference_state: states,
265 };
266 self.inference_state_sender
267 .send(InferenceMessage::ForwardMessage(item))
268 .expect("Should be able to send message to inference_server");
269 action_receiver
270 .recv()
271 .expect("AsyncPolicy should receive queued probabilities.")
272 }
273
274 fn action(
275 &mut self,
276 states: Self::Observation,
277 deterministic: bool,
278 ) -> (Self::Action, Vec<Self::ActionContext>) {
279 let (action_sender, action_receiver) = std::sync::mpsc::channel();
280 let item = ActionItem {
281 sender: action_sender,
282 inference_state: states,
283 deterministic,
284 };
285 self.inference_state_sender
286 .send(InferenceMessage::ActionMessage(item))
287 .expect("should be able to send message to inference_server.");
288 let action = action_receiver
289 .recv()
290 .expect("AsyncPolicy should receive queued actions.");
291 (action.action, action.context)
292 }
293
294 fn update(&mut self, update: Self::PolicyState) {
295 self.inference_state_sender
296 .send(InferenceMessage::PolicyUpdate(update))
297 .expect("AsyncPolicy should be able to send policy state.")
298 }
299
300 fn state(&self) -> Self::PolicyState {
301 let (sender, receiver) = mpsc::channel();
302 self.inference_state_sender
303 .send(InferenceMessage::PolicyRequest(sender))
304 .expect("should be able to send message to inference_server.");
305 receiver
306 .recv()
307 .expect("AsyncPolicy should be able to receive policy state.")
308 }
309
310 fn load_record(self, _record: <Self::PolicyState as PolicyState<B>>::Record) -> Self {
311 todo!()
313 }
314}
315
316#[cfg(test)]
317#[allow(clippy::needless_range_loop)]
318mod tests {
319 use std::thread::JoinHandle;
320 use std::time::Duration;
321
322 use crate::TestBackend;
323 use crate::tests::{MockAction, MockObservation, MockPolicy};
324
325 use super::*;
326
327 #[test]
328 fn test_multiple_actions_before_flush() {
329 fn launch_thread(
330 policy: &AsyncPolicy<TestBackend, MockPolicy>,
331 handles: &mut Vec<JoinHandle<()>>,
332 ) {
333 let mut thread_policy = policy.clone();
334 let handle = spawn(move || {
335 thread_policy.action(MockObservation(vec![0.]), false);
336 });
337 handles.push(handle);
338 }
339
340 let policy = AsyncPolicy::new(8, MockPolicy::new());
341 policy.increment_agents(1000);
342
343 let mut handles = vec![];
344 launch_thread(&policy, &mut handles);
345 std::thread::sleep(Duration::from_millis(10));
346 assert!(!handles[0].is_finished());
347
348 for _ in 0..6 {
349 launch_thread(&policy, &mut handles);
350 }
351 std::thread::sleep(Duration::from_millis(10));
352 for i in 0..7 {
353 assert!(!handles[i].is_finished());
354 }
355
356 launch_thread(&policy, &mut handles);
357 std::thread::sleep(Duration::from_millis(10));
358 for i in 0..8 {
359 assert!(handles[i].is_finished());
360 }
361
362 let mut handles = vec![];
363 launch_thread(&policy, &mut handles);
364 std::thread::sleep(Duration::from_millis(10));
365 assert!(!handles[0].is_finished());
366 }
367
368 #[test]
369 fn test_multiple_forward_before_flush() {
370 fn launch_thread(
371 policy: &AsyncPolicy<TestBackend, MockPolicy>,
372 handles: &mut Vec<JoinHandle<()>>,
373 ) {
374 let mut thread_policy = policy.clone();
375 let handle = spawn(move || {
376 thread_policy.forward(MockObservation(vec![0.]));
377 });
378 handles.push(handle);
379 }
380
381 let policy = AsyncPolicy::new(8, MockPolicy::new());
382 policy.increment_agents(1000);
383
384 let mut handles = vec![];
385 launch_thread(&policy, &mut handles);
386 std::thread::sleep(Duration::from_millis(10));
387 assert!(!handles[0].is_finished());
388
389 for _ in 0..6 {
390 launch_thread(&policy, &mut handles);
391 }
392 std::thread::sleep(Duration::from_millis(10));
393 for i in 0..7 {
394 assert!(!handles[i].is_finished());
395 }
396
397 launch_thread(&policy, &mut handles);
398 std::thread::sleep(Duration::from_millis(10));
399 for i in 0..8 {
400 assert!(handles[i].is_finished());
401 }
402
403 let mut handles = vec![];
404 launch_thread(&policy, &mut handles);
405 std::thread::sleep(Duration::from_millis(10));
406 assert!(!handles[0].is_finished());
407 }
408
409 #[test]
410 fn test_async_policy_deterministic_behaviour() {
411 fn launch_thread(
412 policy: &AsyncPolicy<TestBackend, MockPolicy>,
413 handles: &mut Vec<JoinHandle<MockAction>>,
414 deterministic: bool,
415 ) {
416 let mut thread_policy = policy.clone();
417 let handle = spawn(move || {
418 let (action, _) = thread_policy.action(MockObservation(vec![0.]), deterministic);
419 action
420 });
421 handles.push(handle);
422 }
423
424 let policy = AsyncPolicy::new(2, MockPolicy::new());
425 policy.increment_agents(1000);
426
427 let mut handles = vec![];
428 launch_thread(&policy, &mut handles, true);
429 launch_thread(&policy, &mut handles, false);
430 for _ in 0..2 {
431 let action = handles.pop().unwrap().join().unwrap();
432 assert_eq!(action.0, vec![0]);
433 }
434
435 let mut handles = vec![];
436 launch_thread(&policy, &mut handles, true);
437 launch_thread(&policy, &mut handles, true);
438 for _ in 0..2 {
439 let action = handles.pop().unwrap().join().unwrap();
440 assert_eq!(action.0, vec![1]);
441 }
442 }
443
444 #[test]
445 fn flush_when_running_agents_smaller_than_autobatch_size() {
446 fn launch_thread(
447 policy: &AsyncPolicy<TestBackend, MockPolicy>,
448 handles: &mut Vec<JoinHandle<()>>,
449 ) {
450 let mut thread_policy = policy.clone();
451 let handle = spawn(move || {
452 thread_policy.action(MockObservation(vec![0.]), false);
453 });
454 handles.push(handle);
455 }
456
457 let policy = AsyncPolicy::new(8, MockPolicy::new());
458 policy.increment_agents(3);
459
460 let mut handles = vec![];
461 launch_thread(&policy, &mut handles);
462 launch_thread(&policy, &mut handles);
463 std::thread::sleep(Duration::from_millis(10));
464 assert!(!handles[0].is_finished());
465 assert!(!handles[1].is_finished());
466
467 launch_thread(&policy, &mut handles);
468 std::thread::sleep(Duration::from_millis(10));
469 for i in 0..3 {
470 assert!(handles[i].is_finished());
471 }
472
473 let mut handles = vec![];
474 launch_thread(&policy, &mut handles);
475 launch_thread(&policy, &mut handles);
476 std::thread::sleep(Duration::from_millis(10));
477 assert!(!handles[0].is_finished());
478 assert!(!handles[1].is_finished());
479
480 policy.decrement_agents(1);
481 std::thread::sleep(Duration::from_millis(10));
482 assert!(handles[0].is_finished());
483 assert!(handles[1].is_finished());
484 }
485}