many_cpus/
processor_set.rs1use std::{sync::LazyLock, thread};
2
3use itertools::Itertools;
4use nonempty::NonEmpty;
5
6use crate::{
7 HardwareTrackerClient, HardwareTrackerClientFacade, Processor, ProcessorSetBuilder,
8 pal::{Platform, PlatformFacade},
9};
10
11extern crate alloc;
13
14static ALL_PROCESSORS: LazyLock<ProcessorSet> = LazyLock::new(|| {
15 ProcessorSetBuilder::default()
16 .take_all()
17 .expect("there must be at least one processor - how could this code run if not")
18});
19
20#[derive(Clone, Debug)]
40pub struct ProcessorSet {
41 processors: NonEmpty<Processor>,
42
43 tracker_client: HardwareTrackerClientFacade,
46
47 pal: PlatformFacade,
48}
49
50impl ProcessorSet {
51 pub fn all() -> &'static Self {
57 &ALL_PROCESSORS
58 }
59
60 #[cfg_attr(test, mutants::skip)] pub fn builder() -> ProcessorSetBuilder {
63 ProcessorSetBuilder::default()
64 }
65
66 pub fn to_builder(&self) -> ProcessorSetBuilder {
69 ProcessorSetBuilder::with_internals(self.tracker_client.clone(), self.pal.clone())
70 .filter(|p| self.processors.contains(p))
71 }
72
73 pub(crate) fn new(
74 processors: NonEmpty<Processor>,
75 tracker_client: HardwareTrackerClientFacade,
76 pal: PlatformFacade,
77 ) -> Self {
78 Self {
79 processors,
80 tracker_client,
81 pal,
82 }
83 }
84
85 pub fn from_processors(processors: NonEmpty<Processor>) -> Self {
87 let pal = processors.first().pal.clone();
88 Self::new(processors, HardwareTrackerClientFacade::real(), pal)
89 }
90
91 pub fn from_processor(processor: Processor) -> Self {
93 let pal = processor.pal.clone();
94 Self::new(
95 NonEmpty::singleton(processor),
96 HardwareTrackerClientFacade::real(),
97 pal,
98 )
99 }
100
101 #[expect(clippy::len_without_is_empty)] pub fn len(&self) -> usize {
104 self.processors.len()
105 }
106
107 pub fn processors(&self) -> &NonEmpty<Processor> {
109 &self.processors
110 }
111
112 pub fn pin_current_thread_to(&self) {
121 self.pal.pin_current_thread_to(&self.processors);
122
123 if self.processors.len() == 1 {
124 let processor = self.processors.first();
126
127 self.tracker_client
128 .update_pin_status(Some(processor.id()), Some(processor.memory_region_id()));
129 } else if self
130 .processors
131 .iter()
132 .map(|p| p.memory_region_id())
133 .unique()
134 .count()
135 == 1
136 {
137 let memory_region_id = self.processors.first().memory_region_id();
139
140 self.tracker_client
141 .update_pin_status(None, Some(memory_region_id));
142 } else {
143 self.tracker_client.update_pin_status(None, None);
145 }
146 }
147
148 pub fn spawn_threads<E, R>(&self, entrypoint: E) -> Box<[thread::JoinHandle<R>]>
154 where
155 E: Fn(Processor) -> R + Send + Clone + 'static,
156 R: Send + 'static,
157 {
158 self.processors()
159 .iter()
160 .map(|processor| {
161 thread::spawn({
162 let processor = processor.clone();
163 let entrypoint = entrypoint.clone();
164 let tracker_client = self.tracker_client.clone();
165 let pal = self.pal.clone();
166
167 move || {
168 let set = Self::new(
169 NonEmpty::from_vec(vec![processor.clone()])
170 .expect("we provide 1-item vec as input, so it must be non-empty"),
171 tracker_client.clone(),
172 pal.clone(),
173 );
174 set.pin_current_thread_to();
175 entrypoint(processor)
176 }
177 })
178 })
179 .collect::<Vec<_>>()
180 .into_boxed_slice()
181 }
182
183 pub fn spawn_thread<E, R>(&self, entrypoint: E) -> thread::JoinHandle<R>
192 where
193 E: FnOnce(ProcessorSet) -> R + Send + 'static,
194 R: Send + 'static,
195 {
196 let set = self.clone();
197
198 thread::spawn(move || {
199 set.pin_current_thread_to();
200 entrypoint(set)
201 })
202 }
203}
204
205impl From<Processor> for ProcessorSet {
206 fn from(value: Processor) -> Self {
207 Self::from_processor(value)
208 }
209}
210
211impl From<NonEmpty<Processor>> for ProcessorSet {
212 fn from(value: NonEmpty<Processor>) -> Self {
213 Self::from_processors(value)
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use std::{
220 num::NonZero,
221 sync::{
222 Arc,
223 atomic::{AtomicUsize, Ordering},
224 },
225 };
226
227 use nonempty::nonempty;
228
229 use crate::{
230 EfficiencyClass, MockHardwareTrackerClient,
231 pal::{FakeProcessor, MockPlatform},
232 };
233
234 use super::*;
235
236 #[test]
237 fn smoke_test() {
238 let mut platform = MockPlatform::new();
239
240 platform
242 .expect_pin_current_thread_to_core()
243 .withf(|p| p.len() == 2)
244 .return_const(());
245
246 platform
248 .expect_pin_current_thread_to_core()
249 .withf(|p| p.len() == 2)
250 .return_const(());
251
252 platform
254 .expect_pin_current_thread_to_core()
255 .withf(|p| p.len() == 1)
256 .return_const(());
257
258 platform
259 .expect_pin_current_thread_to_core()
260 .withf(|p| p.len() == 1)
261 .return_const(());
262
263 let platform = PlatformFacade::from_mock(platform);
264
265 let pal_processors = nonempty![
266 FakeProcessor {
267 index: 0,
268 memory_region: 0,
269 efficiency_class: EfficiencyClass::Efficiency,
270 },
271 FakeProcessor {
272 index: 1,
273 memory_region: 0,
274 efficiency_class: EfficiencyClass::Performance,
275 }
276 ];
277
278 let processors = pal_processors.map({
279 let platform = platform.clone();
280 move |p| Processor::new(p.into(), platform.clone())
281 });
282
283 let mut tracker_client = MockHardwareTrackerClient::new();
284
285 tracker_client
286 .expect_update_pin_status()
287 .times(2)
289 .withf(|processor, memory_region| {
290 processor.is_none() && matches!(memory_region, Some(0))
291 })
292 .return_const(());
293
294 tracker_client
296 .expect_update_pin_status()
297 .times(1)
298 .withf(|processor, memory_region| {
299 matches!(processor, Some(0)) && matches!(memory_region, Some(0))
300 })
301 .return_const(());
302
303 tracker_client
304 .expect_update_pin_status()
305 .times(1)
306 .withf(|processor, memory_region| {
307 matches!(processor, Some(1)) && matches!(memory_region, Some(0))
308 })
309 .return_const(());
310
311 let tracker_client = HardwareTrackerClientFacade::from_mock(tracker_client);
312
313 let processor_set = ProcessorSet::new(processors, tracker_client, platform);
314
315 assert_eq!(processor_set.len(), 2);
317
318 let mut processor_iter = processor_set.processors().iter();
320
321 let p1 = processor_iter.next().unwrap();
322 assert_eq!(p1.id(), 0);
323 assert_eq!(p1.memory_region_id(), 0);
324 assert_eq!(p1.efficiency_class(), EfficiencyClass::Efficiency);
325
326 let p2 = processor_iter.next().unwrap();
327 assert_eq!(p2.id(), 1);
328 assert_eq!(p2.memory_region_id(), 0);
329 assert_eq!(p2.efficiency_class(), EfficiencyClass::Performance);
330
331 assert!(processor_iter.next().is_none());
332
333 processor_set.pin_current_thread_to();
335
336 let threads_spawned = Arc::new(AtomicUsize::new(0));
338
339 let threads_spawned_clone = Arc::clone(&threads_spawned);
342
343 let non_copy_value = "foo".to_string();
344
345 fn process_string(_s: String) {}
346
347 processor_set
348 .spawn_thread({
349 move |processor_set| {
350 assert_eq!(processor_set.len(), 2);
352
353 process_string(non_copy_value);
355
356 threads_spawned_clone.fetch_add(1, Ordering::Relaxed);
357 }
358 })
359 .join()
360 .unwrap();
361
362 assert_eq!(threads_spawned.load(Ordering::Relaxed), 1);
363
364 let threads_spawned = Arc::new(AtomicUsize::new(0));
366
367 processor_set
368 .spawn_threads({
369 let threads_spawned = Arc::clone(&threads_spawned);
370 move |_| {
371 threads_spawned.fetch_add(1, Ordering::Relaxed);
372 }
373 })
374 .into_vec()
375 .into_iter()
376 .for_each(|h| h.join().unwrap());
377
378 assert_eq!(threads_spawned.load(Ordering::Relaxed), 2);
379
380 let cloned_processor_set = processor_set.clone();
382
383 assert_eq!(cloned_processor_set.len(), 2);
384 }
385
386 #[cfg(not(miri))] #[test]
388 fn to_builder_preserves_processors() {
389 let set = ProcessorSet::builder()
390 .take(NonZero::new(1).unwrap())
391 .unwrap();
392
393 let builder = set.to_builder();
394
395 let set2 = builder.take_all().unwrap();
396 assert_eq!(set2.len(), 1);
397
398 let processor1 = set.processors().first();
399 let processor2 = set2.processors().first();
400
401 assert_eq!(processor1, processor2);
402 }
403
404 #[cfg(not(miri))] #[test]
406 fn inherit_on_pinned() {
407 thread::spawn(|| {
408 let one = ProcessorSet::builder()
409 .take(NonZero::new(1).unwrap())
410 .unwrap();
411
412 one.pin_current_thread_to();
413
414 let current_thread_allowed = ProcessorSet::builder()
416 .where_available_for_current_thread()
417 .take_all()
418 .unwrap();
419
420 assert_eq!(current_thread_allowed.len(), 1);
421 assert_eq!(
422 current_thread_allowed.processors().first(),
423 one.processors().first()
424 );
425 })
426 .join()
427 .unwrap();
428 }
429}