#ifndef HIGHWAY_HWY_CONTRIB_THREAD_POOL_SPIN_H_
#define HIGHWAY_HWY_CONTRIB_THREAD_POOL_SPIN_H_
#include <stdint.h>
#include <atomic>
#include "hwy/base.h"
#include "hwy/cache_control.h"
#ifndef HWY_ENABLE_MONITORX
#if HWY_ARCH_X86 && ((HWY_COMPILER_CLANG >= 900) || \
(HWY_COMPILER_GCC_ACTUAL >= 502) || defined(__MWAITX__))
#define HWY_ENABLE_MONITORX 1
#else
#define HWY_ENABLE_MONITORX 0
#endif
#endif
#ifndef HWY_ENABLE_UMONITOR
#if HWY_ARCH_X86 && ((HWY_COMPILER_CLANG >= 900) || \
(HWY_COMPILER_GCC_ACTUAL >= 901) || defined(__WAITPKG__))
#define HWY_ENABLE_UMONITOR 1
#else
#define HWY_ENABLE_UMONITOR 0
#endif
#endif
#ifndef HWY_ENABLE_SPIN_ASM
#if (HWY_COMPILER_CLANG || HWY_COMPILER_GCC) && HWY_ARCH_X86_64
#define HWY_ENABLE_SPIN_ASM 1
#else
#define HWY_ENABLE_SPIN_ASM 0
#endif
#endif
#if HWY_ENABLE_MONITORX || HWY_ENABLE_UMONITOR
#if HWY_ENABLE_SPIN_ASM
#define HWY_INLINE_SPIN HWY_INLINE
#else
#define HWY_INLINE_SPIN
#include <x86intrin.h>
#endif
#include "hwy/x86_cpuid.h"
#endif
namespace hwy {
struct SpinResult {
uint32_t current;
uint32_t reps;
};
enum class SpinType : uint8_t {
kMonitorX = 1, kUMonitor, kPause,
kSentinel };
static inline const char* ToString(SpinType type) {
switch (type) {
case SpinType::kMonitorX:
return "MonitorX_C1";
case SpinType::kUMonitor:
return "UMonitor_C0.2";
case SpinType::kPause:
return "Pause";
case SpinType::kSentinel:
return nullptr;
}
}
struct SpinPause {
SpinType Type() const { return SpinType::kPause; }
HWY_INLINE SpinResult UntilDifferent(
const uint32_t prev, const std::atomic<uint32_t>& watched) const {
for (uint32_t reps = 0;; ++reps) {
const uint32_t current = watched.load(std::memory_order_acquire);
if (current != prev) return SpinResult{current, reps};
hwy::Pause();
}
}
HWY_INLINE size_t UntilEqual(const uint32_t expected,
const std::atomic<uint32_t>& watched) const {
for (size_t reps = 0;; ++reps) {
const uint32_t current = watched.load(std::memory_order_acquire);
if (current == expected) return reps;
hwy::Pause();
}
}
};
#if HWY_ENABLE_MONITORX || HWY_IDE
#if !HWY_ENABLE_SPIN_ASM
HWY_PUSH_ATTRIBUTES("mwaitx")
#endif
class SpinMonitorX {
public:
SpinType Type() const { return SpinType::kMonitorX; }
HWY_INLINE_SPIN SpinResult UntilDifferent(
const uint32_t prev, const std::atomic<uint32_t>& watched) const {
for (uint32_t reps = 0;; ++reps) {
uint32_t current = watched.load(std::memory_order_acquire);
if (current != prev) return SpinResult{current, reps};
Monitor(&watched);
current = watched.load(std::memory_order_acquire);
if (current != prev) return SpinResult{current, reps};
Wait();
}
}
HWY_INLINE_SPIN size_t UntilEqual(
const uint32_t expected, const std::atomic<uint32_t>& watched) const {
for (size_t reps = 0;; ++reps) {
uint32_t current = watched.load(std::memory_order_acquire);
if (current == expected) return reps;
Monitor(&watched);
current = watched.load(std::memory_order_acquire);
if (current == expected) return reps;
Wait();
}
}
private:
static HWY_INLINE void Monitor(const void* addr) {
#if HWY_ENABLE_SPIN_ASM
asm volatile("monitorx" ::"a"(addr), "c"(0), "d"(0));
#else
_mm_monitorx(const_cast<void*>(addr), 0, 0);
#endif
}
static HWY_INLINE void Wait() {
#if HWY_ENABLE_SPIN_ASM
asm volatile("mwaitx" ::"a"(kHints), "b"(0), "c"(kExtensions));
#else
_mm_mwaitx(kExtensions, kHints, 0);
#endif
}
static constexpr unsigned kHints = 0x0; static constexpr unsigned kExtensions = 0;
};
#if !HWY_ENABLE_SPIN_ASM
HWY_POP_ATTRIBUTES
#endif
#endif
#if HWY_ENABLE_UMONITOR || HWY_IDE
#if !HWY_ENABLE_SPIN_ASM
HWY_PUSH_ATTRIBUTES("waitpkg")
#endif
class SpinUMonitor {
public:
SpinType Type() const { return SpinType::kUMonitor; }
HWY_INLINE_SPIN SpinResult UntilDifferent(
const uint32_t prev, const std::atomic<uint32_t>& watched) const {
for (uint32_t reps = 0;; ++reps) {
uint32_t current = watched.load(std::memory_order_acquire);
if (current != prev) return SpinResult{current, reps};
Monitor(&watched);
current = watched.load(std::memory_order_acquire);
if (current != prev) return SpinResult{current, reps};
Wait();
}
}
HWY_INLINE_SPIN size_t UntilEqual(
const uint32_t expected, const std::atomic<uint32_t>& watched) const {
for (size_t reps = 0;; ++reps) {
uint32_t current = watched.load(std::memory_order_acquire);
if (current == expected) return reps;
Monitor(&watched);
current = watched.load(std::memory_order_acquire);
if (current == expected) return reps;
Wait();
}
}
private:
static HWY_INLINE void Monitor(const void* addr) {
#if HWY_ENABLE_SPIN_ASM
asm volatile("umonitor %%rcx" ::"c"(addr));
#else
_umonitor(const_cast<void*>(addr));
#endif
}
static HWY_INLINE void Wait() {
#if HWY_ENABLE_SPIN_ASM
asm volatile("umwait %%ecx" ::"c"(kControl), "d"(kDeadline >> 32),
"a"(kDeadline & 0xFFFFFFFFu));
#else
_umwait(kControl, kDeadline);
#endif
}
static constexpr unsigned kControl = 0; static constexpr uint64_t kDeadline = ~uint64_t{0}; };
#if !HWY_ENABLE_SPIN_ASM
HWY_POP_ATTRIBUTES
#endif
#endif
static inline SpinType DetectSpin(int disabled = 0) {
const auto HWY_MAYBE_UNUSED enabled = [disabled](SpinType type) {
return (disabled & (1 << static_cast<int>(type))) == 0;
};
#if HWY_ENABLE_MONITORX
if (enabled(SpinType::kMonitorX) && x86::IsAMD()) {
uint32_t abcd[4];
x86::Cpuid(0x80000001U, 0, abcd);
if (x86::IsBitSet(abcd[2], 29)) return SpinType::kMonitorX;
}
#endif
#if HWY_ENABLE_UMONITOR
if (enabled(SpinType::kUMonitor) && x86::MaxLevel() >= 7) {
uint32_t abcd[4];
x86::Cpuid(7, 0, abcd);
if (x86::IsBitSet(abcd[2], 5)) return SpinType::kUMonitor;
}
#endif
if (!enabled(SpinType::kPause)) {
HWY_WARN("Ignoring attempt to disable Pause, it is the only option left.");
}
return SpinType::kPause;
}
template <class Func, typename... Args>
HWY_INLINE void CallWithSpin(SpinType spin_type, Func&& func, Args&&... args) {
switch (spin_type) {
#if HWY_ENABLE_MONITORX
case SpinType::kMonitorX:
func(SpinMonitorX(), std::forward<Args>(args)...);
break;
#endif
#if HWY_ENABLE_UMONITOR
case SpinType::kUMonitor:
func(SpinUMonitor(), std::forward<Args>(args)...);
break;
#endif
case SpinType::kPause:
default:
func(SpinPause(), std::forward<Args>(args)...);
break;
}
}
}
#endif