#include "core/pool.h"
#include "mem/cow.h"
#include "mem/heap.h"
#include "mem/sys.h"
#include <string.h>
#include <sched.h>
#define TASK_GRAIN ((int64_t)RAY_DISPATCH_MORSELS * RAY_MORSEL_ELEMS)
#define MAX_RING_CAP (1u << 16)
typedef struct {
ray_pool_t* pool;
uint32_t worker_id;
} worker_ctx_t;
static void worker_loop(void* arg) {
worker_ctx_t wctx = *(worker_ctx_t*)arg;
ray_sys_free(arg);
ray_pool_t* pool = wctx.pool;
ray_heap_init();
ray_rc_sync = true;
for (;;) {
ray_sem_wait(&pool->work_ready);
if (atomic_load_explicit(&pool->shutdown, memory_order_acquire))
break;
for (;;) {
uint32_t idx = atomic_fetch_add_explicit(&pool->task_tail, 1,
memory_order_acq_rel);
if (idx >= atomic_load_explicit(&pool->task_count,
memory_order_acquire))
break;
if (RAY_UNLIKELY(atomic_load_explicit(&pool->cancelled,
memory_order_relaxed))) {
atomic_fetch_sub_explicit(&pool->pending, 1,
memory_order_acq_rel);
continue;
}
ray_pool_task_t* t = &pool->tasks[idx & (pool->task_cap - 1)];
t->fn(t->ctx, wctx.worker_id, t->start, t->end);
atomic_fetch_sub_explicit(&pool->pending, 1,
memory_order_acq_rel);
}
}
ray_heap_destroy();
}
ray_err_t ray_pool_create(ray_pool_t* pool, uint32_t n_workers) {
memset(pool, 0, sizeof(*pool));
atomic_init(&pool->shutdown, 0);
atomic_init(&pool->task_tail, 0);
atomic_init(&pool->task_count, 0);
atomic_init(&pool->pending, 0);
atomic_init(&pool->cancelled, 0);
if (n_workers == 0) {
uint32_t ncpu = ray_thread_count();
n_workers = (ncpu > 1) ? ncpu - 1 : 0;
}
pool->n_workers = n_workers;
atomic_store_explicit(&pool->shutdown, 0, memory_order_relaxed);
pool->task_cap = 1024;
if (pool->task_cap < MAX_RING_CAP) {
}
pool->tasks = (ray_pool_task_t*)ray_sys_alloc(pool->task_cap * sizeof(ray_pool_task_t));
if (!pool->tasks) return RAY_ERR_OOM;
pool->task_head = 0;
atomic_store_explicit(&pool->task_tail, 0, memory_order_relaxed);
atomic_store_explicit(&pool->task_count, 0, memory_order_relaxed);
atomic_store_explicit(&pool->pending, 0, memory_order_relaxed);
ray_err_t err = ray_sem_init(&pool->work_ready, 0);
if (err != RAY_OK) {
ray_sys_free(pool->tasks);
return err;
}
if (n_workers > 0) {
pool->threads = (ray_thread_t*)ray_sys_alloc(n_workers * sizeof(ray_thread_t));
if (!pool->threads) {
ray_sem_destroy(&pool->work_ready);
ray_sys_free(pool->tasks);
return RAY_ERR_OOM;
}
for (uint32_t i = 0; i < n_workers; i++) {
worker_ctx_t* wctx = (worker_ctx_t*)ray_sys_alloc(sizeof(worker_ctx_t));
if (!wctx) {
atomic_store_explicit(&pool->shutdown, 1, memory_order_release);
for (uint32_t j = 0; j < i; j++) {
ray_sem_signal(&pool->work_ready);
}
for (uint32_t j = 0; j < i; j++) {
ray_thread_join(pool->threads[j]);
}
ray_sys_free(pool->threads);
ray_sem_destroy(&pool->work_ready);
ray_sys_free(pool->tasks);
return RAY_ERR_OOM;
}
wctx->pool = pool;
wctx->worker_id = i + 1;
err = ray_thread_create(&pool->threads[i], worker_loop, wctx);
if (err != RAY_OK) {
ray_sys_free(wctx);
atomic_store_explicit(&pool->shutdown, 1, memory_order_release);
for (uint32_t j = 0; j < i; j++) {
ray_sem_signal(&pool->work_ready);
}
for (uint32_t j = 0; j < i; j++) {
ray_thread_join(pool->threads[j]);
}
ray_sys_free(pool->threads);
ray_sem_destroy(&pool->work_ready);
ray_sys_free(pool->tasks);
return err;
}
}
}
return RAY_OK;
}
void ray_pool_free(ray_pool_t* pool) {
if (!pool) return;
atomic_store_explicit(&pool->shutdown, 1, memory_order_release);
for (uint32_t i = 0; i < pool->n_workers; i++) {
ray_sem_signal(&pool->work_ready);
}
for (uint32_t i = 0; i < pool->n_workers; i++) {
ray_thread_join(pool->threads[i]);
}
ray_sys_free(pool->threads);
ray_sem_destroy(&pool->work_ready);
ray_sys_free(pool->tasks);
memset(pool, 0, sizeof(*pool));
}
void ray_pool_dispatch(ray_pool_t* pool, ray_pool_fn fn, void* ctx,
int64_t total_elems) {
if (total_elems <= 0) return;
int64_t grain = TASK_GRAIN;
if (RAY_UNLIKELY(total_elems > INT64_MAX - grain + 1))
total_elems = INT64_MAX - grain + 1;
uint32_t n_tasks = (uint32_t)((total_elems + grain - 1) / grain);
if (n_tasks > pool->task_cap) {
uint32_t new_cap = pool->task_cap;
while (new_cap < n_tasks && new_cap < MAX_RING_CAP) new_cap *= 2;
if (new_cap > pool->task_cap) {
ray_pool_task_t* new_tasks = (ray_pool_task_t*)ray_sys_realloc(
pool->tasks, new_cap * sizeof(ray_pool_task_t));
if (new_tasks) {
pool->tasks = new_tasks;
pool->task_cap = new_cap;
}
}
}
if (n_tasks > pool->task_cap) {
n_tasks = pool->task_cap;
grain = (total_elems + n_tasks - 1) / n_tasks;
}
for (uint32_t i = 0; i < n_tasks; i++) {
int64_t start = (int64_t)i * grain;
int64_t end = start + grain;
if (end > total_elems) end = total_elems;
uint32_t slot = i & (pool->task_cap - 1);
pool->tasks[slot].fn = fn;
pool->tasks[slot].ctx = ctx;
pool->tasks[slot].start = start;
pool->tasks[slot].end = end;
}
pool->task_head = n_tasks;
atomic_store_explicit(&pool->task_count, n_tasks, memory_order_release);
atomic_store_explicit(&pool->task_tail, 0, memory_order_release);
atomic_store_explicit(&pool->pending, n_tasks, memory_order_release);
atomic_store_explicit(&ray_parallel_flag, 1, memory_order_release);
ray_rc_sync = true;
for (uint32_t i = 0; i < pool->n_workers; i++) {
ray_sem_signal(&pool->work_ready);
}
for (;;) {
uint32_t idx = atomic_fetch_add_explicit(&pool->task_tail, 1,
memory_order_acq_rel);
if (idx >= n_tasks) break;
if (RAY_UNLIKELY(atomic_load_explicit(&pool->cancelled,
memory_order_relaxed))) {
atomic_fetch_sub_explicit(&pool->pending, 1, memory_order_acq_rel);
continue;
}
ray_pool_task_t* t = &pool->tasks[idx & (pool->task_cap - 1)];
t->fn(t->ctx, 0, t->start, t->end);
atomic_fetch_sub_explicit(&pool->pending, 1, memory_order_acq_rel);
}
{
unsigned spin_count = 0;
while (atomic_load_explicit(&pool->pending, memory_order_acquire) > 0) {
#if defined(__x86_64__) || defined(__i386__)
__builtin_ia32_pause();
#elif defined(__aarch64__)
__asm__ volatile("yield" ::: "memory");
#endif
if (++spin_count % 1024 == 0) sched_yield();
}
}
atomic_store_explicit(&ray_parallel_flag, 0, memory_order_release);
atomic_thread_fence(memory_order_seq_cst);
ray_rc_sync = false;
}
void ray_pool_dispatch_n(ray_pool_t* pool, ray_pool_fn fn, void* ctx,
uint32_t n_tasks) {
if (n_tasks == 0) return;
if (n_tasks > pool->task_cap) {
uint32_t new_cap = pool->task_cap;
while (new_cap < n_tasks && new_cap < MAX_RING_CAP) new_cap *= 2;
if (new_cap > pool->task_cap) {
ray_pool_task_t* new_tasks = (ray_pool_task_t*)ray_sys_realloc(
pool->tasks, new_cap * sizeof(ray_pool_task_t));
if (new_tasks) {
pool->tasks = new_tasks;
pool->task_cap = new_cap;
}
}
}
if (n_tasks > pool->task_cap) n_tasks = pool->task_cap;
for (uint32_t i = 0; i < n_tasks; i++) {
uint32_t slot = i & (pool->task_cap - 1);
pool->tasks[slot].fn = fn;
pool->tasks[slot].ctx = ctx;
pool->tasks[slot].start = (int64_t)i;
pool->tasks[slot].end = (int64_t)i + 1;
}
pool->task_head = n_tasks;
atomic_store_explicit(&pool->task_count, n_tasks, memory_order_release);
atomic_store_explicit(&pool->task_tail, 0, memory_order_release);
atomic_store_explicit(&pool->pending, n_tasks, memory_order_release);
atomic_store_explicit(&ray_parallel_flag, 1, memory_order_release);
ray_rc_sync = true;
for (uint32_t i = 0; i < pool->n_workers; i++) {
ray_sem_signal(&pool->work_ready);
}
for (;;) {
uint32_t idx = atomic_fetch_add_explicit(&pool->task_tail, 1,
memory_order_acq_rel);
if (idx >= n_tasks) break;
if (RAY_UNLIKELY(atomic_load_explicit(&pool->cancelled,
memory_order_relaxed))) {
atomic_fetch_sub_explicit(&pool->pending, 1, memory_order_acq_rel);
continue;
}
ray_pool_task_t* t = &pool->tasks[idx & (pool->task_cap - 1)];
t->fn(t->ctx, 0, t->start, t->end);
atomic_fetch_sub_explicit(&pool->pending, 1, memory_order_acq_rel);
}
{
unsigned spin_count = 0;
while (atomic_load_explicit(&pool->pending, memory_order_acquire) > 0) {
#if defined(__x86_64__) || defined(__i386__)
__builtin_ia32_pause();
#elif defined(__aarch64__)
__asm__ volatile("yield" ::: "memory");
#endif
if (++spin_count % 1024 == 0) sched_yield();
}
}
atomic_store_explicit(&ray_parallel_flag, 0, memory_order_release);
atomic_thread_fence(memory_order_seq_cst);
ray_rc_sync = false;
}
static ray_pool_t g_pool;
static _Atomic(uint32_t) g_pool_init_state = 0;
ray_pool_t* ray_pool_get(void) {
uint32_t state = atomic_load_explicit(&g_pool_init_state, memory_order_acquire);
if (state == 2) return &g_pool;
if (state == 0) {
uint32_t expected = 0;
if (atomic_compare_exchange_strong_explicit(&g_pool_init_state, &expected, 1,
memory_order_acq_rel,
memory_order_acquire)) {
ray_err_t err = ray_pool_create(&g_pool, 0);
if (err == RAY_OK) {
atomic_store_explicit(&g_pool_init_state, 2, memory_order_release);
return &g_pool;
}
atomic_store_explicit(&g_pool_init_state, 0, memory_order_release);
return NULL;
}
}
{
unsigned spin_count = 0;
for (;;) {
uint32_t s = atomic_load_explicit(&g_pool_init_state, memory_order_acquire);
if (s == 2) return &g_pool;
if (s == 0) return NULL;
#if defined(__x86_64__) || defined(__i386__)
__builtin_ia32_pause();
#elif defined(__aarch64__)
__asm__ volatile("yield" ::: "memory");
#endif
if (++spin_count % 1024 == 0) sched_yield();
}
}
}
ray_err_t ray_pool_init(uint32_t n_workers) {
uint32_t expected = 0;
if (!atomic_compare_exchange_strong_explicit(&g_pool_init_state, &expected, 1,
memory_order_acq_rel,
memory_order_acquire)) {
if (expected == 1) {
while (atomic_load_explicit(&g_pool_init_state, memory_order_acquire) == 1) {
#if defined(__x86_64__) || defined(__i386__)
__builtin_ia32_pause();
#elif defined(__aarch64__)
__asm__ volatile("yield" ::: "memory");
#endif
}
}
return RAY_OK;
}
ray_err_t err = ray_pool_create(&g_pool, n_workers);
if (err == RAY_OK) {
atomic_store_explicit(&g_pool_init_state, 2, memory_order_release);
} else {
atomic_store_explicit(&g_pool_init_state, 0, memory_order_release);
}
return err;
}
void ray_pool_destroy(void) {
uint32_t expected = 2;
if (!atomic_compare_exchange_strong_explicit(&g_pool_init_state, &expected, 3,
memory_order_acq_rel,
memory_order_acquire))
return;
ray_pool_free(&g_pool);
atomic_store_explicit(&g_pool_init_state, 0, memory_order_release);
}
void ray_cancel(void) {
ray_pool_t* pool = ray_pool_get();
if (pool)
atomic_store_explicit(&pool->cancelled, 1, memory_order_release);
}