1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
//! Thread for executing "call to Rust" callbacks
use std::{any::Any, collections::{BTreeSet, HashMap}, sync::{self, Arc, mpsc}, thread::{JoinHandle, spawn}};
use anyhow::Result;
use super::{GlobalContext, analysis::BarrierID};
pub type CallbackType = fn(&Arc<GlobalContext>, &mut Box<dyn Any + Send>);
pub type CallbackStateType = Box<dyn Any + Send>;
pub(crate) enum CallbackMessage {
Call(BarrierID),
Register(BTreeSet<BarrierID>, CallbackType, CallbackStateType),
Unregister(BTreeSet<BarrierID>),
Quit
}
/// Registry for callbacks: Every callback can be associated with one ore more barriers
struct CallbackRegistry {
registry: HashMap<BTreeSet<BarrierID>, (CallbackType, CallbackStateType)>
}
impl CallbackRegistry {
fn new() -> Self {
Self {
registry: HashMap::new()
}
}
fn insert(&mut self, barriers: BTreeSet<BarrierID>,
callback: CallbackType, state: CallbackStateType)
{
// TODO: Assert disjointness of barriers in registry?
self.registry.insert(barriers, (callback, state));
}
// fn get(&self, barrier: BarrierID) -> Option<&(fn(GlobalContext, &mut Box<dyn Any + Send>), Box<dyn Any>)> {
// for (barriers, callback_data) in self.registry.iter() {
// if barriers.contains(&barrier) {
// return Some(callback_data)
// }
// }
// None
// }
fn get_mut(&mut self, barrier: BarrierID) -> Option<&mut (CallbackType, CallbackStateType)> {
for (barriers, callback_data) in self.registry.iter_mut() {
if barriers.contains(&barrier) {
return Some(callback_data)
}
}
None
}
fn remove(&mut self, barriers: &BTreeSet<BarrierID>) -> Option<(CallbackType, CallbackStateType)> {
self.registry.remove(barriers)
}
}
pub(crate) struct CallbackThread {
/// Sender for tasks to the callback thread
sender: mpsc::Sender<CallbackMessage>,
/// Receiver for getting callback data out of the callback thread
unregister_receiver: mpsc::Receiver<(CallbackType, CallbackStateType)>,
/// The actual handle to the callback thread
thread: JoinHandle<()>
}
impl CallbackThread {
/// Create a new callback thread that will wait on the given barrier after each call
pub fn new(call_end_barrier: Arc<sync::Barrier>, num_workers: usize,
global_context: Arc<GlobalContext>)
-> Self {
// Create MPSC channel for communication between workers and callback thread
let (sender, receiver) = mpsc::channel();
let (unregister_sender, unregister_receiver) = mpsc::channel();
// Spawn callback thread
let thread = spawn(move || {
// Callback registry
let mut callbacks = CallbackRegistry::new();
loop {
// Wait until every worker has ordered us to do the callback
let mut callback_barrier = None;
for i in 0..num_workers {
let message = receiver.recv().expect("Callback channel senders have hung up");
match message {
CallbackMessage::Quit => { return },
CallbackMessage::Register(barriers, callback, state) => {
if i != 0 {
panic!("Got callback registration while waiting for workers")
}
callbacks.insert(barriers, callback, state);
break;
},
CallbackMessage::Unregister(barriers) => {
if i != 0 {
panic!("Got callback unregistration while waiting for workers")
}
// TODO: Fail more gracefully?
let callback_data = callbacks.remove(&barriers)
.expect("Cannot unregister barrier set: No callback defined.");
unregister_sender.send(callback_data)
.expect("Callback return channel has hung up.");
break;
},
CallbackMessage::Call(barrier) => {
match &mut callback_barrier {
None => {
callback_barrier = Some(barrier);
continue;
}
Some(previous_barrier) => {
// Verify that we got the right call barrier
if *previous_barrier == barrier {
continue;
}
// Panic otherwise
else {
panic!("Got conflicting callback barrier IDs!");
}
}}
}}
}
if let Some(callback_barrier) = callback_barrier {
// Now callback_barrier is consistent and we can call the correct function
if let Some((callback, state)) = callbacks.get_mut(callback_barrier) {
callback(&global_context, state);
}
// Finally unlock the call_end barrier
call_end_barrier.wait();
}
}
});
Self {
sender,
unregister_receiver,
thread
}
}
/// Get a new sender for communication with the callback thread
pub(crate) fn get_sender(&self) -> mpsc::Sender<CallbackMessage> {
self.sender.clone()
}
pub(crate) fn join(self) {
self.sender.send(CallbackMessage::Quit).expect("Callback thread has hung up");
self.thread.join().expect("Callback thread has panicked at some point");
}
pub fn register_callback(&mut self, barriers: BTreeSet<BarrierID>,
callback: CallbackType,
callback_state: CallbackStateType)
-> Result<()> {
// TODO: Check if callback already existed
self.sender.send(CallbackMessage::Register(barriers, callback, callback_state))
.expect("Callback thread has hung up");
Ok(())
}
pub fn unregister_callback(&mut self, barriers: BTreeSet<BarrierID>)
-> Result<(CallbackType, CallbackStateType)> {
self.sender.send(CallbackMessage::Unregister(barriers))
.expect("Callback thread has hung up");
let callback_data = self.unregister_receiver.recv()
.expect("Callback thread has hung up");
Ok(callback_data)
}
}