#ifndef HIGHS_TASK_H_
#define HIGHS_TASK_H_
#include <atomic>
#include <cassert>
#include <cstring>
#include <type_traits>
#include "parallel/HighsSpinMutex.h"
class HighsSplitDeque;
class HighsTask {
friend class HighsSplitDeque;
public:
enum Constants {
kMaxTaskSize = 64,
};
class Interrupt {};
private:
static constexpr uint64_t kFinishedFlag = 1;
static constexpr uint64_t kCancelFlag = 2;
static constexpr uint64_t kPtrMask = ~(kFinishedFlag | kCancelFlag);
struct Metadata {
std::atomic<uintptr_t> stealer;
};
class CallableBase {
public:
virtual void operator()() = 0;
};
template <typename F>
class Callable : public CallableBase {
F functor;
public:
Callable(F&& functor) : functor(std::forward<F>(functor)) {}
virtual void operator()() override {
F callFunctor = std::move(functor);
callFunctor();
}
};
char taskData[kMaxTaskSize - sizeof(Metadata)];
Metadata metadata;
CallableBase& getCallable() {
union {
CallableBase* callablePtr;
char* storagePtr;
} u;
u.storagePtr = this->taskData;
return *u.callablePtr;
}
HighsSplitDeque* markAsFinished(HighsSplitDeque* stealer) {
uintptr_t state =
metadata.stealer.exchange(kFinishedFlag, std::memory_order_release);
HighsSplitDeque* waitingOwner =
reinterpret_cast<HighsSplitDeque*>(state & kPtrMask);
if (waitingOwner != stealer) return waitingOwner;
return nullptr;
}
HighsSplitDeque* run(HighsSplitDeque* stealer) {
uintptr_t state = metadata.stealer.fetch_or(
reinterpret_cast<uintptr_t>(stealer), std::memory_order_acquire);
if (state == 0) getCallable()();
return markAsFinished(stealer);
}
public:
template <typename F>
void setTaskData(F&& f) {
static_assert(sizeof(F) <= sizeof(taskData),
"given task type exceeds maximum size allowed for deque\n");
static_assert(std::is_trivially_destructible<F>::value,
"given task type must be trivially destructible\n");
metadata.stealer.store(0, std::memory_order_relaxed);
new (taskData) Callable<F>(std::forward<F>(f));
assert(static_cast<CallableBase*>(reinterpret_cast<Callable<F>*>(
taskData)) == reinterpret_cast<CallableBase*>(taskData));
}
void cancel() {
metadata.stealer.fetch_or(kCancelFlag, std::memory_order_release);
}
void run() {
if (metadata.stealer.load(std::memory_order_relaxed) == 0) getCallable()();
}
bool requestNotifyWhenFinished(HighsSplitDeque* owner,
HighsSplitDeque* stealer) {
uintptr_t xormask = reinterpret_cast<uintptr_t>(owner) ^
reinterpret_cast<uintptr_t>(stealer);
uintptr_t state =
metadata.stealer.fetch_xor(xormask, std::memory_order_relaxed);
assert(stealer != nullptr);
return (state & kFinishedFlag) == 0;
}
bool isFinished() const {
uintptr_t state = metadata.stealer.load(std::memory_order_acquire);
return state & kFinishedFlag;
}
bool isCancelled() const {
uintptr_t state = metadata.stealer.load(std::memory_order_relaxed);
return state & kCancelFlag;
}
HighsSplitDeque* getStealerIfUnfinished(bool* cancelled = nullptr) {
uintptr_t state = metadata.stealer.load(std::memory_order_acquire);
if (state & kFinishedFlag)
return nullptr;
else {
while ((state & ~kCancelFlag) == 0) {
HighsSpinMutex::yieldProcessor();
state = metadata.stealer.load(std::memory_order_acquire);
}
}
if (state & kFinishedFlag) return nullptr;
if (cancelled) *cancelled = state & kCancelFlag;
return reinterpret_cast<HighsSplitDeque*>(state & kPtrMask);
}
};
#endif