fips_md/codegen/
callback_thread.rs

1//! Thread for executing "call to Rust" callbacks
2
3use std::{any::Any, collections::{BTreeSet, HashMap}, sync::{self, Arc, mpsc}, thread::{JoinHandle, spawn}};
4
5use anyhow::Result;
6
7use super::{GlobalContext, analysis::BarrierID};
8
9pub type CallbackType = fn(&Arc<GlobalContext>, &mut Box<dyn Any + Send>);
10pub type CallbackStateType = Box<dyn Any + Send>;
11
12pub(crate) enum CallbackMessage {
13    Call(BarrierID),
14    Register(BTreeSet<BarrierID>, CallbackType, CallbackStateType),
15    Unregister(BTreeSet<BarrierID>),
16    Quit
17}
18
19/// Registry for callbacks: Every callback can be associated with one ore more barriers
20struct CallbackRegistry {
21    registry: HashMap<BTreeSet<BarrierID>, (CallbackType, CallbackStateType)>
22}
23
24impl CallbackRegistry {
25    fn new() -> Self {
26        Self {
27            registry: HashMap::new()
28        }
29    }
30
31    fn insert(&mut self, barriers: BTreeSet<BarrierID>,
32        callback: CallbackType, state: CallbackStateType)
33    {
34        // TODO: Assert disjointness of barriers in registry?
35        self.registry.insert(barriers, (callback, state));
36    }
37
38    // fn get(&self, barrier: BarrierID) -> Option<&(fn(GlobalContext, &mut Box<dyn Any + Send>), Box<dyn Any>)> {
39    //     for (barriers, callback_data) in self.registry.iter() {
40    //         if barriers.contains(&barrier) {
41    //             return Some(callback_data)
42    //         }
43    //     }
44    //     None
45    // }
46
47    fn get_mut(&mut self, barrier: BarrierID) -> Option<&mut (CallbackType, CallbackStateType)> {
48        for (barriers, callback_data) in self.registry.iter_mut() {
49            if barriers.contains(&barrier) {
50                return Some(callback_data)
51            }
52        }
53        None
54    }
55
56    fn remove(&mut self, barriers: &BTreeSet<BarrierID>) -> Option<(CallbackType, CallbackStateType)> {
57        self.registry.remove(barriers)
58    }
59}
60
61pub(crate) struct CallbackThread {
62    /// Sender for tasks to the callback thread
63    sender: mpsc::Sender<CallbackMessage>,
64    /// Receiver for getting callback data out of the callback thread
65    unregister_receiver: mpsc::Receiver<(CallbackType, CallbackStateType)>,
66    /// The actual handle to the callback thread
67    thread: JoinHandle<()>
68}
69
70impl CallbackThread {
71    /// Create a new callback thread that will wait on the given barrier after each call
72    pub fn new(call_end_barrier: Arc<sync::Barrier>, num_workers: usize,
73        global_context: Arc<GlobalContext>) 
74    -> Self {
75        // Create MPSC channel for communication between workers and callback thread
76        let (sender, receiver) = mpsc::channel();
77        let (unregister_sender, unregister_receiver) = mpsc::channel();
78        // Spawn callback thread
79        let thread = spawn(move || {
80            // Callback registry
81            let mut callbacks = CallbackRegistry::new();
82            
83            loop {
84                // Wait until every worker has ordered us to do the callback
85                let mut callback_barrier = None;
86                for i in 0..num_workers {
87                    let message = receiver.recv().expect("Callback channel senders have hung up");
88                    match message {
89                        CallbackMessage::Quit => { return },
90                        CallbackMessage::Register(barriers, callback, state) => {
91                            if i != 0 {
92                                panic!("Got callback registration while waiting for workers")
93                            }
94                            callbacks.insert(barriers, callback, state);
95                            break;
96                        },
97                        CallbackMessage::Unregister(barriers) => {
98                            if i != 0 {
99                                panic!("Got callback unregistration while waiting for workers")
100                            }
101                            // TODO: Fail more gracefully?
102                            let callback_data = callbacks.remove(&barriers)
103                                .expect("Cannot unregister barrier set: No callback defined.");
104                            unregister_sender.send(callback_data)
105                                .expect("Callback return channel has hung up.");
106                            break;
107                        },
108                        CallbackMessage::Call(barrier) => {
109                            match &mut callback_barrier {
110                                None => {
111                                    callback_barrier = Some(barrier);
112                                    continue;
113                                }
114                                Some(previous_barrier) => {
115                                    // Verify that we got the right call barrier
116                                    if *previous_barrier == barrier {
117                                        continue;
118                                    }
119                                    // Panic otherwise
120                                    else {
121                                        panic!("Got conflicting callback barrier IDs!");
122                                    }
123                            }}
124                    }}
125                }
126                if let Some(callback_barrier) = callback_barrier {
127                    // Now callback_barrier is consistent and we can call the correct function
128                    if let Some((callback, state)) = callbacks.get_mut(callback_barrier) {
129                        callback(&global_context, state);
130                    }
131                    // Finally unlock the call_end barrier
132                    call_end_barrier.wait();
133                }
134            }
135        });
136        Self {
137            sender,
138            unregister_receiver,
139            thread
140        }
141    }
142
143    /// Get a new sender for communication with the callback thread
144    pub(crate) fn get_sender(&self) -> mpsc::Sender<CallbackMessage> {
145        self.sender.clone()
146    }
147
148    pub(crate) fn join(self) {
149        self.sender.send(CallbackMessage::Quit).expect("Callback thread has hung up");
150        self.thread.join().expect("Callback thread has panicked at some point");
151    }
152
153    pub fn register_callback(&mut self, barriers: BTreeSet<BarrierID>,
154        callback: CallbackType,
155        callback_state: CallbackStateType) 
156    -> Result<()> {
157        // TODO: Check if callback already existed
158        self.sender.send(CallbackMessage::Register(barriers, callback, callback_state))
159            .expect("Callback thread has hung up");
160        Ok(())
161    }
162
163    pub fn unregister_callback(&mut self, barriers: BTreeSet<BarrierID>) 
164    -> Result<(CallbackType, CallbackStateType)> {
165        self.sender.send(CallbackMessage::Unregister(barriers))
166            .expect("Callback thread has hung up");
167        let callback_data = self.unregister_receiver.recv()
168            .expect("Callback thread has hung up");
169        Ok(callback_data)
170    }
171}