#include "megbrain/utils/thread_pool.h"
#include <chrono>
using namespace mgb;
#if MGB_HAVE_THREAD
ThreadPool::ThreadPool(size_t threads_num)
: m_nr_threads(threads_num),
m_main_affinity_flag{false},
m_stop{false},
m_active{false} {
if (threads_num < 1) {
m_nr_threads = 1;
}
if (m_nr_threads > 1) {
if (m_nr_threads > static_cast<uint32_t>(sys::get_cpu_count())) {
mgb_log_debug(
"The number of threads is bigger than number of "
"physical cpu cores, got: %zu core_number: %zu",
static_cast<size_t>(sys::get_cpu_count()), nr_threads());
}
for (uint32_t i = 0; i < m_nr_threads - 1; i++) {
m_workers.push_back(new Worker([this, i]() {
while (!m_stop) {
while (m_active) {
if (m_workers[i]->affinity_flag &&
m_core_binding_function != nullptr) {
m_core_binding_function(i);
m_workers[i]->affinity_flag = false;
}
if (m_workers[i]->work_flag.load(std::memory_order_acquire)) {
int index = -1;
while ((index = m_task_iter.fetch_sub(
1, std::memory_order_acq_rel)) &&
index > 0) {
m_task(static_cast<size_t>(m_nr_parallelism - index),
i);
}
m_workers[i]->work_flag.store(
false, std::memory_order_release);
}
std::this_thread::yield();
}
{
std::unique_lock<std::mutex> lock(m_mutex);
if (!m_stop && !m_active) {
m_cv.wait(lock, [this] { return m_stop || m_active; });
}
}
}
}));
}
}
}
void ThreadPool::add_task(const TaskElem& task_elem) {
if (m_main_affinity_flag && m_core_binding_function != nullptr) {
std::lock_guard<std::mutex> lock(m_mutex_task);
m_core_binding_function(m_nr_threads - 1);
m_main_affinity_flag = false;
}
size_t parallelism = task_elem.nr_parallelism;
if (task_elem.nr_parallelism == 1 || m_nr_threads == 1) {
for (size_t i = 0; i < parallelism; i++) {
task_elem.task(i, 0);
}
return;
} else {
std::lock_guard<std::mutex> lock(m_mutex_task);
mgb_assert(
m_task_iter.load(std::memory_order_acquire) <= 0,
"The init value of m_all_sub_task is not zero.");
active();
m_nr_parallelism = parallelism;
m_task_iter.exchange(parallelism, std::memory_order_relaxed);
m_task = [&task_elem](size_t index, size_t thread_id) {
task_elem.task(index, thread_id);
};
for (uint32_t i = 0; i < m_nr_threads - 1; i++) {
m_workers[i]->work_flag = true;
}
int index = -1;
while ((index = m_task_iter.fetch_sub(1, std::memory_order_acq_rel)) &&
(index > 0)) {
m_task(static_cast<size_t>(m_nr_parallelism - index), m_nr_threads - 1);
}
sync();
}
}
void ThreadPool::set_affinity(AffinityCallBack affinity_cb) {
mgb_assert(affinity_cb, "The affinity callback must not be nullptr");
std::lock_guard<std::mutex> lock(m_mutex_task);
m_core_binding_function = affinity_cb;
for (size_t i = 0; i < m_nr_threads - 1; i++) {
m_workers[i]->affinity_flag = true;
}
m_main_affinity_flag = true;
}
size_t ThreadPool::nr_threads() const {
return m_nr_threads;
}
void ThreadPool::sync() {
bool no_finished = false;
do {
no_finished = false;
for (uint32_t i = 0; i < m_nr_threads - 1; ++i) {
if (m_workers[i]->work_flag) {
no_finished = true;
break;
}
}
if (no_finished) {
std::this_thread::yield();
}
} while (no_finished);
}
void ThreadPool::active() {
if (!m_active) {
std::unique_lock<std::mutex> lock(m_mutex);
m_active = true;
m_cv.notify_all();
}
}
void ThreadPool::deactive() {
std::lock_guard<std::mutex> lock_task(m_mutex_task);
std::unique_lock<std::mutex> lock(m_mutex);
m_active = false;
}
ThreadPool::~ThreadPool() {
std::lock_guard<std::mutex> lock_task(m_mutex_task);
{
std::unique_lock<std::mutex> lock(m_mutex);
m_stop = true;
m_active = false;
m_cv.notify_all();
}
for (auto& worker : m_workers) {
delete worker;
}
}
#else
void ThreadPool::add_task(const TaskElem& task_elem) {
for (size_t i = 0; i < task_elem.nr_parallelism; i++) {
task_elem.task(i, 0);
}
}
void ThreadPool::set_affinity(AffinityCallBack affinity_cb) {
mgb_assert(affinity_cb != nullptr, "The affinity callback is nullptr");
affinity_cb(0);
}
#endif