#include "common/task_system/task_scheduler.h"
#include "main/client_context.h"
#include "main/database.h"
#include "processor/processor.h"
#if defined(__APPLE__)
#include <pthread.h>
#include <pthread/qos.h>
#endif
namespace lbug {
namespace common {
#ifndef __SINGLE_THREADED__
#if defined(__APPLE__)
TaskScheduler::TaskScheduler(uint64_t numWorkerThreads, uint32_t threadQos)
#else
TaskScheduler::TaskScheduler(uint64_t numWorkerThreads)
#endif
: stopWorkerThreads{false}, nextScheduledTaskID{0} {
#if defined(__APPLE__)
this->threadQos = threadQos;
#endif
for (auto n = 0u; n < numWorkerThreads; ++n) {
workerThreads.emplace_back([&] { runWorkerThread(); });
}
}
TaskScheduler::~TaskScheduler() {
lock_t lck{taskSchedulerMtx};
stopWorkerThreads = true;
lck.unlock();
cv.notify_all();
for (auto& thread : workerThreads) {
thread.join();
}
}
void TaskScheduler::scheduleTaskAndWaitOrError(const std::shared_ptr<Task>& task,
processor::ExecutionContext* context, bool launchNewWorkerThread) {
for (auto& dependency : task->children) {
scheduleTaskAndWaitOrError(dependency, context);
if (dependency->terminate()) {
return;
}
}
std::thread newWorkerThread;
if (launchNewWorkerThread) {
task->registerThread();
newWorkerThread = std::thread(runTask, task.get());
}
auto scheduledTask = pushTaskIntoQueue(task);
cv.notify_all();
std::unique_lock<std::mutex> taskLck{task->taskMtx, std::defer_lock};
while (true) {
taskLck.lock();
bool timedWait = false;
auto timeout = 0u;
if (task->isCompletedNoLock()) {
taskLck.unlock();
break;
}
if (context->clientContext->hasTimeout()) {
timeout = context->clientContext->getTimeoutRemainingInMS();
if (timeout == 0) {
context->clientContext->interrupt();
} else {
timedWait = true;
}
} else if (task->hasExceptionNoLock()) {
context->clientContext->interrupt();
}
if (timedWait) {
task->cv.wait_for(taskLck, std::chrono::milliseconds(timeout));
} else {
task->cv.wait(taskLck);
}
taskLck.unlock();
}
if (launchNewWorkerThread) {
newWorkerThread.join();
}
if (task->hasException()) {
removeErroringTask(scheduledTask->ID);
std::rethrow_exception(task->getExceptionPtr());
}
}
void TaskScheduler::runWorkerThread() {
#if defined(__APPLE__)
qos_class_t qosClass = (qos_class_t)threadQos;
if (qosClass != QOS_CLASS_DEFAULT && qosClass != QOS_CLASS_UNSPECIFIED) {
auto pthreadQosStatus = pthread_set_qos_class_self_np(qosClass, 0);
UNUSED(pthreadQosStatus);
}
#endif
std::unique_lock<std::mutex> lck{taskSchedulerMtx, std::defer_lock};
std::exception_ptr exceptionPtr = nullptr;
std::shared_ptr<ScheduledTask> scheduledTask = nullptr;
while (true) {
lck.lock();
if (scheduledTask != nullptr) {
if (exceptionPtr != nullptr) {
scheduledTask->task->setException(exceptionPtr);
exceptionPtr = nullptr;
}
scheduledTask->task->deRegisterThreadAndFinalizeTask();
scheduledTask = nullptr;
}
cv.wait(lck, [&] {
scheduledTask = getTaskAndRegister();
return scheduledTask != nullptr || stopWorkerThreads;
});
lck.unlock();
if (stopWorkerThreads) {
return;
}
try {
scheduledTask->task->run();
} catch (std::exception& e) {
exceptionPtr = std::current_exception();
}
}
}
#else
TaskScheduler::TaskScheduler(uint64_t) : stopWorkerThreads{false}, nextScheduledTaskID{0} {}
TaskScheduler::~TaskScheduler() {
stopWorkerThreads = true;
}
void TaskScheduler::scheduleTaskAndWaitOrError(const std::shared_ptr<Task>& task,
processor::ExecutionContext* context, bool) {
for (auto& dependency : task->children) {
scheduleTaskAndWaitOrError(dependency, context);
if (dependency->terminate()) {
return;
}
}
task->registerThread();
runTask(task.get());
if (task->hasException()) {
removeErroringTask(task->ID);
std::rethrow_exception(task->getExceptionPtr());
}
}
#endif
std::shared_ptr<ScheduledTask> TaskScheduler::pushTaskIntoQueue(const std::shared_ptr<Task>& task) {
lock_t lck{taskSchedulerMtx};
auto scheduledTask = std::make_shared<ScheduledTask>(task, nextScheduledTaskID++);
taskQueue.push_back(scheduledTask);
return scheduledTask;
}
std::shared_ptr<ScheduledTask> TaskScheduler::getTaskAndRegister() {
if (taskQueue.empty()) {
return nullptr;
}
auto it = taskQueue.begin();
while (it != taskQueue.end()) {
auto task = (*it)->task;
if (!task->registerThread()) {
if (task->isCompletedSuccessfully()) { it = taskQueue.erase(it);
} else { ++it;
}
} else {
return *it;
}
}
return nullptr;
}
void TaskScheduler::removeErroringTask(uint64_t scheduledTaskID) {
lock_t lck{taskSchedulerMtx};
for (auto it = taskQueue.begin(); it != taskQueue.end(); ++it) {
if (scheduledTaskID == (*it)->ID) {
taskQueue.erase(it);
return;
}
}
}
void TaskScheduler::runTask(Task* task) {
try {
task->run();
task->deRegisterThreadAndFinalizeTask();
} catch (std::exception& e) {
task->setException(std::current_exception());
task->deRegisterThreadAndFinalizeTask();
}
}
TaskScheduler* TaskScheduler::Get(const main::ClientContext& context) {
return context.getDatabase()->getQueryProcessor()->getTaskScheduler();
}
} }