fast_able/fast_thread_pool/
task_executor.rs

1use std::{sync::Arc, thread};
2
3use core_affinity::CoreId;
4use crossbeam::atomic::AtomicCell;
5
6/// Channel type definition based on features
7/// Use crossbeam::channel when crossbeam_channel feature is enabled, otherwise use std::sync::mpsc
8/// 基于 features 的 channel 类型定义
9/// 启用 crossbeam_channel feature 时使用 crossbeam::channel,否则使用 std::sync::mpsc
10
11#[cfg(feature = "crossbeam_channel")]
12pub use crossbeam::channel::{bounded, unbounded, Receiver, Sender, TryRecvError};
13
14#[cfg(not(feature = "crossbeam_channel"))]
15pub use std::sync::mpsc::{
16    channel as unbounded, sync_channel as bounded, Receiver, TryRecvError,
17};
18
19// 为 std::sync::mpsc 定义统一的 Sender 类型
20#[cfg(all(not(feature = "crossbeam_channel"), not(feature = "thread_task_bounded")))]
21pub use std::sync::mpsc::Sender;
22
23#[cfg(all(not(feature = "crossbeam_channel"), feature = "thread_task_bounded"))]
24pub use std::sync::mpsc::SyncSender as Sender;
25
26/// Get default bounded channel capacity
27/// Returns 4 times the number of CPU cores, at least 64, at most 1024
28/// 获取默认的有界 channel 容量
29/// 返回 CPU 核心数的 4 倍,至少为 64,最多为 1024
30fn get_default_bounded_capacity() -> usize {
31    let cpu_count = num_cpus::get();
32    let capacity = (cpu_count * 100).max(128).min(4096);
33    capacity
34}
35
36/// Unified thread pool executor, supports switching channel implementation via features
37/// 统一的线程池执行器,支持通过 features 切换 channel 实现
38pub struct TaskExecutor {
39    jobs: Sender<Box<dyn FnOnce(&usize) + Send + 'static>>,
40    _handle: thread::JoinHandle<()>,
41    pub count: Arc<AtomicCell<i64>>,
42    core: usize,
43}
44
45impl std::fmt::Debug for TaskExecutor {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        f.debug_struct("TaskExecutor")
48            .field("_handle", &self._handle)
49            .field("count", &self.count)
50            .field("core", &self.core)
51            .finish()
52    }
53}
54
55impl TaskExecutor {
56    /// 创建新的任务执行器
57    /// realtime: 实时内核优先级(1-99,越高越优先),输入 -1 时不开启
58    pub fn new(core: CoreId, realtime: i32) -> TaskExecutor {
59        // 根据 thread_task_bounded feature 选择使用有界或无界 channel
60        #[cfg(feature = "thread_task_bounded")]
61        let (tx, rx) = {
62            let capacity = get_default_bounded_capacity();
63            bounded::<Box<dyn FnOnce(&usize) + Send + 'static>>(capacity)
64        };
65
66        #[cfg(not(feature = "thread_task_bounded"))]
67        let (tx, rx) = unbounded::<Box<dyn FnOnce(&usize) + Send + 'static>>();
68
69        let count = Arc::new(AtomicCell::new(0_i64));
70        let task_count = count.clone();
71
72        let _handle = thread::spawn(move || {
73            // 绑核和开启实时内核
74            super::set_core_affinity_and_realtime(core.id, realtime);
75            let core_id = core.id;
76
77            // 在worker线程启动前设置线程级别的panic hook
78            let old_hook = std::panic::take_hook();
79            std::panic::set_hook(Box::new(move |panic_info| {
80                let thread = std::thread::current();
81                let thread_name = thread.name().unwrap_or("unnamed");
82
83                // 获取panic消息
84                let panic_message = if let Some(s) = panic_info.payload().downcast_ref::<&str>() {
85                    s.to_string()
86                } else if let Some(s) = panic_info.payload().downcast_ref::<String>() {
87                    s.clone()
88                } else {
89                    format!(
90                        "Unknown panic payload type: {:?}",
91                        panic_info.payload().type_id()
92                    )
93                };
94
95                // 获取panic位置信息
96                let location_info = if let Some(location) = panic_info.location() {
97                    format!(
98                        "file: '{}', line: {}, column: {}",
99                        location.file(),
100                        location.line(),
101                        location.column()
102                    )
103                } else {
104                    "unknown location".to_string()
105                };
106
107                // 输出详细的panic信息
108                error!(
109                    "PANIC in TaskExecutor worker thread!\n\
110                     ┌─ Thread Info ─────────────────────────────────────┐\n\
111                     │ Thread Name: {}\n\
112                     │ Core ID: {}\n\
113                     │ Thread ID: {:?}\n\
114                     ├─ Panic Details ──────────────────────────────────┤\n\
115                     │ Message: {}\n\
116                     │ Location: {}\n\
117                     └──────────────────────────────────────────────────┘",
118                    thread_name,
119                    core_id,
120                    thread.id(),
121                    panic_message,
122                    location_info
123                );
124
125                // 调用原来的hook以保持默认行为
126                old_hook(panic_info);
127            }));
128
129            Self::run_worker_loop(rx, task_count, core_id);
130        });
131
132        TaskExecutor {
133            jobs: tx,
134            _handle,
135            count,
136            core: core.id,
137        }
138    }
139
140    /// 创建带自定义容量的任务执行器(仅在启用 thread_task_bounded feature 时有效)
141    /// capacity: 有界 channel 的容量
142    /// realtime: 实时内核优先级(1-99,越高越优先),输入 -1 时不开启
143    #[cfg(feature = "thread_task_bounded")]
144    pub fn new_with_capacity(core: CoreId, capacity: usize, realtime: i32) -> TaskExecutor {
145        let (tx, rx) = bounded::<Box<dyn FnOnce(&usize) + Send + 'static>>(capacity);
146        let count = Arc::new(AtomicCell::new(0_i64));
147        let task_count = count.clone();
148
149        let _handle = thread::spawn(move || {
150            // 绑核和开启实时内核
151            super::set_core_affinity_and_realtime(core.id, realtime);
152            let core_id = core.id;
153
154            // 在worker线程启动前设置线程级别的panic hook
155            let old_hook = std::panic::take_hook();
156            std::panic::set_hook(Box::new(move |panic_info| {
157                let thread = std::thread::current();
158                let thread_name = thread.name().unwrap_or("unnamed");
159
160                // 获取panic消息
161                let panic_message = if let Some(s) = panic_info.payload().downcast_ref::<&str>() {
162                    s.to_string()
163                } else if let Some(s) = panic_info.payload().downcast_ref::<String>() {
164                    s.clone()
165                } else {
166                    format!(
167                        "Unknown panic payload type: {:?}",
168                        panic_info.payload().type_id()
169                    )
170                };
171
172                // 获取panic位置信息
173                let location_info = if let Some(location) = panic_info.location() {
174                    format!(
175                        "file: '{}', line: {}, column: {}",
176                        location.file(),
177                        location.line(),
178                        location.column()
179                    )
180                } else {
181                    "unknown location".to_string()
182                };
183
184                // 输出详细的panic信息
185                error!(
186                    "PANIC in TaskExecutor worker thread!\n\
187                     ┌─ Thread Info ─────────────────────────────────────┐\n\
188                     │ Thread Name: {}\n\
189                     │ Core ID: {}\n\
190                     │ Thread ID: {:?}\n\
191                     ├─ Panic Details ──────────────────────────────────┤\n\
192                     │ Message: {}\n\
193                     │ Location: {}\n\
194                     └──────────────────────────────────────────────────┘",
195                    thread_name,
196                    core_id,
197                    thread.id(),
198                    panic_message,
199                    location_info
200                );
201
202                // 调用原来的hook以保持默认行为
203                old_hook(panic_info);
204            }));
205
206            Self::run_worker_loop(rx, task_count, core_id);
207        });
208
209        TaskExecutor {
210            jobs: tx,
211            _handle,
212            count,
213            core: core.id,
214        }
215    }
216
217    /// 工作线程主循环
218    fn run_worker_loop(
219        rx: Receiver<Box<dyn FnOnce(&usize) + Send + 'static>>,
220        task_count: Arc<AtomicCell<i64>>,
221        core_id: usize,
222    ) {
223        #[cfg(feature = "thread_dispatch")]
224        {
225            let mut empty_count = 0;
226            loop {
227                match rx.try_recv() {
228                    Ok(job) => {
229                        job(&core_id);
230                        task_count.fetch_sub(1);
231                        empty_count = 0;
232                    }
233                    Err(TryRecvError::Empty) => {
234                        empty_count += 1;
235                        if empty_count > 1000 {
236                            empty_count = 0;
237                            // 空闲次数过多时,阻塞等待任务
238                            if let Ok(job) = rx.recv() {
239                                job(&core_id);
240                                task_count.fetch_sub(1);
241                            }
242                        }
243                    }
244                    Err(TryRecvError::Disconnected) => {
245                        error!("TaskExecutor disconnected: {}", core_id);
246                        break;
247                    }
248                }
249            }
250        }
251
252        #[cfg(not(feature = "thread_dispatch"))]
253        loop {
254            if let Ok(job) = rx.try_recv() {
255                job(&core_id);
256                task_count.fetch_sub(1);
257            }
258        }
259    }
260
261    /// 提交任务到线程池
262    #[inline(always)]
263    pub fn spawn<F>(&self, f: F)
264    where
265        F: FnOnce(&usize) + Send + 'static,
266    {
267        self.count.fetch_add(1);
268        
269        if let Err(e) = self.jobs.send(Box::new(f)) {
270            error!("TaskExecutor send error: {:?}", e);
271            // 如果发送失败,直接在当前线程执行
272            e.0(&0);
273            self.count.fetch_sub(1);
274        }
275    }
276
277    /// 尝试提交任务到线程池(非阻塞)
278    /// 仅在启用 thread_task_bounded feature 时提供此方法
279    /// 返回 true 表示成功提交,false 表示队列已满
280    #[cfg(all(feature = "thread_task_bounded", feature = "crossbeam_channel"))]
281    #[inline(always)]
282    pub fn try_spawn<F>(&self, f: F) -> bool
283    where
284        F: FnOnce(&usize) + Send + 'static,
285    {
286        match self.jobs.try_send(Box::new(f)) {
287            Ok(_) => {
288                self.count.fetch_add(1);
289                true
290            }
291            Err(_) => false, // 队列已满或通道已关闭
292        }
293    }
294
295    /// 尝试提交任务到线程池(非阻塞)- std::sync::mpsc 版本
296    /// 仅在启用 thread_task_bounded feature 且使用 std mpsc 时提供此方法
297    /// 返回 true 表示成功提交,false 表示队列已满
298    #[cfg(all(feature = "thread_task_bounded", not(feature = "crossbeam_channel")))]
299    #[inline(always)]
300    pub fn try_spawn<F>(&self, f: F) -> bool
301    where
302        F: FnOnce(&usize) + Send + 'static,
303    {
304        match self.jobs.try_send(Box::new(f)) {
305            Ok(_) => {
306                self.count.fetch_add(1);
307                true
308            }
309            Err(_) => false, // 队列已满或通道已关闭
310        }
311    }
312}