#include "ops/rowsel.h"
#include "ops/ops.h"
#include "core/pool.h"
#include <stdint.h>
#include <stdbool.h>
#include <string.h>
ray_t* ray_rowsel_new(int64_t nrows, int64_t total_pass, int64_t idx_count) {
if (nrows < 0 || total_pass < 0 || total_pass > nrows ||
idx_count < 0 || idx_count > total_pass) return NULL;
size_t payload = ray_rowsel_payload_bytes(nrows, idx_count);
ray_t* block = ray_alloc(payload);
if (!block) return NULL;
ray_rowsel_t* m = ray_rowsel_meta(block);
m->total_pass = total_pass;
m->nrows = nrows;
m->n_segs = (uint32_t)((nrows + RAY_MORSEL_ELEMS - 1) / RAY_MORSEL_ELEMS);
if (nrows <= 0) m->n_segs = 0;
m->_pad = 0;
return block;
}
void ray_rowsel_release(ray_t* block) {
if (block) ray_release(block);
}
typedef struct {
const uint8_t* pred_data;
int64_t nrows;
uint32_t* popcount;
} rowsel_pass1_ctx_t;
static void rowsel_pass1_fn(void* vctx, uint32_t worker_id,
int64_t start_seg, int64_t end_seg) {
(void)worker_id;
rowsel_pass1_ctx_t* c = (rowsel_pass1_ctx_t*)vctx;
const uint8_t* pred = c->pred_data;
int64_t nrows = c->nrows;
uint32_t* popcount = c->popcount;
for (int64_t seg = start_seg; seg < end_seg; seg++) {
int64_t base = seg * RAY_MORSEL_ELEMS;
int64_t end = base + RAY_MORSEL_ELEMS;
if (end > nrows) end = nrows;
uint32_t n = 0;
for (int64_t r = base; r < end; r++)
n += pred[r] != 0;
popcount[seg] = n;
}
}
typedef struct {
const uint8_t* pred_data;
int64_t nrows;
const uint8_t* seg_flags;
const uint32_t* seg_offsets;
uint16_t* idx;
} rowsel_pass2_ctx_t;
static void rowsel_pass2_fn(void* vctx, uint32_t worker_id,
int64_t start_seg, int64_t end_seg) {
(void)worker_id;
rowsel_pass2_ctx_t* c = (rowsel_pass2_ctx_t*)vctx;
const uint8_t* pred = c->pred_data;
int64_t nrows = c->nrows;
for (int64_t seg = start_seg; seg < end_seg; seg++) {
if (c->seg_flags[seg] != RAY_SEL_MIX) continue;
int64_t base = seg * RAY_MORSEL_ELEMS;
int64_t end = base + RAY_MORSEL_ELEMS;
if (end > nrows) end = nrows;
uint16_t* out = c->idx + c->seg_offsets[seg];
uint32_t out_n = 0;
for (int64_t r = base; r < end; r++) {
if (pred[r])
out[out_n++] = (uint16_t)(r - base);
}
}
}
ray_t* ray_rowsel_from_pred(ray_t* pred) {
if (!pred || pred->type != RAY_BOOL) return NULL;
int64_t nrows = pred->len;
if (nrows == 0) {
return ray_rowsel_new(0, 0, 0);
}
const uint8_t* pred_data = (const uint8_t*)ray_data(pred);
uint32_t n_segs = (uint32_t)((nrows + RAY_MORSEL_ELEMS - 1) / RAY_MORSEL_ELEMS);
ray_t* pop_block = ray_alloc((size_t)n_segs * sizeof(uint32_t));
if (!pop_block) return NULL;
uint32_t* popcount = (uint32_t*)ray_data(pop_block);
rowsel_pass1_ctx_t p1 = {
.pred_data = pred_data,
.nrows = nrows,
.popcount = popcount,
};
ray_pool_t* pool = ray_pool_get();
if (pool && nrows >= RAY_PARALLEL_THRESHOLD)
ray_pool_dispatch(pool, rowsel_pass1_fn, &p1, (int64_t)n_segs);
else
rowsel_pass1_fn(&p1, 0, 0, (int64_t)n_segs);
int64_t total_pass = 0;
int64_t idx_count = 0;
for (uint32_t s = 0; s < n_segs; s++) {
int64_t seg_start = (int64_t)s * RAY_MORSEL_ELEMS;
int64_t seg_end = seg_start + RAY_MORSEL_ELEMS;
if (seg_end > nrows) seg_end = nrows;
int64_t seg_len = seg_end - seg_start;
uint32_t pc = popcount[s];
total_pass += pc;
if (pc != 0 && (int64_t)pc != seg_len)
idx_count += pc;
}
if (total_pass == nrows) {
ray_release(pop_block);
return NULL;
}
ray_t* block = ray_rowsel_new(nrows, total_pass, idx_count);
if (!block) {
ray_release(pop_block);
return NULL;
}
uint8_t* seg_flags = ray_rowsel_flags(block);
uint32_t* seg_offsets = ray_rowsel_offsets(block);
uint32_t cum = 0;
for (uint32_t s = 0; s < n_segs; s++) {
seg_offsets[s] = cum;
int64_t seg_start = (int64_t)s * RAY_MORSEL_ELEMS;
int64_t seg_end = seg_start + RAY_MORSEL_ELEMS;
if (seg_end > nrows) seg_end = nrows;
int64_t seg_len = seg_end - seg_start;
uint32_t pc = popcount[s];
if (pc == 0) {
seg_flags[s] = RAY_SEL_NONE;
} else if ((int64_t)pc == seg_len) {
seg_flags[s] = RAY_SEL_ALL;
} else {
seg_flags[s] = RAY_SEL_MIX;
cum += pc;
}
}
seg_offsets[n_segs] = cum;
if (cum > 0) {
rowsel_pass2_ctx_t p2 = {
.pred_data = pred_data,
.nrows = nrows,
.seg_flags = seg_flags,
.seg_offsets = seg_offsets,
.idx = ray_rowsel_idx(block),
};
if (pool && nrows >= RAY_PARALLEL_THRESHOLD)
ray_pool_dispatch(pool, rowsel_pass2_fn, &p2, (int64_t)n_segs);
else
rowsel_pass2_fn(&p2, 0, 0, (int64_t)n_segs);
}
ray_release(pop_block);
return block;
}
typedef struct {
const uint8_t* flags;
const uint32_t* offsets;
const uint16_t* idx;
const uint32_t* flat_offsets;
int64_t* out;
int64_t nrows;
} rowsel_to_idx_ctx_t;
static void rowsel_to_idx_fn(void* vctx, uint32_t worker_id,
int64_t start_seg, int64_t end_seg) {
(void)worker_id;
rowsel_to_idx_ctx_t* c = (rowsel_to_idx_ctx_t*)vctx;
int64_t nrows = c->nrows;
for (int64_t seg = start_seg; seg < end_seg; seg++) {
uint8_t f = c->flags[seg];
if (f == RAY_SEL_NONE) continue;
int64_t base = seg * RAY_MORSEL_ELEMS;
int64_t end = base + RAY_MORSEL_ELEMS;
if (end > nrows) end = nrows;
int64_t j = c->flat_offsets[seg];
if (f == RAY_SEL_ALL) {
for (int64_t r = base; r < end; r++) c->out[j++] = r;
} else {
const uint16_t* slice = c->idx + c->offsets[seg];
uint32_t n = c->offsets[seg + 1] - c->offsets[seg];
for (uint32_t i = 0; i < n; i++) c->out[j++] = base + slice[i];
}
}
}
ray_t* ray_rowsel_to_indices(ray_t* sel) {
if (!sel) return NULL;
ray_rowsel_t* m = ray_rowsel_meta(sel);
const uint8_t* flags = ray_rowsel_flags(sel);
const uint32_t* offsets = ray_rowsel_offsets(sel);
const uint16_t* idx = ray_rowsel_idx(sel);
int64_t nrows = m->nrows;
int64_t total_pass = m->total_pass;
uint32_t n_segs = m->n_segs;
ray_t* block = ray_alloc((size_t)total_pass * sizeof(int64_t));
if (!block) return NULL;
int64_t* out = (int64_t*)ray_data(block);
if (total_pass == 0 || n_segs == 0) return block;
ray_t* fo_block = ray_alloc((size_t)n_segs * sizeof(uint32_t));
if (!fo_block) { ray_release(block); return NULL; }
uint32_t* flat_offsets = (uint32_t*)ray_data(fo_block);
uint32_t cum = 0;
for (uint32_t s = 0; s < n_segs; s++) {
flat_offsets[s] = cum;
uint8_t f = flags[s];
if (f == RAY_SEL_NONE) continue;
if (f == RAY_SEL_ALL) {
int64_t base = (int64_t)s * RAY_MORSEL_ELEMS;
int64_t end = base + RAY_MORSEL_ELEMS;
if (end > nrows) end = nrows;
cum += (uint32_t)(end - base);
} else {
cum += offsets[s + 1] - offsets[s];
}
}
rowsel_to_idx_ctx_t ctx = {
.flags = flags,
.offsets = offsets,
.idx = idx,
.flat_offsets = flat_offsets,
.out = out,
.nrows = nrows,
};
ray_pool_t* pool = ray_pool_get();
if (pool && nrows >= RAY_PARALLEL_THRESHOLD)
ray_pool_dispatch(pool, rowsel_to_idx_fn, &ctx, (int64_t)n_segs);
else
rowsel_to_idx_fn(&ctx, 0, 0, (int64_t)n_segs);
ray_release(fo_block);
return block;
}
ray_t* ray_rowsel_refine(ray_t* existing, ray_t* pred) {
if (!existing) return ray_rowsel_from_pred(pred);
if (!pred || pred->type != RAY_BOOL) return NULL;
ray_rowsel_t* em = ray_rowsel_meta(existing);
int64_t nrows = em->nrows;
if (pred->len != nrows) return NULL;
const uint8_t* pred_data = (const uint8_t*)ray_data(pred);
const uint8_t* e_flags = ray_rowsel_flags(existing);
const uint32_t* e_offsets = ray_rowsel_offsets(existing);
const uint16_t* e_idx = ray_rowsel_idx(existing);
uint32_t n_segs = em->n_segs;
ray_t* pop_block = ray_alloc((size_t)n_segs * sizeof(uint32_t));
if (!pop_block) return NULL;
uint32_t* popcount = (uint32_t*)ray_data(pop_block);
memset(popcount, 0, (size_t)n_segs * sizeof(uint32_t));
int64_t total_pass = 0;
int64_t idx_count = 0;
for (uint32_t s = 0; s < n_segs; s++) {
uint8_t f = e_flags[s];
if (f == RAY_SEL_NONE) continue;
int64_t base = (int64_t)s * RAY_MORSEL_ELEMS;
int64_t end = base + RAY_MORSEL_ELEMS;
if (end > nrows) end = nrows;
int64_t seg_len = end - base;
uint32_t n = 0;
if (f == RAY_SEL_ALL) {
for (int64_t r = base; r < end; r++)
n += pred_data[r] != 0;
} else {
const uint16_t* src = e_idx + e_offsets[s];
uint32_t src_n = e_offsets[s + 1] - e_offsets[s];
for (uint32_t i = 0; i < src_n; i++) {
int64_t r = base + src[i];
n += pred_data[r] != 0;
}
}
popcount[s] = n;
total_pass += n;
if (n != 0 && (int64_t)n != seg_len)
idx_count += n;
}
if (total_pass == nrows) {
ray_release(pop_block);
return NULL;
}
ray_t* block = ray_rowsel_new(nrows, total_pass, idx_count);
if (!block) {
ray_release(pop_block);
return NULL;
}
uint8_t* seg_flags = ray_rowsel_flags(block);
uint32_t* seg_offsets = ray_rowsel_offsets(block);
uint16_t* idx_out = ray_rowsel_idx(block);
uint32_t cum = 0;
for (uint32_t s = 0; s < n_segs; s++) {
seg_offsets[s] = cum;
int64_t base = (int64_t)s * RAY_MORSEL_ELEMS;
int64_t end = base + RAY_MORSEL_ELEMS;
if (end > nrows) end = nrows;
int64_t seg_len = end - base;
uint32_t pc = popcount[s];
if (pc == 0) {
seg_flags[s] = RAY_SEL_NONE;
continue;
}
if ((int64_t)pc == seg_len) {
seg_flags[s] = RAY_SEL_ALL;
continue;
}
seg_flags[s] = RAY_SEL_MIX;
uint16_t* dst = idx_out + cum;
uint32_t dn = 0;
uint8_t f = e_flags[s];
if (f == RAY_SEL_ALL) {
for (int64_t r = base; r < end; r++)
if (pred_data[r])
dst[dn++] = (uint16_t)(r - base);
} else {
const uint16_t* src = e_idx + e_offsets[s];
uint32_t src_n = e_offsets[s + 1] - e_offsets[s];
for (uint32_t i = 0; i < src_n; i++) {
int64_t r = base + src[i];
if (pred_data[r])
dst[dn++] = (uint16_t)(r - base);
}
}
cum += pc;
}
seg_offsets[n_segs] = cum;
ray_release(pop_block);
return block;
}