fips_md/codegen/
callback_thread.rs1use std::{any::Any, collections::{BTreeSet, HashMap}, sync::{self, Arc, mpsc}, thread::{JoinHandle, spawn}};
4
5use anyhow::Result;
6
7use super::{GlobalContext, analysis::BarrierID};
8
9pub type CallbackType = fn(&Arc<GlobalContext>, &mut Box<dyn Any + Send>);
10pub type CallbackStateType = Box<dyn Any + Send>;
11
12pub(crate) enum CallbackMessage {
13 Call(BarrierID),
14 Register(BTreeSet<BarrierID>, CallbackType, CallbackStateType),
15 Unregister(BTreeSet<BarrierID>),
16 Quit
17}
18
19struct CallbackRegistry {
21 registry: HashMap<BTreeSet<BarrierID>, (CallbackType, CallbackStateType)>
22}
23
24impl CallbackRegistry {
25 fn new() -> Self {
26 Self {
27 registry: HashMap::new()
28 }
29 }
30
31 fn insert(&mut self, barriers: BTreeSet<BarrierID>,
32 callback: CallbackType, state: CallbackStateType)
33 {
34 self.registry.insert(barriers, (callback, state));
36 }
37
38 fn get_mut(&mut self, barrier: BarrierID) -> Option<&mut (CallbackType, CallbackStateType)> {
48 for (barriers, callback_data) in self.registry.iter_mut() {
49 if barriers.contains(&barrier) {
50 return Some(callback_data)
51 }
52 }
53 None
54 }
55
56 fn remove(&mut self, barriers: &BTreeSet<BarrierID>) -> Option<(CallbackType, CallbackStateType)> {
57 self.registry.remove(barriers)
58 }
59}
60
61pub(crate) struct CallbackThread {
62 sender: mpsc::Sender<CallbackMessage>,
64 unregister_receiver: mpsc::Receiver<(CallbackType, CallbackStateType)>,
66 thread: JoinHandle<()>
68}
69
70impl CallbackThread {
71 pub fn new(call_end_barrier: Arc<sync::Barrier>, num_workers: usize,
73 global_context: Arc<GlobalContext>)
74 -> Self {
75 let (sender, receiver) = mpsc::channel();
77 let (unregister_sender, unregister_receiver) = mpsc::channel();
78 let thread = spawn(move || {
80 let mut callbacks = CallbackRegistry::new();
82
83 loop {
84 let mut callback_barrier = None;
86 for i in 0..num_workers {
87 let message = receiver.recv().expect("Callback channel senders have hung up");
88 match message {
89 CallbackMessage::Quit => { return },
90 CallbackMessage::Register(barriers, callback, state) => {
91 if i != 0 {
92 panic!("Got callback registration while waiting for workers")
93 }
94 callbacks.insert(barriers, callback, state);
95 break;
96 },
97 CallbackMessage::Unregister(barriers) => {
98 if i != 0 {
99 panic!("Got callback unregistration while waiting for workers")
100 }
101 let callback_data = callbacks.remove(&barriers)
103 .expect("Cannot unregister barrier set: No callback defined.");
104 unregister_sender.send(callback_data)
105 .expect("Callback return channel has hung up.");
106 break;
107 },
108 CallbackMessage::Call(barrier) => {
109 match &mut callback_barrier {
110 None => {
111 callback_barrier = Some(barrier);
112 continue;
113 }
114 Some(previous_barrier) => {
115 if *previous_barrier == barrier {
117 continue;
118 }
119 else {
121 panic!("Got conflicting callback barrier IDs!");
122 }
123 }}
124 }}
125 }
126 if let Some(callback_barrier) = callback_barrier {
127 if let Some((callback, state)) = callbacks.get_mut(callback_barrier) {
129 callback(&global_context, state);
130 }
131 call_end_barrier.wait();
133 }
134 }
135 });
136 Self {
137 sender,
138 unregister_receiver,
139 thread
140 }
141 }
142
143 pub(crate) fn get_sender(&self) -> mpsc::Sender<CallbackMessage> {
145 self.sender.clone()
146 }
147
148 pub(crate) fn join(self) {
149 self.sender.send(CallbackMessage::Quit).expect("Callback thread has hung up");
150 self.thread.join().expect("Callback thread has panicked at some point");
151 }
152
153 pub fn register_callback(&mut self, barriers: BTreeSet<BarrierID>,
154 callback: CallbackType,
155 callback_state: CallbackStateType)
156 -> Result<()> {
157 self.sender.send(CallbackMessage::Register(barriers, callback, callback_state))
159 .expect("Callback thread has hung up");
160 Ok(())
161 }
162
163 pub fn unregister_callback(&mut self, barriers: BTreeSet<BarrierID>)
164 -> Result<(CallbackType, CallbackStateType)> {
165 self.sender.send(CallbackMessage::Unregister(barriers))
166 .expect("Callback thread has hung up");
167 let callback_data = self.unregister_receiver.recv()
168 .expect("Callback thread has hung up");
169 Ok(callback_data)
170 }
171}