agave_thread_manager/
native_thread_runtime.rs

1use {
2    crate::{
3        policy::{apply_policy, parse_policy, CoreAllocation},
4        MAX_THREAD_NAME_CHARS,
5    },
6    anyhow::bail,
7    log::warn,
8    serde::{Deserialize, Serialize},
9    solana_metrics::datapoint_info,
10    std::{
11        ops::Deref,
12        sync::{
13            atomic::{AtomicUsize, Ordering},
14            Arc, Mutex,
15        },
16    },
17};
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
20#[serde(default)]
21pub struct NativeConfig {
22    pub core_allocation: CoreAllocation,
23    pub max_threads: usize,
24    /// Priority in range 0..99
25    pub priority: u8,
26    pub policy: String,
27    pub stack_size_bytes: usize,
28}
29
30impl Default for NativeConfig {
31    fn default() -> Self {
32        Self {
33            core_allocation: CoreAllocation::OsDefault,
34            max_threads: 16,
35            priority: crate::policy::DEFAULT_PRIORITY,
36            policy: "OTHER".to_owned(),
37            stack_size_bytes: 2 * 1024 * 1024,
38        }
39    }
40}
41
42#[derive(Debug)]
43pub struct NativeThreadRuntimeInner {
44    pub id_count: AtomicUsize,
45    pub running_count: Arc<AtomicUsize>,
46    pub config: NativeConfig,
47    pub name: String,
48}
49
50#[derive(Debug, Clone)]
51pub struct NativeThreadRuntime {
52    inner: Arc<NativeThreadRuntimeInner>,
53}
54
55impl Deref for NativeThreadRuntime {
56    type Target = NativeThreadRuntimeInner;
57
58    fn deref(&self) -> &Self::Target {
59        &self.inner
60    }
61}
62
63pub struct JoinHandle<T> {
64    std_handle: Option<std::thread::JoinHandle<T>>,
65    running_count: Arc<AtomicUsize>,
66}
67
68impl<T> JoinHandle<T> {
69    fn join_inner(&mut self) -> std::thread::Result<T> {
70        match self.std_handle.take() {
71            Some(jh) => {
72                let result = jh.join();
73                let rc = self.running_count.fetch_sub(1, Ordering::Relaxed);
74                datapoint_info!("thread-manager-native", ("threads-running", rc, i64),);
75                result
76            }
77            None => {
78                panic!("Thread already joined");
79            }
80        }
81    }
82
83    pub fn join(mut self) -> std::thread::Result<T> {
84        self.join_inner()
85    }
86
87    pub fn is_finished(&self) -> bool {
88        match self.std_handle {
89            Some(ref jh) => jh.is_finished(),
90            None => true,
91        }
92    }
93}
94
95impl<T> Drop for JoinHandle<T> {
96    fn drop(&mut self) {
97        if self.std_handle.is_some() {
98            warn!("Attempting to drop a Join Handle of a running thread will leak thread IDs, please join your  threads!");
99            self.join_inner().expect("Child thread panicked");
100        }
101    }
102}
103
104impl NativeThreadRuntime {
105    pub fn new(name: String, cfg: NativeConfig) -> Self {
106        debug_assert!(name.len() < MAX_THREAD_NAME_CHARS, "Thread name too long");
107        Self {
108            inner: Arc::new(NativeThreadRuntimeInner {
109                id_count: AtomicUsize::new(0),
110                running_count: Arc::new(AtomicUsize::new(0)),
111                config: cfg,
112                name,
113            }),
114        }
115    }
116
117    pub fn spawn<F, T>(&self, f: F) -> anyhow::Result<JoinHandle<T>>
118    where
119        F: FnOnce() -> T,
120        F: Send + 'static,
121        T: Send + 'static,
122    {
123        let n = self.id_count.fetch_add(1, Ordering::Relaxed);
124        let name = format!("{}-{}", &self.name, n);
125        self.spawn_named(name, f)
126    }
127
128    pub fn spawn_named<F, T>(&self, name: String, f: F) -> anyhow::Result<JoinHandle<T>>
129    where
130        F: FnOnce() -> T,
131        F: Send + 'static,
132        T: Send + 'static,
133    {
134        debug_assert!(name.len() < MAX_THREAD_NAME_CHARS, "Thread name too long");
135        let spawned = self.running_count.load(Ordering::Relaxed);
136        if spawned >= self.config.max_threads {
137            bail!("All allowed threads in this pool are already spawned");
138        }
139
140        let core_alloc = self.config.core_allocation.clone();
141        let priority = self.config.priority;
142        let policy = parse_policy(&self.config.policy);
143        let chosen_cores_mask = Mutex::new(self.config.core_allocation.as_core_mask_vector());
144        let jh = std::thread::Builder::new()
145            .name(name)
146            .stack_size(self.config.stack_size_bytes)
147            .spawn(move || {
148                apply_policy(&core_alloc, policy, priority, &chosen_cores_mask);
149                f()
150            })?;
151        let rc = self.running_count.fetch_add(1, Ordering::Relaxed);
152        datapoint_info!("thread-manager-native", ("threads-running", rc as i64, i64),);
153        Ok(JoinHandle {
154            std_handle: Some(jh),
155            running_count: self.running_count.clone(),
156        })
157    }
158
159    #[cfg(feature = "dev-context-only-utils")]
160    pub fn new_for_tests(name: &str) -> Self {
161        Self::new(name.to_owned(), NativeConfig::default())
162    }
163}