1use std::{any::Any, cmp::max};
7
8use quanta::{Clock, Instant};
9
10use threadpool::ThreadPool;
11
12use crossbeam::channel::{unbounded, Receiver};
13
14use ncomm_core::{Executor, ExecutorState, Node};
15
16use crate::{insert_into, NodeWrapper};
17
18pub struct ThreadPoolExecutor<ID: PartialEq> {
31 backing: Vec<NodeWrapper<ID>>,
33 clock: Clock,
35 pool: ThreadPool,
37 state: ExecutorState,
39 start_instant: Instant,
41 interrupt: Receiver<bool>,
43 interrupted: bool,
45}
46
47impl<ID: PartialEq> ThreadPoolExecutor<ID> {
48 pub fn new(threads: usize, interrupt: Receiver<bool>) -> Self {
50 let clock = Clock::new();
51 let now = clock.now();
52 let pool = ThreadPool::new(max(1, threads.saturating_sub(1)));
53
54 Self {
55 backing: Vec::new(),
56 clock,
57 pool,
58 state: ExecutorState::Stopped,
59 start_instant: now,
60 interrupt,
61 interrupted: false,
62 }
63 }
64
65 pub fn new_with(
67 threads: usize,
68 interrupt: Receiver<bool>,
69 mut nodes: Vec<Box<dyn Node<ID>>>,
70 ) -> Self {
71 let mut backing = Vec::new();
72 for node in nodes.drain(..) {
73 backing.push(NodeWrapper { priority: 0, node });
74 }
75
76 let clock = Clock::new();
77 let now = clock.now();
78 let pool = ThreadPool::new(max(1, threads.saturating_sub(1)));
79
80 Self {
81 backing,
82 clock,
83 pool,
84 state: ExecutorState::Stopped,
85 start_instant: now,
86 interrupt,
87 interrupted: false,
88 }
89 }
90}
91
92impl<ID: PartialEq + 'static> Executor<ID> for ThreadPoolExecutor<ID> {
93 type Context = Box<dyn Any>;
95
96 fn start(&mut self) {
102 for node_wrapper in self.backing.iter_mut() {
103 node_wrapper.priority = 0;
104 node_wrapper.node.start();
105 }
106
107 self.interrupted = false;
108 self.state = ExecutorState::Started;
109 self.start_instant = self.clock.now();
110 }
111
112 fn update_for_ms(&mut self, ms: u128) {
113 self.start();
115
116 self.state = ExecutorState::Running;
118 let (node_tx, node_rx) = unbounded();
119 while self
120 .clock
121 .now()
122 .duration_since(self.start_instant)
123 .as_millis()
124 < ms
125 && !self.check_interrupt()
126 {
127 if self.backing.last().is_some()
128 && self
129 .clock
130 .now()
131 .duration_since(self.start_instant)
132 .as_micros()
133 >= self.backing.last().unwrap().priority
134 {
135 let mut node_wrapper = self.backing.pop().unwrap();
136 let node_tx = node_tx.clone();
137 self.pool.execute(move || {
138 node_wrapper.node.update();
139 node_wrapper.priority += node_wrapper.node.get_update_delay_us();
140 node_tx.send(node_wrapper).unwrap();
141 });
142 }
143
144 if let Ok(node_wrapper) = node_rx.try_recv() {
145 insert_into(&mut self.backing, node_wrapper);
146 }
147 }
148
149 for node_wrapper in self.backing.iter_mut() {
151 node_wrapper.priority = 0;
152 node_wrapper.node.shutdown();
153 }
154 self.state = ExecutorState::Stopped;
155 }
156
157 fn update_loop(&mut self) {
158 self.start();
160
161 self.state = ExecutorState::Running;
163 let (node_tx, node_rx) = unbounded();
164 while !self.check_interrupt() {
165 if self.backing.last().is_some()
166 && self
167 .clock
168 .now()
169 .duration_since(self.start_instant)
170 .as_micros()
171 >= self.backing.last().unwrap().priority
172 {
173 let mut node_wrapper = self.backing.pop().unwrap();
174 let node_tx = node_tx.clone();
175 self.pool.execute(move || {
176 node_wrapper.node.update();
177 node_wrapper.priority += node_wrapper.node.get_update_delay_us();
178 node_tx.send(node_wrapper).unwrap();
179 });
180 }
181
182 if let Ok(node_wrapper) = node_rx.try_recv() {
183 insert_into(&mut self.backing, node_wrapper);
184 }
185 }
186
187 for node_wrapper in self.backing.iter_mut() {
189 node_wrapper.priority = 0;
190 node_wrapper.node.shutdown();
191 }
192 self.state = ExecutorState::Stopped;
193 }
194
195 fn check_interrupt(&mut self) -> bool {
197 if let Ok(interrupt) = self.interrupt.try_recv() {
198 self.interrupted = interrupt;
199 }
200 self.interrupted
201 }
202
203 fn add_node(&mut self, node: Box<dyn Node<ID>>) {
210 if let Some(idx) = self
211 .backing
212 .iter()
213 .position(|node_wrapper| node_wrapper.node.get_id().eq(&node.get_id()))
214 {
215 self.backing.remove(idx);
216 }
217
218 if self.state == ExecutorState::Stopped {
219 self.backing.push(NodeWrapper { priority: 0, node });
220 } else if self.state == ExecutorState::Started {
221 insert_into(
222 &mut self.backing,
223 NodeWrapper {
224 priority: self
225 .clock
226 .now()
227 .duration_since(self.start_instant)
228 .as_micros(),
229 node,
230 },
231 );
232 }
233 }
234
235 fn remove_node(&mut self, id: &ID) -> Option<Box<dyn Node<ID>>> {
239 if self.state != ExecutorState::Running {
240 let idx = self
241 .backing
242 .iter()
243 .position(|node_wrapper| node_wrapper.node.get_id().eq(id));
244 if let Some(idx) = idx {
245 Some(self.backing.remove(idx).destroy())
246 } else {
247 None
248 }
249 } else {
250 None
251 }
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 use std::{any::Any, thread, time::Duration};
260
261 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
262 enum State {
263 Stopped,
264 Started,
265 Updating,
266 }
267
268 struct SimpleNode {
269 id: u8,
270 update_delay: u128,
271 num: u8,
272 state: State,
273 }
274
275 impl SimpleNode {
276 pub fn new(id: u8, update_delay: u128) -> Self {
277 Self {
278 id,
279 update_delay,
280 num: 0,
281 state: State::Stopped,
282 }
283 }
284 }
285
286 impl Node<u8> for SimpleNode {
287 fn get_id(&self) -> u8 {
288 self.id
289 }
290
291 fn start(&mut self) {
292 self.state = State::Started;
293 }
294
295 fn update(&mut self) {
296 self.state = State::Updating;
297 self.num = self.num.wrapping_add(1);
298 }
299
300 fn shutdown(&mut self) {
301 self.state = State::Stopped;
302 }
303
304 fn get_update_delay_us(&self) -> u128 {
305 self.update_delay
306 }
307 }
308
309 #[test]
310 fn test_start() {
311 let (_, rx) = unbounded();
312
313 let mut executor = ThreadPoolExecutor::new_with(
314 3,
315 rx,
316 vec![
317 Box::new(SimpleNode::new(0, 10_000)),
318 Box::new(SimpleNode::new(1, 25_000)),
319 ],
320 );
321 let original_start_instant = executor.start_instant;
322
323 executor.start();
324
325 for node_wrapper in executor.backing.iter() {
326 assert_eq!(node_wrapper.priority, 0);
327 let simple_node: &dyn Any = &node_wrapper.node;
328 let simple_node: &Box<SimpleNode> = unsafe { simple_node.downcast_ref_unchecked() };
329 assert_eq!(simple_node.state, State::Started);
330 }
331
332 assert!(!executor.interrupted);
333 assert_eq!(executor.state, ExecutorState::Started);
334 assert!(executor.start_instant > original_start_instant);
335 }
336
337 #[test]
338 fn test_update_for_ms() {
339 let (_, rx) = unbounded();
340
341 let mut executor = ThreadPoolExecutor::new_with(
342 3,
343 rx,
344 vec![
345 Box::new(SimpleNode::new(0, 10_000)),
346 Box::new(SimpleNode::new(1, 25_000)),
347 ],
348 );
349
350 let start = executor.clock.now();
351 executor.update_for_ms(100);
352 let end = executor.clock.now();
353
354 for node_wrapper in executor.backing.iter() {
356 assert_eq!(node_wrapper.priority, 0);
357 let simple_node: &dyn Any = &node_wrapper.node;
358 let simple_node: &Box<SimpleNode> = unsafe { simple_node.downcast_ref_unchecked() };
359 assert_eq!(simple_node.state, State::Stopped);
360 assert!([3, 4, 5, 9, 10, 11].contains(&simple_node.num));
361 }
362
363 assert!(Duration::from_millis(95) < end - start);
364 assert!(end - start < Duration::from_millis(105));
365 }
366
367 #[test]
368 fn test_check_interrupt() {
369 let (tx, rx) = unbounded();
370
371 let mut executor = ThreadPoolExecutor::new_with(
372 3,
373 rx,
374 vec![
375 Box::new(SimpleNode::new(0, 10_000)),
376 Box::new(SimpleNode::new(1, 25_000)),
377 ],
378 );
379
380 tx.send(true).unwrap();
381
382 assert!(executor.check_interrupt());
383 }
384
385 #[test]
386 fn test_add_node() {
387 let (_, rx) = unbounded();
388
389 let mut executor = ThreadPoolExecutor::new_with(
390 3,
391 rx,
392 vec![
393 Box::new(SimpleNode::new(0, 10_000)),
394 Box::new(SimpleNode::new(1, 25_000)),
395 ],
396 );
397
398 executor.add_node(Box::new(SimpleNode::new(2, 1_000)));
399
400 assert_eq!(executor.backing.len(), 3);
401 }
402
403 #[test]
404 fn test_add_node_same_id() {
405 let (_, rx) = unbounded();
406
407 let mut executor = ThreadPoolExecutor::new_with(
408 3,
409 rx,
410 vec![
411 Box::new(SimpleNode::new(0, 10_000)),
412 Box::new(SimpleNode::new(1, 25_000)),
413 ],
414 );
415
416 executor.add_node(Box::new(SimpleNode::new(0, 1_000)));
417
418 assert_eq!(executor.backing.len(), 2);
419 let node_zero = executor
420 .backing
421 .iter()
422 .find(|node_wrapper| node_wrapper.node.get_id().eq(&0))
423 .unwrap();
424 assert_eq!(node_zero.node.get_update_delay_us(), 1_000);
425 }
426
427 #[test]
428 fn test_remove_node() {
429 let (_, rx) = unbounded();
430
431 let mut executor = ThreadPoolExecutor::new_with(
432 3,
433 rx,
434 vec![
435 Box::new(SimpleNode::new(0, 10_000)),
436 Box::new(SimpleNode::new(1, 25_000)),
437 ],
438 );
439
440 executor.remove_node(&0);
441
442 assert_eq!(executor.backing.len(), 1);
443 assert_eq!(executor.backing[0].node.get_id(), 1);
444 }
445
446 #[test]
447 fn test_update_loop() {
448 let (tx, rx) = unbounded();
449
450 let mut executor = ThreadPoolExecutor::new_with(
451 2,
452 rx,
453 vec![
454 Box::new(SimpleNode::new(0, 10_000)),
455 Box::new(SimpleNode::new(1, 25_000)),
456 ],
457 );
458
459 let handle = thread::spawn(move || {
460 executor.update_loop();
461 executor
462 });
463
464 thread::sleep(Duration::from_millis(100));
465 tx.send(true).unwrap();
466
467 let executor = handle.join().unwrap();
468 for node_wrapper in executor.backing.iter() {
469 assert_eq!(node_wrapper.priority, 0);
470 let simple_node: &dyn Any = &node_wrapper.node;
471 let simple_node: &Box<SimpleNode> = unsafe { simple_node.downcast_ref_unchecked() };
472 assert_eq!(simple_node.state, State::Stopped);
473 assert!([3, 4, 5, 9, 10, 11].contains(&simple_node.num));
474 }
475
476 assert!(executor.interrupted);
477 assert_eq!(executor.state, ExecutorState::Stopped);
478 }
479}