thread_local_collect/tlm/restr/
probed.rs1#![doc = include_str!("../../../examples/tlmrestr_probed_i32_accumulator.rs")]
17pub use super::control_restr::ControlRestrG;
25
26use super::control_restr::WithTakeTls;
27use crate::tlm::probed::{Control as ControlInner, Holder as HolderInner, Probed};
28
29pub type Control<U> = ControlRestrG<Probed<U, Option<U>>, U>;
35
36impl<U> WithTakeTls<Probed<U, Option<U>>, U> for Control<U>
37where
38 U: 'static,
39{
40 fn take_tls(control: &ControlInner<U, Option<U>>) {
41 control.take_tls();
42 }
43}
44
45impl<U> Control<U>
46where
47 U: Clone,
48{
49 pub fn probe_tls(&self) -> U {
50 self.control
51 .probe_tls()
52 .expect("accumulator guaranteed to never be None")
53 }
54}
55
56pub type Holder<U> = HolderInner<U, Option<U>>;
60
61#[cfg(test)]
62#[allow(clippy::unwrap_used)]
63mod tests {
64 use super::{Control, Holder};
65 use crate::dev_support::{assert_eq_and_println, ThreadGater};
66 use std::{
67 collections::HashMap,
68 fmt::Debug,
69 sync::Mutex,
70 thread::{self, ThreadId},
71 };
72
73 #[derive(Debug, Clone, PartialEq)]
74 struct Foo(String);
75
76 type Data = (i32, Foo);
77
78 type AccValue = HashMap<ThreadId, HashMap<i32, Foo>>;
79
80 fn op(data: Data, acc: &mut AccValue, tid: ThreadId) {
81 println!(
82 "`op` called from {:?} with data {:?}",
83 thread::current().id(),
84 data
85 );
86
87 acc.entry(tid).or_default();
88 let (k, v) = data;
89 acc.get_mut(&tid).unwrap().insert(k, v.clone());
90 }
91
92 fn op_r(acc1: AccValue, acc2: AccValue) -> AccValue {
93 println!(
94 "`op_r` called from {:?} with acc1={:?} and acc2={:?}",
95 thread::current().id(),
96 acc1,
97 acc2
98 );
99
100 let mut acc = acc1;
101 acc2.into_iter().for_each(|(k, v)| {
102 acc.insert(k, v);
103 });
104 acc
105 }
106
107 thread_local! {static MY_TL: Holder<AccValue> = Holder::new();}
108
109 const NTHREADS: usize = 5;
110
111 #[test]
112 fn own_thread_and_explicit_joins_no_probe() {
113 let mut control = Control::new(&MY_TL, HashMap::new, op_r);
114
115 let tid_own = thread::current().id();
116
117 let map_own = {
118 let value1 = Foo("a".to_owned());
119 let value2 = Foo("b".to_owned());
120 let map_own = HashMap::from([(1, value1.clone()), (2, value2.clone())]);
121
122 control.aggregate_data((1, value1), op);
123 control.aggregate_data((2, value2), op);
124
125 map_own
126 };
127
128 let tid_map_pairs = thread::scope(|s| {
129 let hs = (0..NTHREADS)
130 .map(|i| {
131 let value1 = Foo("a".to_owned() + &i.to_string());
132 let value2 = Foo("a".to_owned() + &i.to_string());
133 let map_i = HashMap::from([(1, value1.clone()), (2, value2.clone())]);
134
135 s.spawn(|| {
136 control.aggregate_data((1, value1), op);
137 control.aggregate_data((2, value2), op);
138
139 let tid_spawned = thread::current().id();
140 (tid_spawned, map_i)
141 })
142 })
143 .collect::<Vec<_>>();
144
145 hs.into_iter()
146 .map(|h| h.join().unwrap())
147 .collect::<Vec<_>>()
148 });
149
150 {
151 let map = std::iter::once((tid_own, map_own))
152 .chain(tid_map_pairs)
153 .collect::<HashMap<_, _>>();
154
155 {
156 let acc = control.drain_tls();
157 assert_eq_and_println(&acc, &map, "Accumulator check");
158 }
159
160 {
162 let acc = control.drain_tls();
163 assert_eq_and_println(&acc, &HashMap::new(), "empty accumulatore expected");
164 }
165 }
166
167 {
169 let map_own = {
170 let value1 = Foo("c".to_owned());
171 let value2 = Foo("d".to_owned());
172 let map_own = HashMap::from([(11, value1.clone()), (22, value2.clone())]);
173
174 control.aggregate_data((11, value1), op);
175 control.aggregate_data((22, value2), op);
176
177 map_own
178 };
179
180 let (tid_spawned, map_spawned) = thread::scope(|s| {
181 let control = &control;
182
183 let value1 = Foo("x".to_owned());
184 let value2 = Foo("y".to_owned());
185 let map_spawned = HashMap::from([(11, value1.clone()), (22, value2.clone())]);
186
187 let tid = s
188 .spawn(move || {
189 control.aggregate_data((11, value1), op);
190 control.aggregate_data((22, value2), op);
191 thread::current().id()
192 })
193 .join()
194 .unwrap();
195
196 (tid, map_spawned)
197 });
198
199 let map = HashMap::from([(tid_own, map_own), (tid_spawned, map_spawned)]);
200 let acc = control.drain_tls();
201 assert_eq_and_println(&acc, &map, "take_acc - control reused");
202 }
203 }
204
205 #[test]
206 fn own_thread_only_no_probe() {
207 let mut control = Control::new(&MY_TL, HashMap::new, op_r);
208
209 control.aggregate_data((1, Foo("a".to_owned())), op);
210 control.aggregate_data((2, Foo("b".to_owned())), op);
211
212 let tid_own = thread::current().id();
213 let map_own = HashMap::from([(1, Foo("a".to_owned())), (2, Foo("b".to_owned()))]);
214
215 let map = HashMap::from([(tid_own, map_own)]);
216
217 let acc = control.drain_tls();
218 assert_eq_and_println(&acc, &map, "Accumulator check");
219 }
220
221 #[test]
222 fn own_thread_and_explicit_join_with_probe() {
223 let mut control = Control::new(&MY_TL, HashMap::new, op_r);
224
225 let main_tid = thread::current().id();
226 println!("main_tid={:?}", main_tid);
227
228 let main_thread_gater = ThreadGater::new("main");
229 let spawned_thread_gater = ThreadGater::new("spawned");
230
231 let expected_acc_mutex = Mutex::new(HashMap::new());
232
233 let assert_acc = |acc: &AccValue, msg: &str| {
234 let exp = expected_acc_mutex.try_lock().unwrap().clone();
235 assert_eq_and_println(acc, &exp, msg);
236 };
237
238 thread::scope(|s| {
239 let h = s.spawn(|| {
240 let spawned_tid = thread::current().id();
241 println!("spawned tid={:?}", spawned_tid);
242
243 let mut my_map = HashMap::<i32, Foo>::new();
244
245 let mut process_value = |gate: u8, k: i32, v: Foo| {
246 main_thread_gater.wait_for(gate);
247 control.aggregate_data((k, v.clone()), op);
248 my_map.insert(k, v);
249 expected_acc_mutex
250 .try_lock()
251 .unwrap()
252 .insert(spawned_tid, my_map.clone());
253 spawned_thread_gater.open(gate);
254 };
255
256 process_value(0, 1, Foo("aa".to_owned()));
257 process_value(1, 2, Foo("bb".to_owned()));
258 process_value(2, 3, Foo("cc".to_owned()));
259 process_value(3, 4, Foo("dd".to_owned()));
260 });
261
262 {
263 control.aggregate_data((1, Foo("a".to_owned())), op);
264 control.aggregate_data((2, Foo("b".to_owned())), op);
265 let my_map = HashMap::from([(1, Foo("a".to_owned())), (2, Foo("b".to_owned()))]);
266
267 let mut map = expected_acc_mutex.try_lock().unwrap();
268 map.insert(main_tid, my_map);
269 let map = map.clone(); let acc = control.probe_tls();
271 assert_eq_and_println(
272 &acc,
273 &map,
274 "Accumulator after main thread inserts and probe_tls",
275 );
276 main_thread_gater.open(0);
277 }
278
279 {
280 spawned_thread_gater.wait_for(0);
281 let acc = control.probe_tls();
282 assert_acc(
283 &acc,
284 "Accumulator after 1st spawned thread insert and probe_tls",
285 );
286 main_thread_gater.open(1);
287 }
288
289 {
290 spawned_thread_gater.wait_for(1);
291 let acc = control.probe_tls();
292 assert_acc(
293 &acc,
294 "Accumulator after 2nd spawned thread insert and take_tls",
295 );
296 main_thread_gater.open(2);
297 }
298
299 {
300 spawned_thread_gater.wait_for(2);
301 let acc = control.probe_tls();
302 assert_acc(
303 &acc,
304 "Accumulator after 3rd spawned thread insert and probe_tls",
305 );
306 main_thread_gater.open(3);
307 }
308
309 {
310 h.join().unwrap();
312 }
313 });
314
315 {
316 let acc = control.drain_tls();
317 assert_acc(
318 &acc,
319 "Accumulator after 4th spawned thread insert and drain_tls",
320 );
321 }
322
323 {
325 let acc = control.drain_tls();
326 assert_eq_and_println(&acc, &HashMap::new(), "empty accumulatore expected");
327 }
328 }
329
330 #[test]
331 fn no_thread() {
332 let mut control = Control::new(&MY_TL, HashMap::new, op_r);
333 let acc = control.drain_tls();
334 assert_eq!(acc, HashMap::new(), "empty accumulator expected");
335 }
336}