#include "ops/internal.h"
#include "lang/internal.h"
#include "ops/ops.h"
#include "mem/sys.h"
int sort_cmp(const sort_cmp_ctx_t* ctx, int64_t a, int64_t b) {
for (uint8_t k = 0; k < ctx->n_sort; k++) {
ray_t* col = ctx->vecs[k];
if (!col) continue;
int cmp = 0;
int null_cmp = 0;
int desc = ctx->desc ? ctx->desc[k] : 0;
int nf = ctx->nulls_first ? ctx->nulls_first[k] : desc;
int a_null = ray_vec_is_null(col, a);
int b_null = ray_vec_is_null(col, b);
if (a_null || b_null) {
null_cmp = 1;
if (a_null && b_null) cmp = 0;
else if (a_null) cmp = nf ? -1 : 1;
else cmp = nf ? 1 : -1;
} else if (col->type == RAY_F64) {
double va = ((double*)ray_data(col))[a];
double vb = ((double*)ray_data(col))[b];
if (va < vb) cmp = -1;
else if (va > vb) cmp = 1;
} else if (col->type == RAY_I64 || col->type == RAY_TIMESTAMP) {
int64_t va = ((int64_t*)ray_data(col))[a];
int64_t vb = ((int64_t*)ray_data(col))[b];
if (va < vb) cmp = -1;
else if (va > vb) cmp = 1;
} else if (col->type == RAY_I32) {
int32_t va = ((int32_t*)ray_data(col))[a];
int32_t vb = ((int32_t*)ray_data(col))[b];
if (va < vb) cmp = -1;
else if (va > vb) cmp = 1;
} else if (RAY_IS_SYM(col->type)) {
int64_t va = ray_read_sym(ray_data(col), a, col->type, col->attrs);
int64_t vb = ray_read_sym(ray_data(col), b, col->type, col->attrs);
ray_t* sa = ray_sym_str(va);
ray_t* sb = ray_sym_str(vb);
if (sa && sb) cmp = ray_str_cmp(sa, sb);
} else if (col->type == RAY_I16) {
int16_t va = ((int16_t*)ray_data(col))[a];
int16_t vb = ((int16_t*)ray_data(col))[b];
if (va < vb) cmp = -1;
else if (va > vb) cmp = 1;
} else if (col->type == RAY_BOOL || col->type == RAY_U8) {
uint8_t va = ((uint8_t*)ray_data(col))[a];
uint8_t vb = ((uint8_t*)ray_data(col))[b];
if (va < vb) cmp = -1;
else if (va > vb) cmp = 1;
} else if (col->type == RAY_DATE || col->type == RAY_TIME) {
int32_t va = ((int32_t*)ray_data(col))[a];
int32_t vb = ((int32_t*)ray_data(col))[b];
if (va < vb) cmp = -1;
else if (va > vb) cmp = 1;
} else if (col->type == RAY_GUID) {
const uint8_t* base = (const uint8_t*)ray_data(col);
cmp = memcmp(base + a * 16, base + b * 16, 16);
} else if (col->type == RAY_STR) {
const ray_str_t* elems;
const char* pool;
str_resolve(col, &elems, &pool);
cmp = ray_str_t_cmp(&elems[a], pool, &elems[b], pool);
}
if (desc && !null_cmp) cmp = -cmp;
if (cmp != 0) return cmp;
}
return 0;
}
static void key_sift_down(uint64_t* keys, int64_t* idx, int64_t n, int64_t i) {
for (;;) {
int64_t largest = i, l = 2*i+1, r = 2*i+2;
if (l < n && keys[l] > keys[largest]) largest = l;
if (r < n && keys[r] > keys[largest]) largest = r;
if (largest == i) return;
uint64_t tk = keys[i]; keys[i] = keys[largest]; keys[largest] = tk;
int64_t ti = idx[i]; idx[i] = idx[largest]; idx[largest] = ti;
i = largest;
}
}
static void key_heapsort(uint64_t* keys, int64_t* idx, int64_t n) {
for (int64_t i = n/2 - 1; i >= 0; i--)
key_sift_down(keys, idx, n, i);
for (int64_t i = n - 1; i > 0; i--) {
uint64_t tk = keys[0]; keys[0] = keys[i]; keys[i] = tk;
int64_t ti = idx[0]; idx[0] = idx[i]; idx[i] = ti;
key_sift_down(keys, idx, i, 0);
}
}
static void key_insertion_sort(uint64_t* keys, int64_t* idx, int64_t n) {
for (int64_t i = 1; i < n; i++) {
uint64_t kk = keys[i];
int64_t ii = idx[i];
int64_t j = i - 1;
while (j >= 0 && keys[j] > kk) {
keys[j+1] = keys[j];
idx[j+1] = idx[j];
j--;
}
keys[j+1] = kk;
idx[j+1] = ii;
}
}
static void key_introsort_impl(uint64_t* keys, int64_t* idx,
int64_t n, int depth) {
while (n > 32) {
if (depth == 0) {
key_heapsort(keys, idx, n);
return;
}
depth--;
int64_t mid = n / 2;
uint64_t a = keys[0], b = keys[mid], c = keys[n-1];
int64_t pi;
if (a < b) pi = (b < c) ? mid : (a < c ? n-1 : 0);
else pi = (a < c) ? 0 : (b < c ? n-1 : mid);
uint64_t pk = keys[pi]; keys[pi] = keys[n-1]; keys[n-1] = pk;
int64_t pv = idx[pi]; idx[pi] = idx[n-1]; idx[n-1] = pv;
int64_t lo = 0;
for (int64_t i = 0; i < n - 1; i++) {
if (keys[i] < pk) {
uint64_t tk = keys[i]; keys[i] = keys[lo]; keys[lo] = tk;
int64_t ti = idx[i]; idx[i] = idx[lo]; idx[lo] = ti;
lo++;
}
}
keys[n-1] = keys[lo]; keys[lo] = pk;
idx[n-1] = idx[lo]; idx[lo] = pv;
if (lo < n - 1 - lo) {
key_introsort_impl(keys, idx, lo, depth);
keys += lo + 1; idx += lo + 1; n -= lo + 1;
} else {
key_introsort_impl(keys + lo + 1, idx + lo + 1, n - lo - 1, depth);
n = lo;
}
}
key_insertion_sort(keys, idx, n);
}
void key_introsort(uint64_t* keys, int64_t* idx, int64_t n) {
if (n <= 1) return;
int depth = 0;
for (int64_t nn = n; nn > 1; nn >>= 1) depth++;
depth *= 2;
key_introsort_impl(keys, idx, n, depth);
}
typedef struct {
const uint64_t* keys;
int64_t* pw_unsorted;
} sortedness_ctx_t;
static void sortedness_fn(void* arg, uint32_t wid, int64_t start, int64_t end) {
sortedness_ctx_t* c = (sortedness_ctx_t*)arg;
const uint64_t* keys = c->keys;
int64_t unsorted = 0;
for (int64_t i = start + 1; i < end; i++) {
if (keys[i] < keys[i - 1]) unsorted++;
}
c->pw_unsorted[wid] += unsorted;
}
double detect_sortedness(ray_pool_t* pool, const uint64_t* keys, int64_t n) {
if (n <= 1) return 0.0;
int64_t total_unsorted;
if (pool && n > SMALL_POOL_THRESHOLD) {
uint32_t nw = ray_pool_total_workers(pool);
int64_t pw[nw];
memset(pw, 0, (size_t)nw * sizeof(int64_t));
sortedness_ctx_t ctx = { .keys = keys, .pw_unsorted = pw };
ray_pool_dispatch(pool, sortedness_fn, &ctx, n);
total_unsorted = 0;
for (uint32_t t = 0; t < nw; t++)
total_unsorted += pw[t];
int64_t grain = RAY_DISPATCH_MORSELS * RAY_MORSEL_ELEMS;
for (int64_t b = grain; b < n; b += grain) {
if (keys[b] < keys[b - 1])
total_unsorted++;
}
} else {
total_unsorted = 0;
for (int64_t i = 1; i < n; i++) {
if (keys[i] < keys[i - 1]) total_unsorted++;
}
}
return (double)total_unsorted / (double)(n - 1);
}
typedef struct {
const uint64_t* keys;
uint64_t* pw_or;
} key_range_ctx_t;
static void key_range_fn(void* arg, uint32_t wid, int64_t start, int64_t end) {
key_range_ctx_t* c = (key_range_ctx_t*)arg;
const uint64_t* keys = c->keys;
uint64_t local_or = c->pw_or[wid];
uint64_t first = keys[start];
for (int64_t i = start; i < end; i++)
local_or |= keys[i] ^ first;
c->pw_or[wid] = local_or;
}
uint8_t compute_key_nbytes(ray_pool_t* pool, const uint64_t* keys,
int64_t n, uint8_t type_max) {
if (n <= 1) return 1;
uint64_t diff;
if (pool && n > SMALL_POOL_THRESHOLD) {
uint32_t nw = ray_pool_total_workers(pool);
uint64_t pw_or[nw];
memset(pw_or, 0, nw * sizeof(uint64_t));
key_range_ctx_t ctx = { .keys = keys, .pw_or = pw_or };
ray_pool_dispatch(pool, key_range_fn, &ctx, n);
diff = 0;
for (uint32_t w = 0; w < nw; w++) diff |= pw_or[w];
uint64_t first = keys[0];
int64_t chunk = (n + nw - 1) / nw;
for (uint32_t w = 1; w < nw; w++) {
int64_t wstart = (int64_t)w * chunk;
if (wstart < n) diff |= keys[wstart] ^ first;
}
} else {
diff = 0;
uint64_t first = keys[0];
for (int64_t i = 1; i < n; i++)
diff |= keys[i] ^ first;
}
uint8_t nb = 0;
while (diff) { nb++; diff >>= 8; }
if (nb < 1) nb = 1;
return nb < type_max ? nb : type_max;
}
static void radix_hist_fn(void* arg, uint32_t wid, int64_t start, int64_t end) {
(void)wid; (void)end;
radix_pass_ctx_t* c = (radix_pass_ctx_t*)arg;
int64_t task = start;
uint32_t* h = c->hist + task * 256;
memset(h, 0, 256 * sizeof(uint32_t));
int64_t chunk = (c->n + c->n_tasks - 1) / c->n_tasks;
int64_t lo = task * chunk;
int64_t hi = lo + chunk;
if (hi > c->n) hi = c->n;
if (lo >= hi) return;
const uint64_t* keys = c->keys;
uint8_t shift = c->shift;
for (int64_t i = lo; i < hi; i++)
h[(keys[i] >> shift) & 0xFF]++;
}
#define SWC_N 8
static void radix_scatter_fn(void* arg, uint32_t wid, int64_t start, int64_t end) {
(void)wid; (void)end;
radix_pass_ctx_t* c = (radix_pass_ctx_t*)arg;
int64_t task = start;
int64_t chunk = (c->n + c->n_tasks - 1) / c->n_tasks;
int64_t lo = task * chunk;
int64_t hi = lo + chunk;
if (hi > c->n) hi = c->n;
if (lo >= hi) return;
int64_t* off = c->offsets + task * 256;
const uint64_t* k_in = c->keys;
const int64_t* i_in = c->idx;
uint64_t* k_out = c->keys_out;
int64_t* i_out = c->idx_out;
uint8_t shift = c->shift;
uint64_t kbuf[256][SWC_N];
int64_t ibuf[256][SWC_N];
uint8_t bcnt[256];
memset(bcnt, 0, 256);
for (int64_t i = lo; i < hi; i++) {
uint8_t byte = (k_in[i] >> shift) & 0xFF;
uint8_t bp = bcnt[byte];
kbuf[byte][bp] = k_in[i];
ibuf[byte][bp] = i_in[i];
if (++bp == SWC_N) {
int64_t pos = off[byte];
memcpy(&k_out[pos], kbuf[byte], SWC_N * sizeof(uint64_t));
memcpy(&i_out[pos], ibuf[byte], SWC_N * sizeof(int64_t));
off[byte] = pos + SWC_N;
bp = 0;
}
bcnt[byte] = bp;
}
for (int b = 0; b < 256; b++) {
int64_t pos = off[b];
for (uint8_t j = 0; j < bcnt[b]; j++) {
k_out[pos + j] = kbuf[b][j];
i_out[pos + j] = ibuf[b][j];
}
off[b] = pos + bcnt[b];
}
}
#undef SWC_N
int64_t* radix_sort_run(ray_pool_t* pool,
uint64_t* keys, int64_t* indices,
uint64_t* keys_tmp, int64_t* idx_tmp,
int64_t n, uint8_t n_bytes,
uint64_t** sorted_keys_out) {
uint32_t n_tasks = pool ? ray_pool_total_workers(pool) : 1;
if (n_tasks < 1) n_tasks = 1;
ray_t *hist_hdr = NULL, *off_hdr = NULL;
uint32_t* hist = (uint32_t*)scratch_alloc(&hist_hdr,
(size_t)n_tasks * 256 * sizeof(uint32_t));
int64_t* offsets = (int64_t*)scratch_alloc(&off_hdr,
(size_t)n_tasks * 256 * sizeof(int64_t));
if (!hist || !offsets) {
scratch_free(hist_hdr); scratch_free(off_hdr);
return NULL;
}
uint64_t* src_k = keys, *dst_k = keys_tmp;
int64_t* src_i = indices, *dst_i = idx_tmp;
for (uint8_t bp = 0; bp < n_bytes; bp++) {
uint8_t shift = bp * 8;
radix_pass_ctx_t ctx = {
.keys = src_k, .idx = src_i,
.keys_out = dst_k, .idx_out = dst_i,
.n = n, .shift = shift, .n_tasks = n_tasks,
.hist = hist, .offsets = offsets,
};
if (pool && n_tasks > 1)
ray_pool_dispatch_n(pool, radix_hist_fn, &ctx, n_tasks);
else
radix_hist_fn(&ctx, 0, 0, 1);
bool uniform = false;
for (int b = 0; b < 256; b++) {
uint32_t total = 0;
for (uint32_t t = 0; t < n_tasks; t++)
total += hist[t * 256 + b];
if (total == (uint32_t)n) { uniform = true; break; }
}
if (uniform) continue;
int64_t running = 0;
for (int b = 0; b < 256; b++) {
for (uint32_t t = 0; t < n_tasks; t++) {
offsets[t * 256 + b] = running;
running += hist[t * 256 + b];
}
}
if (pool && n_tasks > 1)
ray_pool_dispatch_n(pool, radix_scatter_fn, &ctx, n_tasks);
else
radix_scatter_fn(&ctx, 0, 0, 1);
uint64_t* tk = src_k; src_k = dst_k; dst_k = tk;
int64_t* ti = src_i; src_i = dst_i; dst_i = ti;
}
scratch_free(hist_hdr);
scratch_free(off_hdr);
if (sorted_keys_out) *sorted_keys_out = src_k;
return src_i;
}
#define PSWC_N 8
static void packed_scatter_fn(void* arg, uint32_t wid, int64_t start, int64_t end) {
(void)wid; (void)end;
radix_pass_ctx_t* c = (radix_pass_ctx_t*)arg;
int64_t task = start;
int64_t chunk = (c->n + c->n_tasks - 1) / c->n_tasks;
int64_t lo = task * chunk;
int64_t hi = lo + chunk;
if (hi > c->n) hi = c->n;
if (lo >= hi) return;
int64_t* off = c->offsets + task * 256;
const uint64_t* in = c->keys;
uint64_t* out = c->keys_out;
uint8_t shift = c->shift;
uint64_t buf[256][PSWC_N];
uint8_t bcnt[256];
memset(bcnt, 0, 256);
for (int64_t i = lo; i < hi; i++) {
uint8_t byte = (in[i] >> shift) & 0xFF;
uint8_t bp = bcnt[byte];
buf[byte][bp] = in[i];
if (++bp == PSWC_N) {
int64_t pos = off[byte];
memcpy(&out[pos], buf[byte], PSWC_N * sizeof(uint64_t));
off[byte] = pos + PSWC_N;
bp = 0;
}
bcnt[byte] = bp;
}
for (int b = 0; b < 256; b++) {
int64_t pos = off[b];
for (uint8_t j = 0; j < bcnt[b]; j++)
out[pos + j] = buf[b][j];
off[b] = pos + bcnt[b];
}
}
#undef PSWC_N
uint64_t* packed_radix_sort_run(ray_pool_t* pool,
uint64_t* data, uint64_t* tmp,
int64_t n, uint8_t n_bytes) {
uint32_t n_tasks = pool ? ray_pool_total_workers(pool) : 1;
if (n_tasks < 1) n_tasks = 1;
ray_t *hist_hdr = NULL, *off_hdr = NULL;
uint32_t* hist = (uint32_t*)scratch_alloc(&hist_hdr,
(size_t)n_tasks * 256 * sizeof(uint32_t));
int64_t* offsets = (int64_t*)scratch_alloc(&off_hdr,
(size_t)n_tasks * 256 * sizeof(int64_t));
if (!hist || !offsets) {
scratch_free(hist_hdr); scratch_free(off_hdr);
return NULL;
}
uint64_t* src = data, *dst = tmp;
for (uint8_t bp = 0; bp < n_bytes; bp++) {
uint8_t shift = bp * 8;
radix_pass_ctx_t ctx = {
.keys = src, .keys_out = dst,
.n = n, .shift = shift, .n_tasks = n_tasks,
.hist = hist, .offsets = offsets,
};
if (pool && n_tasks > 1)
ray_pool_dispatch_n(pool, radix_hist_fn, &ctx, n_tasks);
else
radix_hist_fn(&ctx, 0, 0, 1);
bool uniform = false;
for (int b = 0; b < 256; b++) {
uint32_t total = 0;
for (uint32_t t = 0; t < n_tasks; t++)
total += hist[t * 256 + b];
if (total == (uint32_t)n) { uniform = true; break; }
}
if (uniform) continue;
int64_t running = 0;
for (int b = 0; b < 256; b++) {
for (uint32_t t = 0; t < n_tasks; t++) {
offsets[t * 256 + b] = running;
running += hist[t * 256 + b];
}
}
if (pool && n_tasks > 1)
ray_pool_dispatch_n(pool, packed_scatter_fn, &ctx, n_tasks);
else
packed_scatter_fn(&ctx, 0, 0, 1);
uint64_t* t2 = src; src = dst; dst = t2;
}
scratch_free(hist_hdr);
scratch_free(off_hdr);
return src;
}
typedef struct {
uint64_t* keys;
uint8_t key_bits;
uint64_t key_mask;
int64_t* pw_unsorted;
int64_t* pw_not_reverse;
} packed_detect_ctx_t;
static void packed_detect_fn(void* arg, uint32_t wid,
int64_t start, int64_t end) {
packed_detect_ctx_t* c = (packed_detect_ctx_t*)arg;
uint64_t* k = c->keys;
uint8_t kb = c->key_bits;
uint64_t km = c->key_mask;
int64_t unsorted = 0, not_rev = 0;
uint64_t prev = (start > 0) ? (k[start - 1] & km) : 0;
for (int64_t i = start; i < end; i++) {
uint64_t cur = k[i] & km;
if (i > start) {
if (cur < prev) unsorted++;
if (cur > prev) not_rev++;
}
k[i] = cur | ((uint64_t)i << kb);
prev = cur;
}
c->pw_unsorted[wid] += unsorted;
c->pw_not_reverse[wid] += not_rev;
}
typedef struct {
const uint64_t* sorted;
int64_t* indices;
uint64_t* keys_out;
uint8_t key_bits;
uint64_t idx_mask;
uint64_t key_mask;
bool extract_keys;
} packed_unpack_ctx_t;
static void packed_unpack_fn(void* arg, uint32_t wid,
int64_t start, int64_t end) {
(void)wid;
packed_unpack_ctx_t* c = (packed_unpack_ctx_t*)arg;
for (int64_t i = start; i < end; i++) {
uint64_t v = c->sorted[i];
c->indices[i] = (int64_t)((v >> c->key_bits) & c->idx_mask);
if (c->extract_keys)
c->keys_out[i] = v & c->key_mask;
}
}
static int64_t* bucket_lsb_sort(uint64_t* keys, int64_t* idx,
uint64_t* ktmp, int64_t* itmp,
int64_t n, uint8_t n_bytes) {
if (n <= 64) {
key_introsort(keys, idx, n);
return idx;
}
uint64_t* src_k = keys, *dst_k = ktmp;
int64_t* src_i = idx, *dst_i = itmp;
for (uint8_t bp = 0; bp < n_bytes; bp++) {
uint8_t shift = bp * 8;
uint32_t hist[256];
memset(hist, 0, sizeof(hist));
for (int64_t i = 0; i < n; i++)
hist[(src_k[i] >> shift) & 0xFF]++;
bool uniform = false;
for (int b = 0; b < 256; b++) {
if (hist[b] == (uint32_t)n) { uniform = true; break; }
}
if (uniform) continue;
int64_t off[256];
off[0] = 0;
for (int b = 1; b < 256; b++)
off[b] = off[b-1] + (int64_t)hist[b-1];
for (int64_t i = 0; i < n; i++) {
uint8_t byte = (src_k[i] >> shift) & 0xFF;
int64_t pos = off[byte]++;
dst_k[pos] = src_k[i];
dst_i[pos] = src_i[i];
}
uint64_t* tk = src_k; src_k = dst_k; dst_k = tk;
int64_t* ti = src_i; src_i = dst_i; dst_i = ti;
}
return src_i;
}
typedef struct {
uint64_t* data_k;
int64_t* data_i;
uint64_t* tmp_k;
int64_t* tmp_i;
int64_t bucket_offsets[257];
uint8_t n_bytes;
} msd_bucket_ctx_t;
static void msd_bucket_sort_fn(void* arg, uint32_t wid,
int64_t start, int64_t end) {
(void)wid;
msd_bucket_ctx_t* c = (msd_bucket_ctx_t*)arg;
for (int64_t b = start; b < end; b++) {
int64_t off = c->bucket_offsets[b];
int64_t cnt = c->bucket_offsets[b + 1] - off;
if (cnt <= 1) continue;
int64_t* sorted = bucket_lsb_sort(
c->data_k + off, c->data_i + off,
c->tmp_k + off, c->tmp_i + off,
cnt, c->n_bytes);
if (sorted != c->data_i + off) {
memcpy(c->data_k + off, c->tmp_k + off,
(size_t)cnt * sizeof(uint64_t));
memcpy(c->data_i + off, c->tmp_i + off,
(size_t)cnt * sizeof(int64_t));
}
}
}
int64_t* msd_radix_sort_run(ray_pool_t* pool,
uint64_t* keys, int64_t* indices,
uint64_t* keys_tmp, int64_t* idx_tmp,
int64_t n, uint8_t n_bytes,
uint64_t** sorted_keys_out) {
if (n_bytes <= 5 || n <= 1000000) {
return radix_sort_run(pool, keys, indices, keys_tmp, idx_tmp,
n, n_bytes, sorted_keys_out);
}
uint32_t n_tasks = pool ? ray_pool_total_workers(pool) : 1;
if (n_tasks < 1) n_tasks = 1;
ray_t *hist_hdr = NULL, *off_hdr = NULL;
uint32_t* hist = (uint32_t*)scratch_alloc(&hist_hdr,
(size_t)n_tasks * 256 * sizeof(uint32_t));
int64_t* offsets = (int64_t*)scratch_alloc(&off_hdr,
(size_t)n_tasks * 256 * sizeof(int64_t));
if (!hist || !offsets) {
scratch_free(hist_hdr); scratch_free(off_hdr);
return radix_sort_run(pool, keys, indices, keys_tmp, idx_tmp,
n, n_bytes, sorted_keys_out);
}
uint8_t msd_byte = n_bytes - 1;
uint8_t shift = msd_byte * 8;
radix_pass_ctx_t ctx = {
.keys = keys, .idx = indices,
.keys_out = keys_tmp, .idx_out = idx_tmp,
.n = n, .shift = shift, .n_tasks = n_tasks,
.hist = hist, .offsets = offsets,
};
if (pool && n_tasks > 1)
ray_pool_dispatch_n(pool, radix_hist_fn, &ctx, n_tasks);
else
radix_hist_fn(&ctx, 0, 0, 1);
bool uniform = false;
for (int b = 0; b < 256; b++) {
uint32_t total = 0;
for (uint32_t t = 0; t < n_tasks; t++)
total += hist[t * 256 + b];
if (total == (uint32_t)n) { uniform = true; break; }
}
if (uniform) {
scratch_free(hist_hdr); scratch_free(off_hdr);
return msd_radix_sort_run(pool, keys, indices, keys_tmp, idx_tmp,
n, n_bytes - 1, sorted_keys_out);
}
int64_t bucket_offsets[257];
{
int64_t running = 0;
for (int b = 0; b < 256; b++) {
bucket_offsets[b] = running;
for (uint32_t t = 0; t < n_tasks; t++) {
offsets[t * 256 + b] = running;
running += hist[t * 256 + b];
}
}
bucket_offsets[256] = running;
}
if (pool && n_tasks > 1)
ray_pool_dispatch_n(pool, radix_scatter_fn, &ctx, n_tasks);
else
radix_scatter_fn(&ctx, 0, 0, 1);
scratch_free(hist_hdr);
scratch_free(off_hdr);
uint8_t remaining_bytes = msd_byte;
msd_bucket_ctx_t bctx = {
.data_k = keys_tmp, .data_i = idx_tmp,
.tmp_k = keys, .tmp_i = indices,
.n_bytes = remaining_bytes,
};
memcpy(bctx.bucket_offsets, bucket_offsets, sizeof(bucket_offsets));
if (pool)
ray_pool_dispatch_n(pool, msd_bucket_sort_fn, &bctx, 256);
else
msd_bucket_sort_fn(&bctx, 0, 0, 256);
if (sorted_keys_out) *sorted_keys_out = keys_tmp;
return idx_tmp;
}
void radix_encode_fn(void* arg, uint32_t wid, int64_t start, int64_t end) {
(void)wid;
radix_encode_ctx_t* c = (radix_encode_ctx_t*)arg;
if (c->indices) {
int64_t* idx = c->indices;
for (int64_t i = start; i < end; i++) idx[i] = i;
}
if (c->n_keys <= 1) {
switch (c->type) {
case RAY_I64: case RAY_TIMESTAMP: {
const int64_t* d = (const int64_t*)c->data;
bool has_nulls = c->col && (c->col->attrs & RAY_ATTR_HAS_NULLS);
bool nf = c->nulls_first;
bool desc = c->desc;
uint64_t null_e = (nf ^ desc) ? 0 : UINT64_MAX;
if (desc) {
for (int64_t i = start; i < end; i++) {
if (has_nulls && ray_vec_is_null(c->col, i))
c->keys[i] = ~null_e;
else
c->keys[i] = ~((uint64_t)d[i] ^ ((uint64_t)1 << 63));
}
} else {
for (int64_t i = start; i < end; i++) {
if (has_nulls && ray_vec_is_null(c->col, i))
c->keys[i] = null_e;
else
c->keys[i] = (uint64_t)d[i] ^ ((uint64_t)1 << 63);
}
}
break;
}
case RAY_F64: {
const double* d = (const double*)c->data;
bool nf = c->nulls_first;
bool desc = c->desc;
uint64_t nan_e = (nf ^ desc) ? 0 : UINT64_MAX;
for (int64_t i = start; i < end; i++) {
uint64_t bits;
memcpy(&bits, &d[i], 8);
if ((bits & 0x7FF0000000000000ULL) == 0x7FF0000000000000ULL &&
(bits & 0x000FFFFFFFFFFFFFULL)) {
c->keys[i] = desc ? ~nan_e : nan_e;
} else {
uint64_t mask = -(bits >> 63) | ((uint64_t)1 << 63);
uint64_t e = bits ^ mask;
c->keys[i] = desc ? ~e : e;
}
}
break;
}
case RAY_I32: case RAY_DATE: case RAY_TIME: {
const int32_t* d = (const int32_t*)c->data;
bool has_nulls = c->col && (c->col->attrs & RAY_ATTR_HAS_NULLS);
bool nf = c->nulls_first;
bool desc = c->desc;
uint64_t null_e = (nf ^ desc) ? 0 : UINT64_MAX;
if (desc) {
for (int64_t i = start; i < end; i++) {
if (has_nulls && ray_vec_is_null(c->col, i))
c->keys[i] = ~null_e;
else
c->keys[i] = ~((uint64_t)((uint32_t)d[i] ^ ((uint32_t)1 << 31)));
}
} else {
for (int64_t i = start; i < end; i++) {
if (has_nulls && ray_vec_is_null(c->col, i))
c->keys[i] = null_e;
else
c->keys[i] = (uint64_t)((uint32_t)d[i] ^ ((uint32_t)1 << 31));
}
}
break;
}
case RAY_SYM: {
const uint32_t* rank = c->enum_rank;
if (c->desc) {
for (int64_t i = start; i < end; i++) {
uint32_t raw = (uint32_t)ray_read_sym(c->data, i, c->type, c->col_attrs);
c->keys[i] = ~(uint64_t)rank[raw];
}
} else {
for (int64_t i = start; i < end; i++) {
uint32_t raw = (uint32_t)ray_read_sym(c->data, i, c->type, c->col_attrs);
c->keys[i] = (uint64_t)rank[raw];
}
}
break;
}
case RAY_I16: {
const int16_t* d = (const int16_t*)c->data;
if (c->desc) {
for (int64_t i = start; i < end; i++)
c->keys[i] = ~((uint64_t)((uint16_t)d[i] ^ ((uint16_t)1 << 15)));
} else {
for (int64_t i = start; i < end; i++)
c->keys[i] = (uint64_t)((uint16_t)d[i] ^ ((uint16_t)1 << 15));
}
break;
}
case RAY_BOOL: case RAY_U8: {
const uint8_t* d = (const uint8_t*)c->data;
if (c->desc) {
for (int64_t i = start; i < end; i++)
c->keys[i] = ~(uint64_t)d[i];
} else {
for (int64_t i = start; i < end; i++)
c->keys[i] = (uint64_t)d[i];
}
break;
}
}
} else {
for (int64_t i = start; i < end; i++) {
uint64_t composite = 0;
for (uint8_t k = 0; k < c->n_keys; k++) {
ray_t* col = c->vecs[k];
int64_t val;
if (c->enum_ranks[k]) {
uint32_t raw = (uint32_t)ray_read_sym(ray_data(col), i, col->type, col->attrs);
val = (int64_t)c->enum_ranks[k][raw];
} else if (col->type == RAY_I64 || col->type == RAY_TIMESTAMP) {
val = ((const int64_t*)ray_data(col))[i];
} else if (col->type == RAY_F64) {
uint64_t bits;
memcpy(&bits, &((const double*)ray_data(col))[i], 8);
uint64_t mask = -(bits >> 63) | ((uint64_t)1 << 63);
val = (int64_t)(bits ^ mask);
} else if (col->type == RAY_I32 || col->type == RAY_DATE || col->type == RAY_TIME) {
val = (int64_t)((const int32_t*)ray_data(col))[i];
} else if (col->type == RAY_I16) {
val = (int64_t)((const int16_t*)ray_data(col))[i];
} else if (col->type == RAY_BOOL || col->type == RAY_U8) {
val = (int64_t)((const uint8_t*)ray_data(col))[i];
} else {
val = 0;
}
uint64_t part = (uint64_t)val - (uint64_t)c->mins[k];
if (c->descs[k]) part = (uint64_t)c->ranges[k] - part;
composite |= part << c->bit_shifts[k];
}
c->keys[i] = composite;
}
}
}
#define RAY_STRSORT_KEY_PARTS_MAX 4
#define RAY_STRSORT_BASE_CASE 24
#define RAY_STRSORT_PROBE_HEAD 1024
typedef struct {
uint64_t parts[RAY_STRSORT_KEY_PARTS_MAX];
uint32_t row;
uint32_t len;
} ray_strkey_t;
static inline uint64_t strkey_lex_u64(uint64_t v) {
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
return __builtin_bswap64(v);
#else
return v;
#endif
}
static inline uint64_t strkey_load_part(const char* src, int64_t len, int offset) {
int64_t remaining = len - offset;
if (remaining <= 0) return 0;
uint64_t raw = 0;
int64_t take = remaining < 8 ? remaining : 8;
memcpy(&raw, src + offset, (size_t)take);
return strkey_lex_u64(raw);
}
static int strkey_cmp(const ray_strkey_t* a, const ray_strkey_t* b,
int parts,
const ray_str_t* elems, const char* pool) {
for (int p = 0; p < parts; p++) {
if (a->parts[p] < b->parts[p]) return -1;
if (a->parts[p] > b->parts[p]) return 1;
}
int64_t parts_bytes = (int64_t)parts * 8;
if ((int64_t)a->len <= parts_bytes && (int64_t)b->len <= parts_bytes) {
return (int)a->len - (int)b->len;
}
const ray_str_t* sa = &elems[a->row];
const ray_str_t* sb = &elems[b->row];
const char* pa = ray_str_t_ptr(sa, pool);
const char* pb = ray_str_t_ptr(sb, pool);
int64_t la = (int64_t)sa->len - parts_bytes; if (la < 0) la = 0;
int64_t lb = (int64_t)sb->len - parts_bytes; if (lb < 0) lb = 0;
int64_t m = la < lb ? la : lb;
int r = m ? memcmp(pa + parts_bytes, pb + parts_bytes, (size_t)m) : 0;
if (r != 0) return r;
return (la > lb) - (la < lb);
}
static void strkey_insertion_sort(ray_strkey_t* a, int64_t n, int parts,
const ray_str_t* elems, const char* pool) {
for (int64_t i = 1; i < n; i++) {
ray_strkey_t cur = a[i];
int64_t j = i - 1;
while (j >= 0 && strkey_cmp(&a[j], &cur, parts, elems, pool) > 0) {
a[j + 1] = a[j];
j--;
}
a[j + 1] = cur;
}
}
static inline uint8_t strkey_byte_at(const ray_strkey_t* k, int bp) {
int part = bp >> 3;
int shift = 56 - ((bp & 7) << 3);
return (uint8_t)(k->parts[part] >> shift);
}
static int strsort_probe_parts(const int64_t* indices, int64_t n_live,
const ray_str_t* elems) {
int64_t max_len = 0;
for (int64_t i = 0; i < n_live; i++) {
int64_t l = (int64_t)elems[indices[i]].len;
if (l > max_len) max_len = l;
}
int64_t pcalc = (max_len + 7) / 8;
if (pcalc < 1) pcalc = 1;
if (pcalc > RAY_STRSORT_KEY_PARTS_MAX) pcalc = RAY_STRSORT_KEY_PARTS_MAX;
return (int)pcalc;
}
typedef struct {
ray_strkey_t* out;
const int64_t* indices;
const ray_str_t* elems;
const char* pool;
int parts;
} strsort_build_ctx_t;
static void strsort_build_fn(void* vctx, uint32_t wid, int64_t s, int64_t e) {
(void)wid;
strsort_build_ctx_t* c = (strsort_build_ctx_t*)vctx;
for (int64_t i = s; i < e; i++) {
int64_t row = c->indices[i];
const ray_str_t* str = &c->elems[row];
c->out[i].row = (uint32_t)row;
c->out[i].len = str->len;
int64_t len = str->len;
const char* src = len ? ray_str_t_ptr(str, c->pool) : NULL;
for (int p = 0; p < RAY_STRSORT_KEY_PARTS_MAX; p++) {
c->out[i].parts[p] = (p < c->parts)
? strkey_load_part(src, len, p * 8)
: 0;
}
}
}
static void strsort_build_keys(ray_strkey_t* out, int64_t n_live,
const int64_t* indices,
const ray_str_t* elems, const char* pool,
int parts) {
strsort_build_ctx_t c = { out, indices, elems, pool, parts };
ray_pool_t* p = ray_pool_get();
if (p && n_live >= RAY_PARALLEL_THRESHOLD) {
ray_pool_dispatch(p, strsort_build_fn, &c, n_live);
} else {
strsort_build_fn(&c, 0, 0, n_live);
}
}
typedef struct {
int64_t* out;
const ray_strkey_t* keys;
} strsort_emit_ctx_t;
static void strsort_emit_fn(void* vctx, uint32_t wid, int64_t s, int64_t e) {
(void)wid;
strsort_emit_ctx_t* c = (strsort_emit_ctx_t*)vctx;
for (int64_t i = s; i < e; i++) c->out[i] = (int64_t)c->keys[i].row;
}
static inline int strkey_cmp_packed(const ray_strkey_t* a,
const ray_strkey_t* b, int parts) {
for (int p = 0; p < parts; p++) {
if (a->parts[p] < b->parts[p]) return -1;
if (a->parts[p] > b->parts[p]) return 1;
}
return (int)a->len - (int)b->len;
}
static int strsort_detect_runs(const ray_strkey_t* keys, int64_t n,
int parts, int parts_bytes) {
if (n < 2) return 0;
bool asc = true, desc = true;
for (int64_t i = 1; i < n; i++) {
int r = strkey_cmp_packed(&keys[i - 1], &keys[i], parts);
if (r == 0) {
if ((int64_t)keys[i - 1].len > parts_bytes ||
(int64_t)keys[i].len > parts_bytes) {
return 0;
}
} else if (r > 0) {
asc = false;
} else {
desc = false;
}
if (!asc && !desc) return 0;
}
if (asc) return 1;
if (desc) return -1;
return 0;
}
typedef struct {
const ray_strkey_t* src;
ray_strkey_t* dst;
int64_t n;
uint32_t n_tasks;
uint32_t* hist;
int64_t* offsets;
} strsort_top_ctx_t;
static void strsort_top_hist_fn(void* vctx, uint32_t wid,
int64_t start, int64_t end) {
(void)wid; (void)end;
strsort_top_ctx_t* c = (strsort_top_ctx_t*)vctx;
int64_t task = start;
uint32_t* h = c->hist + task * 256;
memset(h, 0, 256 * sizeof(uint32_t));
int64_t chunk = (c->n + c->n_tasks - 1) / c->n_tasks;
int64_t lo = task * chunk;
int64_t hi = lo + chunk;
if (hi > c->n) hi = c->n;
if (lo >= hi) return;
const ray_strkey_t* src = c->src;
for (int64_t i = lo; i < hi; i++) {
h[strkey_byte_at(&src[i], 0)]++;
}
}
static void strsort_top_scatter_fn(void* vctx, uint32_t wid,
int64_t start, int64_t end) {
(void)wid; (void)end;
strsort_top_ctx_t* c = (strsort_top_ctx_t*)vctx;
int64_t task = start;
int64_t chunk = (c->n + c->n_tasks - 1) / c->n_tasks;
int64_t lo = task * chunk;
int64_t hi = lo + chunk;
if (hi > c->n) hi = c->n;
if (lo >= hi) return;
int64_t* off = c->offsets + task * 256;
const ray_strkey_t* src = c->src;
ray_strkey_t* dst = c->dst;
for (int64_t i = lo; i < hi; i++) {
uint8_t b = strkey_byte_at(&src[i], 0);
dst[off[b]++] = src[i];
}
}
typedef struct {
ray_strkey_t* keys;
const int64_t* starts;
const int64_t* counts;
int parts_bytes;
int64_t base_offset;
const ray_str_t* elems;
const char* pool;
int parts;
int start_bp;
} strsort_bucket_ctx_t;
static void strsort_aflag(ray_strkey_t* keys, int64_t n, int bp,
int parts_bytes, int64_t base_offset,
const ray_str_t* elems, const char* pool,
int parts);
static void strsort_bucket_fn(void* vctx, uint32_t wid, int64_t s, int64_t e) {
(void)wid;
strsort_bucket_ctx_t* c = (strsort_bucket_ctx_t*)vctx;
for (int64_t b = s; b < e; b++) {
int64_t cnt = c->counts[b];
if (cnt <= 1) continue;
strsort_aflag(c->keys + c->starts[b], cnt, c->start_bp,
c->parts_bytes, c->base_offset,
c->elems, c->pool, c->parts);
}
}
static void strkey_qsort_by_len(ray_strkey_t* a, int64_t lo, int64_t hi) {
while (hi - lo > 16) {
int64_t mid = lo + (hi - lo) / 2;
if (a[lo].len > a[hi].len) { ray_strkey_t t=a[lo]; a[lo]=a[hi]; a[hi]=t; }
if (a[mid].len > a[hi].len) { ray_strkey_t t=a[mid]; a[mid]=a[hi]; a[hi]=t; }
if (a[lo].len > a[mid].len) { ray_strkey_t t=a[lo]; a[lo]=a[mid]; a[mid]=t; }
uint32_t pivot = a[mid].len;
int64_t i = lo - 1, j = hi + 1;
for (;;) {
do { i++; } while (a[i].len < pivot);
do { j--; } while (a[j].len > pivot);
if (i >= j) break;
ray_strkey_t t = a[i]; a[i] = a[j]; a[j] = t;
}
if (j - lo < hi - (j + 1)) {
strkey_qsort_by_len(a, lo, j);
lo = j + 1;
} else {
strkey_qsort_by_len(a, j + 1, hi);
hi = j;
}
}
for (int64_t i = lo + 1; i <= hi; i++) {
ray_strkey_t cur = a[i];
int64_t j = i - 1;
while (j >= lo && a[j].len > cur.len) {
a[j + 1] = a[j];
j--;
}
a[j + 1] = cur;
}
}
static bool strsort_repack_window(ray_strkey_t* keys, int64_t n,
int64_t base_offset,
const ray_str_t* elems, const char* pool,
int parts) {
bool any_tail = false;
uint32_t min_len = UINT32_MAX;
uint32_t max_len = 0;
for (int64_t i = 0; i < n; i++) {
const ray_str_t* s = &elems[keys[i].row];
int64_t len = s->len;
if (len > base_offset) any_tail = true;
if ((uint32_t)len < min_len) min_len = (uint32_t)len;
if ((uint32_t)len > max_len) max_len = (uint32_t)len;
const char* src = len > 0 ? ray_str_t_ptr(s, pool) : NULL;
for (int p = 0; p < parts; p++) {
int64_t off = base_offset + (int64_t)p * 8;
keys[i].parts[p] = (src && len > off)
? strkey_load_part(src, len, (int)off)
: 0;
}
}
if (!any_tail && n > 1 && min_len != max_len) {
strkey_qsort_by_len(keys, 0, n - 1);
}
return any_tail;
}
static void strsort_aflag(ray_strkey_t* keys, int64_t n, int bp,
int parts_bytes, int64_t base_offset,
const ray_str_t* elems, const char* pool,
int parts) {
for (;;) {
if (n <= 1) return;
if (n <= RAY_STRSORT_BASE_CASE) {
strkey_insertion_sort(keys, n, parts, elems, pool);
return;
}
if (bp >= parts_bytes) {
int64_t next_offset = base_offset + parts_bytes;
if (!strsort_repack_window(keys, n, next_offset,
elems, pool, parts)) {
return;
}
base_offset = next_offset;
bp = 0;
continue;
}
int64_t counts[256] = {0};
for (int64_t i = 0; i < n; i++) {
counts[strkey_byte_at(&keys[i], bp)]++;
}
int uniq_b = -1;
bool uniform = true;
for (int b = 0; b < 256; b++) {
if (counts[b] == 0) continue;
if (uniq_b < 0) uniq_b = b;
else { uniform = false; break; }
}
if (uniform) {
bp++;
continue;
}
int64_t starts[256];
int64_t ends[256];
{
int64_t sum = 0;
for (int b = 0; b < 256; b++) {
starts[b] = sum;
sum += counts[b];
ends[b] = sum;
}
}
int64_t cursors[256];
memcpy(cursors, starts, sizeof(cursors));
for (int b = 0; b < 256; b++) {
while (cursors[b] < ends[b]) {
ray_strkey_t v = keys[cursors[b]];
int bb = strkey_byte_at(&v, bp);
while (bb != b) {
ray_strkey_t tmp = keys[cursors[bb]];
keys[cursors[bb]] = v;
cursors[bb]++;
v = tmp;
bb = strkey_byte_at(&v, bp);
}
keys[cursors[b]] = v;
cursors[b]++;
}
}
int big_b = 0;
int64_t big_cnt = counts[0];
for (int b = 1; b < 256; b++) {
if (counts[b] > big_cnt) { big_cnt = counts[b]; big_b = b; }
}
for (int b = 0; b < 256; b++) {
if (b == big_b) continue;
int64_t cnt = counts[b];
if (cnt > 1) {
strsort_aflag(keys + starts[b], cnt, bp + 1,
parts_bytes, base_offset, elems, pool, parts);
}
}
keys += starts[big_b];
n = big_cnt;
bp++;
}
}
static bool sort_str_msd_inplace(int64_t* sorted_idx, int64_t nrows,
ray_t* col, bool desc, bool nulls_first) {
if (nrows <= 0) return true;
for (int64_t i = 0; i < nrows; i++) sorted_idx[i] = i;
int64_t null_count = 0;
bool has_nulls = (col->attrs & RAY_ATTR_HAS_NULLS) ||
((col->attrs & RAY_ATTR_SLICE) && col->slice_parent &&
(col->slice_parent->attrs & RAY_ATTR_HAS_NULLS));
if (has_nulls) {
int64_t w = 0;
int64_t null_pos;
for (int64_t i = 0; i < nrows; i++) {
if (!ray_vec_is_null(col, i)) sorted_idx[w++] = i;
}
null_count = nrows - w;
null_pos = w;
for (int64_t i = 0; i < nrows; i++) {
if (ray_vec_is_null(col, i)) sorted_idx[null_pos++] = i;
}
}
int64_t n_live = nrows - null_count;
if (n_live > 1) {
const ray_str_t* elems;
const char* pool;
str_resolve(col, &elems, &pool);
ray_pool_t* pool_p = ray_pool_get();
bool go_parallel = (pool_p && n_live >= RAY_PARALLEL_THRESHOLD);
int parts = strsort_probe_parts(sorted_idx, n_live, elems);
int parts_bytes = parts * 8;
ray_t* keys_hdr = NULL;
ray_strkey_t* keys = (ray_strkey_t*)scratch_alloc(&keys_hdr,
(size_t)n_live * sizeof(ray_strkey_t));
if (!keys) return false;
strsort_build_keys(keys, n_live, sorted_idx, elems, pool, parts);
int run_dir = strsort_detect_runs(keys, n_live, parts, parts_bytes);
bool want_asc = !desc;
if (run_dir == 1 && want_asc) {
strsort_emit_ctx_t ectx = { sorted_idx, keys };
if (go_parallel)
ray_pool_dispatch(pool_p, strsort_emit_fn, &ectx, n_live);
else
strsort_emit_fn(&ectx, 0, 0, n_live);
} else if (run_dir == -1 && !want_asc) {
strsort_emit_ctx_t ectx = { sorted_idx, keys };
if (go_parallel)
ray_pool_dispatch(pool_p, strsort_emit_fn, &ectx, n_live);
else
strsort_emit_fn(&ectx, 0, 0, n_live);
} else if (run_dir != 0) {
for (int64_t i = 0, j = n_live - 1; i < j; i++, j--) {
ray_strkey_t t = keys[i]; keys[i] = keys[j]; keys[j] = t;
}
strsort_emit_ctx_t ectx = { sorted_idx, keys };
if (go_parallel)
ray_pool_dispatch(pool_p, strsort_emit_fn, &ectx, n_live);
else
strsort_emit_fn(&ectx, 0, 0, n_live);
} else {
ray_t* tmp_hdr = NULL;
ray_strkey_t* keys_sorted = keys;
if (!go_parallel || parts_bytes == 0) {
strsort_aflag(keys, n_live, 0, parts_bytes,
0, elems, pool, parts);
} else {
ray_strkey_t* tmp = (ray_strkey_t*)scratch_alloc(&tmp_hdr,
(size_t)n_live * sizeof(ray_strkey_t));
if (!tmp) {
strsort_aflag(keys, n_live, 0, parts_bytes,
0, elems, pool, parts);
} else {
uint32_t n_tasks = ray_pool_total_workers(pool_p);
if (n_tasks < 1) n_tasks = 1;
ray_t* hist_hdr = NULL;
ray_t* off_hdr = NULL;
uint32_t* hist = (uint32_t*)scratch_alloc(&hist_hdr,
(size_t)n_tasks * 256 * sizeof(uint32_t));
int64_t* off = (int64_t*)scratch_alloc(&off_hdr,
(size_t)n_tasks * 256 * sizeof(int64_t));
if (!hist || !off) {
scratch_free(hist_hdr); scratch_free(off_hdr);
strsort_aflag(keys, n_live, 0, parts_bytes,
0, elems, pool, parts);
} else {
strsort_top_ctx_t tctx = {
.src = keys, .dst = tmp, .n = n_live,
.n_tasks = n_tasks, .hist = hist, .offsets = off,
};
ray_pool_dispatch_n(pool_p, strsort_top_hist_fn,
&tctx, n_tasks);
int64_t bucket_counts[256];
int64_t bucket_starts[256];
int64_t sum = 0;
for (int b = 0; b < 256; b++) {
bucket_starts[b] = sum;
int64_t bc = 0;
for (uint32_t t = 0; t < n_tasks; t++) {
off[t * 256 + b] = sum + bc;
bc += hist[t * 256 + b];
}
bucket_counts[b] = bc;
sum += bc;
}
ray_pool_dispatch_n(pool_p, strsort_top_scatter_fn,
&tctx, n_tasks);
scratch_free(hist_hdr);
scratch_free(off_hdr);
strsort_bucket_ctx_t bctx = {
.keys = tmp,
.starts = bucket_starts,
.counts = bucket_counts,
.parts_bytes = parts_bytes,
.base_offset = 0,
.elems = elems,
.pool = pool,
.parts = parts,
.start_bp = 1,
};
ray_pool_dispatch_n(pool_p, strsort_bucket_fn,
&bctx, 256);
keys_sorted = tmp;
}
}
}
strsort_emit_ctx_t ectx = { sorted_idx, keys_sorted };
if (go_parallel)
ray_pool_dispatch(pool_p, strsort_emit_fn, &ectx, n_live);
else
strsort_emit_fn(&ectx, 0, 0, n_live);
if (tmp_hdr) scratch_free(tmp_hdr);
if (desc) {
for (int64_t i = 0, j = n_live - 1; i < j; i++, j--) {
int64_t t = sorted_idx[i];
sorted_idx[i] = sorted_idx[j];
sorted_idx[j] = t;
}
}
}
scratch_free(keys_hdr);
}
if (null_count > 0 && nulls_first) {
int64_t a = 0, b = n_live - 1;
while (a < b) { int64_t t = sorted_idx[a]; sorted_idx[a] = sorted_idx[b]; sorted_idx[b] = t; a++; b--; }
a = n_live; b = nrows - 1;
while (a < b) { int64_t t = sorted_idx[a]; sorted_idx[a] = sorted_idx[b]; sorted_idx[b] = t; a++; b--; }
a = 0; b = nrows - 1;
while (a < b) { int64_t t = sorted_idx[a]; sorted_idx[a] = sorted_idx[b]; sorted_idx[b] = t; a++; b--; }
}
return true;
}
typedef struct {
const void* data;
int8_t type;
uint8_t attrs;
uint32_t* pw_max;
} enum_max_ctx_t;
static void enum_max_fn(void* arg, uint32_t wid,
int64_t start, int64_t end) {
enum_max_ctx_t* c = (enum_max_ctx_t*)arg;
uint32_t local_max = c->pw_max[wid];
for (int64_t i = start; i < end; i++) {
uint32_t v = (uint32_t)ray_read_sym(c->data, i, c->type, c->attrs);
if (v > local_max) local_max = v;
}
c->pw_max[wid] = local_max;
}
uint32_t* build_enum_rank(ray_t* col, int64_t nrows, ray_t** hdr_out) {
const void* data = ray_data(col);
int8_t type = col->type;
uint8_t attrs = col->attrs;
uint32_t max_id = 0;
ray_pool_t* pool = ray_pool_get();
if (pool && nrows > 100000) {
uint32_t nw = ray_pool_total_workers(pool);
uint32_t pw_max[nw];
memset(pw_max, 0, nw * sizeof(uint32_t));
enum_max_ctx_t ectx = { .data = data, .type = type, .attrs = attrs, .pw_max = pw_max };
ray_pool_dispatch(pool, enum_max_fn, &ectx, nrows);
for (uint32_t w = 0; w < nw; w++)
if (pw_max[w] > max_id) max_id = pw_max[w];
} else {
for (int64_t i = 0; i < nrows; i++) {
uint32_t v = (uint32_t)ray_read_sym(data, i, type, attrs);
if (v > max_id) max_id = v;
}
}
if (max_id >= UINT32_MAX - 1) { *hdr_out = NULL; return NULL; }
uint32_t n_ids = max_id + 1;
ray_scratch_arena_t arena;
ray_scratch_arena_init(&arena);
uint32_t* ids = (uint32_t*)ray_scratch_arena_push(&arena,
(size_t)n_ids * sizeof(uint32_t));
if (!ids) { ray_scratch_arena_reset(&arena); *hdr_out = NULL; return NULL; }
for (uint32_t i = 0; i < n_ids; i++) ids[i] = i;
const char** ptrs = (const char**)ray_scratch_arena_push(&arena,
(size_t)n_ids * sizeof(const char*));
uint32_t* lens = (uint32_t*)ray_scratch_arena_push(&arena,
(size_t)n_ids * sizeof(uint32_t));
if (!ptrs || !lens) {
ray_scratch_arena_reset(&arena); *hdr_out = NULL; return NULL;
}
for (uint32_t i = 0; i < n_ids; i++) {
ray_t* s = ray_sym_str((int64_t)i);
if (s) {
ptrs[i] = ray_str_ptr(s);
lens[i] = (uint32_t)ray_str_len(s);
} else {
ptrs[i] = NULL;
lens[i] = 0;
}
}
{
uint32_t* tmp = (uint32_t*)ray_scratch_arena_push(&arena,
(size_t)n_ids * sizeof(uint32_t));
if (!tmp) { ray_scratch_arena_reset(&arena);
*hdr_out = NULL; return NULL; }
for (uint32_t width = 1; width < n_ids; width *= 2) {
for (uint32_t i = 0; i < n_ids; i += 2 * width) {
uint32_t lo = i;
uint32_t mid = lo + width;
if (mid > n_ids) mid = n_ids;
uint32_t hi = lo + 2 * width;
if (hi > n_ids) hi = n_ids;
uint32_t a = lo, b = mid, k = lo;
while (a < mid && b < hi) {
uint32_t ia = ids[a], ib = ids[b];
uint32_t la = lens[ia], lb = lens[ib];
uint32_t ml = la < lb ? la : lb;
int cmp = 0;
if (ml > 0) cmp = memcmp(ptrs[ia], ptrs[ib], ml);
if (cmp == 0) cmp = (la > lb) - (la < lb);
if (cmp <= 0) tmp[k++] = ids[a++];
else tmp[k++] = ids[b++];
}
while (a < mid) tmp[k++] = ids[a++];
while (b < hi) tmp[k++] = ids[b++];
}
uint32_t* s = ids; ids = tmp; tmp = s;
}
}
ray_t* rank_hdr;
uint32_t* rank = (uint32_t*)scratch_calloc(&rank_hdr,
(size_t)n_ids * sizeof(uint32_t));
if (!rank) { ray_scratch_arena_reset(&arena); *hdr_out = NULL; return NULL; }
for (uint32_t i = 0; i < n_ids; i++)
rank[ids[i]] = i;
ray_scratch_arena_reset(&arena);
*hdr_out = rank_hdr;
return rank;
}
void sort_insertion(const sort_cmp_ctx_t* ctx, int64_t* arr, int64_t n) {
for (int64_t i = 1; i < n; i++) {
int64_t key = arr[i];
int64_t j = i - 1;
while (j >= 0 && sort_cmp(ctx, arr[j], key) > 0) {
arr[j + 1] = arr[j];
j--;
}
arr[j + 1] = key;
}
}
void sort_merge_recursive(const sort_cmp_ctx_t* ctx,
int64_t* arr, int64_t* tmp, int64_t n) {
if (n <= 64) {
sort_insertion(ctx, arr, n);
return;
}
int64_t mid = n / 2;
sort_merge_recursive(ctx, arr, tmp, mid);
sort_merge_recursive(ctx, arr + mid, tmp + mid, n - mid);
int64_t i = 0, j = mid, k = 0;
while (i < mid && j < n) {
if (sort_cmp(ctx, arr[i], arr[j]) <= 0)
tmp[k++] = arr[i++];
else
tmp[k++] = arr[j++];
}
while (i < mid) tmp[k++] = arr[i++];
while (j < n) tmp[k++] = arr[j++];
memcpy(arr, tmp, (size_t)n * sizeof(int64_t));
}
void sort_phase1_fn(void* arg, uint32_t worker_id, int64_t start, int64_t end) {
(void)worker_id;
sort_phase1_ctx_t* ctx = (sort_phase1_ctx_t*)arg;
for (int64_t chunk_idx = start; chunk_idx < end; chunk_idx++) {
int64_t chunk_size = (ctx->nrows + ctx->n_chunks - 1) / ctx->n_chunks;
int64_t lo = chunk_idx * chunk_size;
int64_t hi = lo + chunk_size;
if (hi > ctx->nrows) hi = ctx->nrows;
if (lo >= hi) continue;
sort_merge_recursive(ctx->cmp_ctx, ctx->indices + lo, ctx->tmp + lo, hi - lo);
}
}
static void merge_runs(const sort_cmp_ctx_t* ctx,
const int64_t* src, int64_t* dst,
int64_t lo, int64_t mid, int64_t hi) {
int64_t i = lo, j = mid, k = lo;
while (i < mid && j < hi) {
if (sort_cmp(ctx, src[i], src[j]) <= 0)
dst[k++] = src[i++];
else
dst[k++] = src[j++];
}
while (i < mid) dst[k++] = src[i++];
while (j < hi) dst[k++] = src[j++];
}
void sort_merge_fn(void* arg, uint32_t worker_id, int64_t start, int64_t end) {
(void)worker_id;
sort_merge_ctx_t* ctx = (sort_merge_ctx_t*)arg;
for (int64_t pair_idx = start; pair_idx < end; pair_idx++) {
int64_t lo = pair_idx * 2 * ctx->run_size;
int64_t mid = lo + ctx->run_size;
int64_t hi = mid + ctx->run_size;
if (mid > ctx->nrows) mid = ctx->nrows;
if (hi > ctx->nrows) hi = ctx->nrows;
if (lo >= ctx->nrows) continue;
if (mid >= hi) {
memcpy(ctx->dst + lo, ctx->src + lo, (size_t)(hi - lo) * sizeof(int64_t));
} else {
merge_runs(ctx->cmp_ctx, ctx->src, ctx->dst, lo, mid, hi);
}
}
}
void mk_prescan_fn(void* arg, uint32_t wid,
int64_t start, int64_t end) {
mk_prescan_ctx_t* c = (mk_prescan_ctx_t*)arg;
uint8_t nk = c->n_keys;
int64_t* my_mins = c->pw_mins + (int64_t)wid * nk;
int64_t* my_maxs = c->pw_maxs + (int64_t)wid * nk;
for (uint8_t k = 0; k < nk; k++) {
if (my_mins[k] == INT64_MAX) {
}
}
for (uint8_t k = 0; k < nk; k++) {
ray_t* col = c->vecs[k];
int64_t kmin = my_mins[k], kmax = my_maxs[k];
if (c->enum_ranks[k]) {
const void* cdata = ray_data(col);
int8_t ctype = col->type;
uint8_t cattrs = col->attrs;
const uint32_t* ranks = c->enum_ranks[k];
for (int64_t i = start; i < end; i++) {
uint32_t raw = (uint32_t)ray_read_sym(cdata, i, ctype, cattrs);
int64_t v = (int64_t)ranks[raw];
if (v < kmin) kmin = v;
if (v > kmax) kmax = v;
}
} else if (col->type == RAY_I64 || col->type == RAY_TIMESTAMP) {
const int64_t* d = (const int64_t*)ray_data(col);
for (int64_t i = start; i < end; i++) {
if (d[i] < kmin) kmin = d[i];
if (d[i] > kmax) kmax = d[i];
}
} else if (col->type == RAY_F64) {
const double* d = (const double*)ray_data(col);
for (int64_t i = start; i < end; i++) {
uint64_t bits;
memcpy(&bits, &d[i], 8);
uint64_t mask = -(bits >> 63) | ((uint64_t)1 << 63);
int64_t v = (int64_t)(bits ^ mask);
if (v < kmin) kmin = v;
if (v > kmax) kmax = v;
}
} else if (col->type == RAY_I32 || col->type == RAY_DATE || col->type == RAY_TIME) {
const int32_t* d = (const int32_t*)ray_data(col);
for (int64_t i = start; i < end; i++) {
int64_t v = (int64_t)d[i];
if (v < kmin) kmin = v;
if (v > kmax) kmax = v;
}
} else if (col->type == RAY_I16) {
const int16_t* d = (const int16_t*)ray_data(col);
for (int64_t i = start; i < end; i++) {
int64_t v = (int64_t)d[i];
if (v < kmin) kmin = v;
if (v > kmax) kmax = v;
}
} else if (col->type == RAY_BOOL || col->type == RAY_U8) {
const uint8_t* d = (const uint8_t*)ray_data(col);
for (int64_t i = start; i < end; i++) {
int64_t v = (int64_t)d[i];
if (v < kmin) kmin = v;
if (v > kmax) kmax = v;
}
}
my_mins[k] = kmin;
my_maxs[k] = kmax;
}
}
typedef struct { uint64_t key; int64_t idx; } topn_entry_t;
static inline void topn_sift_down(topn_entry_t* h, int64_t n, int64_t i) {
for (;;) {
int64_t largest = i, l = 2*i+1, r = 2*i+2;
if (l < n && h[l].key > h[largest].key) largest = l;
if (r < n && h[r].key > h[largest].key) largest = r;
if (largest == i) return;
topn_entry_t t = h[i]; h[i] = h[largest]; h[largest] = t;
i = largest;
}
}
typedef struct {
int64_t limit;
topn_entry_t* heaps;
int64_t* counts;
uint8_t n_keys;
ray_t** vecs;
int64_t mins[16];
int64_t ranges[16];
uint8_t bit_shifts[16];
uint8_t descs[16];
const uint32_t* enum_ranks[16];
} fused_topn_ctx_t;
__attribute__((unused))
static void fused_topn_fn(void* arg, uint32_t wid,
int64_t start, int64_t end) {
fused_topn_ctx_t* c = (fused_topn_ctx_t*)arg;
int64_t K = c->limit;
topn_entry_t* heap = c->heaps + (int64_t)wid * K;
int64_t cnt = c->counts[wid];
uint8_t nk = c->n_keys;
for (int64_t i = start; i < end; i++) {
uint64_t composite = 0;
for (uint8_t k = 0; k < nk; k++) {
ray_t* col = c->vecs[k];
int64_t val;
if (c->enum_ranks[k]) {
uint32_t raw = (uint32_t)ray_read_sym(ray_data(col), i, col->type, col->attrs);
val = (int64_t)c->enum_ranks[k][raw];
} else if (col->type == RAY_I64 || col->type == RAY_TIMESTAMP) {
val = ((const int64_t*)ray_data(col))[i];
} else if (col->type == RAY_F64) {
uint64_t bits;
memcpy(&bits, &((const double*)ray_data(col))[i], 8);
uint64_t mask = -(bits >> 63) | ((uint64_t)1 << 63);
val = (int64_t)(bits ^ mask);
} else if (col->type == RAY_I32 || col->type == RAY_DATE || col->type == RAY_TIME) {
val = (int64_t)((const int32_t*)ray_data(col))[i];
} else if (col->type == RAY_I16) {
val = (int64_t)((const int16_t*)ray_data(col))[i];
} else if (col->type == RAY_BOOL || col->type == RAY_U8) {
val = (int64_t)((const uint8_t*)ray_data(col))[i];
} else {
val = 0;
}
uint64_t part = (uint64_t)val - (uint64_t)c->mins[k];
if (c->descs[k]) part = (uint64_t)c->ranges[k] - part;
composite |= part << c->bit_shifts[k];
}
if (cnt < K) {
heap[cnt].key = composite;
heap[cnt].idx = i;
cnt++;
if (cnt == K) {
for (int64_t j = K/2 - 1; j >= 0; j--)
topn_sift_down(heap, K, j);
}
} else if (composite < heap[0].key) {
heap[0].key = composite;
heap[0].idx = i;
topn_sift_down(heap, K, 0);
}
}
c->counts[wid] = cnt;
}
typedef struct {
const uint64_t* keys;
int64_t limit;
topn_entry_t* heaps;
int64_t* counts;
} topn_ctx_t;
__attribute__((unused))
static void topn_scan_fn(void* arg, uint32_t wid, int64_t start, int64_t end) {
topn_ctx_t* c = (topn_ctx_t*)arg;
int64_t K = c->limit;
topn_entry_t* heap = c->heaps + (int64_t)wid * K;
const uint64_t* keys = c->keys;
int64_t cnt = c->counts[wid];
for (int64_t i = start; i < end; i++) {
uint64_t k = keys[i];
if (cnt < K) {
heap[cnt].key = k;
heap[cnt].idx = i;
cnt++;
if (cnt == K) {
for (int64_t j = K/2 - 1; j >= 0; j--)
topn_sift_down(heap, K, j);
}
} else if (k < heap[0].key) {
heap[0].key = k;
heap[0].idx = i;
topn_sift_down(heap, K, 0);
}
}
c->counts[wid] = cnt;
}
#define TOPN_MAX 8192
__attribute__((unused))
static int64_t topn_merge_fused(fused_topn_ctx_t* ctx, uint32_t n_workers,
int64_t* out, int64_t limit) {
if (limit > TOPN_MAX) limit = TOPN_MAX;
topn_entry_t merge[limit];
int64_t cnt = 0;
for (uint32_t w = 0; w < n_workers; w++) {
topn_entry_t* wh = ctx->heaps + (int64_t)w * limit;
int64_t wc = ctx->counts[w];
for (int64_t j = 0; j < wc; j++) {
if (cnt < limit) {
merge[cnt++] = wh[j];
if (cnt == limit) {
for (int64_t m = limit/2 - 1; m >= 0; m--)
topn_sift_down(merge, limit, m);
}
} else if (wh[j].key < merge[0].key) {
merge[0] = wh[j];
topn_sift_down(merge, limit, 0);
}
}
}
if (cnt > 1) {
for (int64_t m = cnt/2 - 1; m >= 0; m--)
topn_sift_down(merge, cnt, m);
for (int64_t i = cnt - 1; i > 0; i--) {
topn_entry_t t = merge[0]; merge[0] = merge[i]; merge[i] = t;
topn_sift_down(merge, i, 0);
}
}
for (int64_t i = 0; i < cnt; i++)
out[i] = merge[i].idx;
return cnt;
}
__attribute__((unused))
static int64_t topn_merge(topn_ctx_t* ctx, uint32_t n_workers,
int64_t* out, int64_t limit) {
if (limit > TOPN_MAX) limit = TOPN_MAX;
topn_entry_t merge[limit];
int64_t cnt = 0;
for (uint32_t w = 0; w < n_workers; w++) {
topn_entry_t* wh = ctx->heaps + (int64_t)w * limit;
int64_t wc = ctx->counts[w];
for (int64_t j = 0; j < wc; j++) {
if (cnt < limit) {
merge[cnt++] = wh[j];
if (cnt == limit) {
for (int64_t m = limit/2 - 1; m >= 0; m--)
topn_sift_down(merge, limit, m);
}
} else if (wh[j].key < merge[0].key) {
merge[0] = wh[j];
topn_sift_down(merge, limit, 0);
}
}
}
if (cnt > 1) {
for (int64_t m = cnt/2 - 1; m >= 0; m--)
topn_sift_down(merge, cnt, m);
for (int64_t i = cnt - 1; i > 0; i--) {
topn_entry_t t = merge[0]; merge[0] = merge[i]; merge[i] = t;
topn_sift_down(merge, i, 0);
}
}
for (int64_t i = 0; i < cnt; i++)
out[i] = merge[i].idx;
return cnt;
}
static void radix_decode_into(void* dst, int8_t type, const uint64_t* sorted_keys,
int64_t n, bool desc) {
if (type == RAY_I64 || type == RAY_TIMESTAMP) {
int64_t* d = (int64_t*)dst;
if (desc)
for (int64_t i = 0; i < n; i++)
d[i] = (int64_t)(~sorted_keys[i] ^ ((uint64_t)1 << 63));
else
for (int64_t i = 0; i < n; i++)
d[i] = (int64_t)(sorted_keys[i] ^ ((uint64_t)1 << 63));
} else if (type == RAY_F64) {
double* d = (double*)dst;
for (int64_t i = 0; i < n; i++) {
uint64_t k = desc ? ~sorted_keys[i] : sorted_keys[i];
uint64_t mask = (k >> 63) ? ((uint64_t)1 << 63) : ~(uint64_t)0;
uint64_t bits = k ^ mask;
memcpy(&d[i], &bits, 8);
}
} else if (type == RAY_I32 || type == RAY_DATE || type == RAY_TIME) {
int32_t* d = (int32_t*)dst;
if (desc)
for (int64_t i = 0; i < n; i++)
d[i] = (int32_t)((uint32_t)(~sorted_keys[i]) ^ ((uint32_t)1 << 31));
else
for (int64_t i = 0; i < n; i++)
d[i] = (int32_t)((uint32_t)sorted_keys[i] ^ ((uint32_t)1 << 31));
} else if (type == RAY_I16) {
int16_t* d = (int16_t*)dst;
if (desc)
for (int64_t i = 0; i < n; i++)
d[i] = (int16_t)((uint16_t)(~sorted_keys[i]) ^ ((uint16_t)1 << 15));
else
for (int64_t i = 0; i < n; i++)
d[i] = (int16_t)((uint16_t)sorted_keys[i] ^ ((uint16_t)1 << 15));
} else if (type == RAY_BOOL || type == RAY_U8) {
uint8_t* d = (uint8_t*)dst;
if (desc)
for (int64_t i = 0; i < n; i++) d[i] = (uint8_t)(~sorted_keys[i]);
else
for (int64_t i = 0; i < n; i++) d[i] = (uint8_t)sorted_keys[i];
}
}
static ray_t* sort_indices_ex(ray_t** cols, uint8_t* descs, uint8_t* nulls_first,
uint8_t n_cols, int64_t nrows,
uint64_t** sorted_keys_out, ray_t** keys_hdr_out) {
if (n_cols == 0 || nrows <= 0)
return ray_vec_new(RAY_I64, 0);
if (n_cols > 16)
return ray_error("nyi", NULL);
ray_t* indices_hdr;
int64_t* indices = (int64_t*)scratch_alloc(&indices_hdr,
(size_t)nrows * sizeof(int64_t));
if (!indices) return ray_error("oom", NULL);
bool iota_done = false;
bool radix_done = false;
int64_t* sorted_idx = indices;
ray_t* radix_itmp_hdr = NULL;
ray_t* enum_rank_hdrs[n_cols];
memset(enum_rank_hdrs, 0, n_cols * sizeof(ray_t*));
if (nrows > 64) {
if (n_cols == 1 && cols[0]->type == RAY_STR) {
bool desc = descs ? descs[0] : 0;
bool nf = nulls_first ? nulls_first[0] : !desc;
if (sort_str_msd_inplace(indices, nrows, cols[0], desc, nf)) {
sorted_idx = indices;
iota_done = true;
radix_done = true;
goto str_msd_done;
}
}
bool can_radix = true;
bool has_wide_key = false;
for (uint8_t k = 0; k < n_cols; k++) {
if (!cols[k]) { can_radix = false; break; }
int8_t t = cols[k]->type;
if (t == RAY_STR || t == RAY_GUID) { has_wide_key = true; continue; }
if (t != RAY_I64 && t != RAY_F64 && t != RAY_I32 && t != RAY_I16 &&
t != RAY_BOOL && t != RAY_U8 && t != RAY_SYM &&
t != RAY_DATE && t != RAY_TIME && t != RAY_TIMESTAMP) {
can_radix = false; break;
}
}
if (has_wide_key && n_cols == 1) can_radix = false;
if (can_radix) {
ray_pool_t* pool = ray_pool_get();
uint32_t* enum_ranks[n_cols];
memset(enum_ranks, 0, n_cols * sizeof(uint32_t*));
for (uint8_t k = 0; k < n_cols; k++) {
if (RAY_IS_SYM(cols[k]->type)) {
enum_ranks[k] = build_enum_rank(cols[k], nrows,
&enum_rank_hdrs[k]);
if (!enum_ranks[k]) { can_radix = false; break; }
}
}
if (can_radix && n_cols == 1) {
uint8_t key_nbytes_max = radix_key_bytes(cols[0]->type);
ray_pool_t* sk_pool = (nrows >= SMALL_POOL_THRESHOLD) ? pool : NULL;
ray_t *keys_hdr;
uint64_t* keys = (uint64_t*)scratch_alloc(&keys_hdr,
(size_t)nrows * sizeof(uint64_t));
if (keys) {
bool desc = descs ? descs[0] : 0;
bool nf = nulls_first ? nulls_first[0] : !desc;
radix_encode_ctx_t enc = {
.keys = keys, .indices = indices,
.data = ray_data(cols[0]),
.col = cols[0],
.type = cols[0]->type,
.col_attrs = cols[0]->attrs,
.desc = desc,
.nulls_first = nf,
.enum_rank = enum_ranks[0], .n_keys = 1,
};
if (sk_pool)
ray_pool_dispatch(sk_pool, radix_encode_fn, &enc, nrows);
else
radix_encode_fn(&enc, 0, 0, nrows);
iota_done = true;
if (nrows <= RADIX_SORT_THRESHOLD) {
key_introsort(keys, indices, nrows);
sorted_idx = indices;
radix_done = true;
} else {
uint8_t key_nbytes = compute_key_nbytes(
sk_pool, keys, nrows, key_nbytes_max);
uint8_t idx_bits = 0;
{ int64_t nn = nrows; while (nn > 0) { idx_bits++; nn >>= 1; } }
bool use_packed = (key_nbytes <= 3
&& key_nbytes * 8 + idx_bits <= 64);
if (use_packed) {
uint8_t key_bits = key_nbytes * 8;
ray_t *ptmp_hdr;
uint64_t* ptmp = (uint64_t*)scratch_alloc(&ptmp_hdr,
(size_t)nrows * sizeof(uint64_t));
if (ptmp) {
uint32_t pd_nw = sk_pool ? ray_pool_total_workers(sk_pool) : 1;
int64_t pd_pw[pd_nw], pd_nr[pd_nw];
memset(pd_pw, 0, (size_t)pd_nw * sizeof(int64_t));
memset(pd_nr, 0, (size_t)pd_nw * sizeof(int64_t));
uint64_t key_mask_pd =
(key_bits < 64) ? ((1ULL << key_bits) - 1) : ~0ULL;
packed_detect_ctx_t pd_ctx = {
.keys = keys, .key_bits = key_bits,
.key_mask = key_mask_pd,
.pw_unsorted = pd_pw, .pw_not_reverse = pd_nr,
};
if (sk_pool)
ray_pool_dispatch(sk_pool, packed_detect_fn, &pd_ctx, nrows);
else
packed_detect_fn(&pd_ctx, 0, 0, nrows);
int64_t total_unsorted = 0, total_not_rev = 0;
for (uint32_t t = 0; t < pd_nw; t++) {
total_unsorted += pd_pw[t];
total_not_rev += pd_nr[t];
}
int64_t grain = RAY_DISPATCH_MORSELS * RAY_MORSEL_ELEMS;
uint64_t key_mask_s =
(key_bits < 64) ? ((1ULL << key_bits) - 1) : ~0ULL;
for (int64_t b = grain; b < nrows; b += grain) {
uint64_t ka = keys[b-1] & key_mask_s;
uint64_t kb2 = keys[b] & key_mask_s;
if (kb2 < ka) total_unsorted++;
if (kb2 > ka) total_not_rev++;
}
if (total_unsorted == 0) {
sorted_idx = indices;
radix_done = true;
} else if (total_not_rev == 0 && nrows > 1) {
for (int64_t i = 0; i < nrows; i++)
indices[i] = nrows - 1 - i;
sorted_idx = indices;
radix_done = true;
} else {
uint64_t* sorted = packed_radix_sort_run(
sk_pool, keys, ptmp, nrows, key_nbytes);
if (sorted) {
uint64_t idx_mask =
(idx_bits < 64) ? ((1ULL << idx_bits) - 1) : ~0ULL;
packed_unpack_ctx_t up = {
.sorted = sorted, .indices = indices,
.keys_out = NULL,
.key_bits = key_bits,
.idx_mask = idx_mask, .key_mask = 0,
.extract_keys = false,
};
if (sk_pool)
ray_pool_dispatch(sk_pool, packed_unpack_fn, &up, nrows);
else
packed_unpack_fn(&up, 0, 0, nrows);
sorted_idx = indices;
radix_done = true;
}
}
}
scratch_free(ptmp_hdr);
} else {
double us_frac2 = detect_sortedness(sk_pool, keys, nrows);
if (us_frac2 == 0.0) {
sorted_idx = indices;
radix_done = true;
}
if (!radix_done) {
ray_t *ktmp_hdr, *itmp_hdr;
uint64_t* ktmp = (uint64_t*)scratch_alloc(&ktmp_hdr,
(size_t)nrows * sizeof(uint64_t));
int64_t* itmp = (int64_t*)scratch_alloc(&itmp_hdr,
(size_t)nrows * sizeof(int64_t));
if (ktmp && itmp) {
bool want_sk = sorted_keys_out
&& !RAY_IS_SYM(cols[0]->type);
uint64_t* sk_out = NULL;
sorted_idx = msd_radix_sort_run(sk_pool, keys, indices,
ktmp, itmp, nrows,
key_nbytes,
want_sk ? &sk_out : NULL);
radix_done = (sorted_idx != NULL);
if (radix_done && want_sk && sk_out) {
*sorted_keys_out = sk_out;
if (sk_out == ktmp) {
*keys_hdr_out = ktmp_hdr;
ktmp_hdr = NULL;
} else {
*keys_hdr_out = keys_hdr;
keys_hdr = NULL;
}
}
}
if (ktmp_hdr) scratch_free(ktmp_hdr);
if (sorted_idx != itmp) scratch_free(itmp_hdr);
else radix_itmp_hdr = itmp_hdr;
}
}
}
}
scratch_free(keys_hdr);
} else if (can_radix && n_cols > 1) {
int64_t mins[n_cols], maxs[n_cols];
uint16_t total_bits = 0;
bool fits = true;
ray_pool_t* mk_prescan_pool = (nrows >= SMALL_POOL_THRESHOLD) ? pool : NULL;
if (has_wide_key) {
total_bits = UINT16_MAX;
fits = false;
} else if (n_cols <= MK_PRESCAN_MAX_KEYS && mk_prescan_pool) {
uint32_t nw = ray_pool_total_workers(mk_prescan_pool);
size_t pw_count = (size_t)nw * n_cols;
int64_t pw_mins_stack[512], pw_maxs_stack[512];
ray_t *pw_mins_hdr = NULL, *pw_maxs_hdr = NULL;
int64_t* pw_mins = (pw_count <= 512)
? pw_mins_stack
: (int64_t*)scratch_alloc(&pw_mins_hdr, pw_count * sizeof(int64_t));
int64_t* pw_maxs = (pw_count <= 512)
? pw_maxs_stack
: (int64_t*)scratch_alloc(&pw_maxs_hdr, pw_count * sizeof(int64_t));
for (size_t i = 0; i < pw_count; i++) {
pw_mins[i] = INT64_MAX;
pw_maxs[i] = INT64_MIN;
}
mk_prescan_ctx_t pctx = {
.vecs = cols, .enum_ranks = enum_ranks,
.n_keys = n_cols, .nrows = nrows, .n_workers = nw,
.pw_mins = pw_mins, .pw_maxs = pw_maxs,
};
ray_pool_dispatch(mk_prescan_pool, mk_prescan_fn, &pctx, nrows);
for (uint8_t k = 0; k < n_cols; k++) {
int64_t kmin = INT64_MAX, kmax = INT64_MIN;
for (uint32_t w = 0; w < nw; w++) {
int64_t wmin = pw_mins[w * n_cols + k];
int64_t wmax = pw_maxs[w * n_cols + k];
if (wmin < kmin) kmin = wmin;
if (wmax > kmax) kmax = wmax;
}
mins[k] = kmin;
maxs[k] = kmax;
uint64_t range = (uint64_t)(kmax - kmin);
uint8_t bits = 1;
while (((uint64_t)1 << bits) <= range && bits < 64)
bits++;
total_bits = (uint16_t)(total_bits + bits);
}
if (pw_mins_hdr) scratch_free(pw_mins_hdr);
if (pw_maxs_hdr) scratch_free(pw_maxs_hdr);
} else {
for (uint8_t k = 0; k < n_cols; k++) {
ray_t* col = cols[k];
int64_t kmin = INT64_MAX, kmax = INT64_MIN;
if (enum_ranks[k]) {
const void* cdata = ray_data(col);
int8_t ctype = col->type;
uint8_t cattrs = col->attrs;
for (int64_t i = 0; i < nrows; i++) {
uint32_t raw = (uint32_t)ray_read_sym(cdata, i, ctype, cattrs);
int64_t v = (int64_t)enum_ranks[k][raw];
if (v < kmin) kmin = v;
if (v > kmax) kmax = v;
}
} else if (col->type == RAY_I64 || col->type == RAY_TIMESTAMP) {
const int64_t* d = (const int64_t*)ray_data(col);
for (int64_t i = 0; i < nrows; i++) {
if (d[i] < kmin) kmin = d[i];
if (d[i] > kmax) kmax = d[i];
}
} else if (col->type == RAY_F64) {
const double* d = (const double*)ray_data(col);
for (int64_t i = 0; i < nrows; i++) {
uint64_t bits;
memcpy(&bits, &d[i], 8);
uint64_t mask = -(bits >> 63) | ((uint64_t)1 << 63);
int64_t v = (int64_t)(bits ^ mask);
if (v < kmin) kmin = v;
if (v > kmax) kmax = v;
}
} else if (col->type == RAY_I32 || col->type == RAY_DATE || col->type == RAY_TIME) {
const int32_t* d = (const int32_t*)ray_data(col);
for (int64_t i = 0; i < nrows; i++) {
if (d[i] < kmin) kmin = (int64_t)d[i];
if (d[i] > kmax) kmax = (int64_t)d[i];
}
} else if (col->type == RAY_I16) {
const int16_t* d = (const int16_t*)ray_data(col);
for (int64_t i = 0; i < nrows; i++) {
if (d[i] < kmin) kmin = (int64_t)d[i];
if (d[i] > kmax) kmax = (int64_t)d[i];
}
} else if (col->type == RAY_BOOL || col->type == RAY_U8) {
const uint8_t* d = (const uint8_t*)ray_data(col);
for (int64_t i = 0; i < nrows; i++) {
if (d[i] < kmin) kmin = (int64_t)d[i];
if (d[i] > kmax) kmax = (int64_t)d[i];
}
}
mins[k] = kmin;
maxs[k] = kmax;
uint64_t range = (uint64_t)(kmax - kmin);
uint8_t bits = 1;
while (((uint64_t)1 << bits) <= range && bits < 64)
bits++;
total_bits = (uint16_t)(total_bits + bits);
}
}
if (total_bits > 64) {
fits = false;
ray_t* rank_hdrs[n_cols];
uint32_t* ranks[n_cols];
uint32_t rank_max[n_cols];
bool rank_ok = true;
for (uint8_t k = 0; k < n_cols; k++) {
rank_hdrs[k] = NULL; ranks[k] = NULL; rank_max[k] = 0;
}
for (uint8_t k = 0; k < n_cols && rank_ok; k++) {
uint8_t kdesc = descs ? descs[k] : 0;
uint8_t knf = nulls_first ? nulls_first[k] : !kdesc;
ray_t* col_arg[1] = { cols[k] };
uint8_t desc_arg[1] = { kdesc };
uint8_t nf_arg[1] = { knf };
ray_t* sk_idx = sort_indices_ex(col_arg, desc_arg,
nf_arg, 1, nrows,
NULL, NULL);
if (!sk_idx || RAY_IS_ERR(sk_idx)) { rank_ok = false; break; }
int64_t* sk_idx_data = (int64_t*)ray_data(sk_idx);
uint32_t* r = (uint32_t*)scratch_alloc(&rank_hdrs[k],
(size_t)nrows * sizeof(uint32_t));
if (!r) { ray_release(sk_idx); rank_ok = false; break; }
ranks[k] = r;
sort_cmp_ctx_t cctx = {
.vecs = col_arg, .desc = desc_arg,
.nulls_first = nf_arg, .n_sort = 1,
};
uint32_t cur = 0;
r[sk_idx_data[0]] = 0;
for (int64_t i = 1; i < nrows; i++) {
if (sort_cmp(&cctx, sk_idx_data[i-1], sk_idx_data[i]) != 0)
cur++;
r[sk_idx_data[i]] = cur;
}
rank_max[k] = cur;
ray_release(sk_idx);
}
if (rank_ok) {
uint8_t rank_bits[n_cols];
uint16_t rank_total = 0;
for (uint8_t k = 0; k < n_cols; k++) {
uint8_t b = 1;
while (((uint64_t)1 << b) <= rank_max[k] && b < 64) b++;
rank_bits[k] = b;
rank_total = (uint16_t)(rank_total + b);
}
if (rank_total <= 64) {
uint8_t rshift[n_cols];
uint16_t accum = 0;
for (int k = n_cols - 1; k >= 0; k--) {
rshift[k] = (uint8_t)accum;
accum = (uint16_t)(accum + rank_bits[k]);
}
uint8_t rcomp_nbytes = (uint8_t)((rank_total + 7) / 8);
if (rcomp_nbytes < 1) rcomp_nbytes = 1;
ray_pool_t* rk_pool =
(nrows >= SMALL_POOL_THRESHOLD) ? pool : NULL;
ray_t* rkeys_hdr;
uint64_t* rkeys = (uint64_t*)scratch_alloc(&rkeys_hdr,
(size_t)nrows * sizeof(uint64_t));
if (rkeys) {
for (int64_t i = 0; i < nrows; i++) {
uint64_t composite = 0;
for (uint8_t k = 0; k < n_cols; k++)
composite |= ((uint64_t)ranks[k][i]) << rshift[k];
rkeys[i] = composite;
indices[i] = i;
}
iota_done = true;
if (nrows <= RADIX_SORT_THRESHOLD) {
key_introsort(rkeys, indices, nrows);
sorted_idx = indices;
radix_done = true;
} else {
ray_t *rktmp_hdr, *ritmp_hdr;
uint64_t* rktmp = (uint64_t*)scratch_alloc(&rktmp_hdr,
(size_t)nrows * sizeof(uint64_t));
int64_t* ritmp = (int64_t*)scratch_alloc(&ritmp_hdr,
(size_t)nrows * sizeof(int64_t));
if (rktmp && ritmp) {
sorted_idx = msd_radix_sort_run(
rk_pool, rkeys, indices,
rktmp, ritmp, nrows, rcomp_nbytes, NULL);
radix_done = (sorted_idx != NULL);
}
if (rktmp_hdr) scratch_free(rktmp_hdr);
if (sorted_idx != ritmp) {
if (ritmp_hdr) scratch_free(ritmp_hdr);
} else {
radix_itmp_hdr = ritmp_hdr;
}
}
scratch_free(rkeys_hdr);
}
}
}
for (uint8_t k = 0; k < n_cols; k++)
if (rank_hdrs[k]) scratch_free(rank_hdrs[k]);
}
if (fits) {
uint8_t bit_shifts[n_cols];
uint8_t accum = 0;
for (int k = n_cols - 1; k >= 0; k--) {
bit_shifts[k] = accum;
uint64_t range = (uint64_t)(maxs[k] - mins[k]);
uint8_t bits = 1;
while (((uint64_t)1 << bits) <= range && bits < 64)
bits++;
accum += bits;
}
uint8_t comp_nbytes = (total_bits + 7) / 8;
if (comp_nbytes < 1) comp_nbytes = 1;
ray_pool_t* mk_pool = (nrows >= SMALL_POOL_THRESHOLD) ? pool : NULL;
{
ray_t *keys_hdr;
uint64_t* keys = (uint64_t*)scratch_alloc(&keys_hdr,
(size_t)nrows * sizeof(uint64_t));
if (keys) {
radix_encode_ctx_t enc = {
.keys = keys, .indices = indices,
.n_keys = n_cols, .vecs = cols,
};
for (uint8_t k = 0; k < n_cols; k++) {
enc.mins[k] = mins[k];
enc.ranges[k] = maxs[k] - mins[k];
enc.bit_shifts[k] = bit_shifts[k];
enc.descs[k] = descs ? descs[k] : 0;
enc.enum_ranks[k] = enum_ranks[k];
}
if (mk_pool)
ray_pool_dispatch(mk_pool, radix_encode_fn, &enc, nrows);
else
radix_encode_fn(&enc, 0, 0, nrows);
iota_done = true;
double unsorted_frac = detect_sortedness(mk_pool, keys, nrows);
if (unsorted_frac == 0.0) {
sorted_idx = indices;
radix_done = true;
} else if (nrows <= RADIX_SORT_THRESHOLD) {
key_introsort(keys, indices, nrows);
sorted_idx = indices;
radix_done = true;
} else {
ray_t *ktmp_hdr, *itmp_hdr;
uint64_t* ktmp = (uint64_t*)scratch_alloc(&ktmp_hdr,
(size_t)nrows * sizeof(uint64_t));
int64_t* itmp = (int64_t*)scratch_alloc(&itmp_hdr,
(size_t)nrows * sizeof(int64_t));
if (ktmp && itmp) {
sorted_idx = msd_radix_sort_run(mk_pool, keys, indices,
ktmp, itmp, nrows,
comp_nbytes, NULL);
radix_done = (sorted_idx != NULL);
}
scratch_free(ktmp_hdr);
if (sorted_idx != itmp) scratch_free(itmp_hdr);
else radix_itmp_hdr = itmp_hdr;
}
}
scratch_free(keys_hdr);
}
}
}
}
}
if (!radix_done) {
if (!iota_done)
for (int64_t i = 0; i < nrows; i++) indices[i] = i;
uint8_t default_nf[n_cols > 0 ? n_cols : 1];
if (!nulls_first) {
for (uint8_t k = 0; k < n_cols; k++)
default_nf[k] = descs ? !descs[k] : 1;
nulls_first = default_nf;
}
sort_cmp_ctx_t cmp_ctx = {
.vecs = cols,
.desc = descs,
.nulls_first = nulls_first,
.n_sort = n_cols,
};
if (nrows <= 64) {
sort_insertion(&cmp_ctx, indices, nrows);
} else {
ray_pool_t* pool = ray_pool_get();
uint32_t n_workers = pool ? ray_pool_total_workers(pool) : 1;
ray_t* tmp_hdr;
int64_t* tmp = (int64_t*)scratch_alloc(&tmp_hdr,
(size_t)nrows * sizeof(int64_t));
if (!tmp) {
for (uint8_t k = 0; k < n_cols; k++)
scratch_free(enum_rank_hdrs[k]);
scratch_free(indices_hdr);
return ray_error("oom", NULL);
}
uint32_t n_chunks = n_workers;
if (pool && n_chunks > 1 && nrows > 1024) {
sort_phase1_ctx_t p1ctx = {
.cmp_ctx = &cmp_ctx, .indices = indices, .tmp = tmp,
.nrows = nrows, .n_chunks = n_chunks,
};
ray_pool_dispatch_n(pool, sort_phase1_fn, &p1ctx, n_chunks);
} else {
n_chunks = 1;
sort_merge_recursive(&cmp_ctx, indices, tmp, nrows);
}
if (n_chunks > 1) {
int64_t chunk_size = (nrows + n_chunks - 1) / n_chunks;
int64_t run_size = chunk_size;
int64_t* src = indices;
int64_t* dst = tmp;
while (run_size < nrows) {
int64_t n_pairs = (nrows + 2 * run_size - 1) / (2 * run_size);
sort_merge_ctx_t mctx = {
.cmp_ctx = &cmp_ctx, .src = src, .dst = dst,
.nrows = nrows, .run_size = run_size,
};
if (pool && n_pairs > 1)
ray_pool_dispatch_n(pool, sort_merge_fn, &mctx,
(uint32_t)n_pairs);
else
sort_merge_fn(&mctx, 0, 0, n_pairs);
int64_t* t = src; src = dst; dst = t;
run_size *= 2;
}
if (src != indices)
memcpy(indices, src, (size_t)nrows * sizeof(int64_t));
}
scratch_free(tmp_hdr);
}
}
str_msd_done:;
if (sorted_keys_out && !*sorted_keys_out) {
*sorted_keys_out = NULL;
if (keys_hdr_out) *keys_hdr_out = NULL;
}
ray_t* result = ray_vec_new(RAY_I64, nrows);
if (!result || RAY_IS_ERR(result)) {
if (sorted_keys_out && *sorted_keys_out && keys_hdr_out)
scratch_free(*keys_hdr_out);
for (uint8_t k = 0; k < n_cols; k++)
scratch_free(enum_rank_hdrs[k]);
scratch_free(radix_itmp_hdr);
scratch_free(indices_hdr);
return result ? result : ray_error("oom", NULL);
}
result->len = nrows;
memcpy(ray_data(result), sorted_idx, (size_t)nrows * sizeof(int64_t));
for (uint8_t k = 0; k < n_cols; k++)
scratch_free(enum_rank_hdrs[k]);
scratch_free(radix_itmp_hdr);
scratch_free(indices_hdr);
return result;
}
ray_t* ray_sort_indices(ray_t** cols, uint8_t* descs, uint8_t* nulls_first,
uint8_t n_cols, int64_t nrows) {
return sort_indices_ex(cols, descs, nulls_first, n_cols, nrows, NULL, NULL);
}
ray_t* ray_sort(ray_t** cols, uint8_t* descs, uint8_t* nulls_first,
uint8_t n_cols, int64_t nrows) {
if (n_cols == 1) {
uint64_t* sorted_keys = NULL;
ray_t* keys_hdr = NULL;
ray_t* idx = sort_indices_ex(cols, descs, nulls_first, 1, nrows,
&sorted_keys, &keys_hdr);
if (!idx || RAY_IS_ERR(idx)) return idx;
if (sorted_keys && !RAY_IS_SYM(cols[0]->type)) {
ray_t* result = ray_vec_new(cols[0]->type, nrows);
if (!result || RAY_IS_ERR(result)) {
ray_release(idx);
if (keys_hdr) scratch_free(keys_hdr);
return result ? result : ray_error("oom", NULL);
}
result->len = nrows;
radix_decode_into(ray_data(result), cols[0]->type, sorted_keys,
nrows, descs ? descs[0] : 0);
if (cols[0]->attrs & RAY_ATTR_HAS_NULLS) {
int64_t* idx_data = (int64_t*)ray_data(idx);
for (int64_t i = 0; i < nrows; i++)
if (ray_vec_is_null(cols[0], idx_data[i]))
ray_vec_set_null(result, i, true);
}
ray_release(idx);
scratch_free(keys_hdr);
return result;
}
if (keys_hdr) scratch_free(keys_hdr);
ray_t* result = gather_by_idx(cols[0], (int64_t*)ray_data(idx), nrows);
ray_release(idx);
return result;
}
ray_t* idx = ray_sort_indices(cols, descs, nulls_first, n_cols, nrows);
if (!idx || RAY_IS_ERR(idx)) return idx;
ray_t* result = gather_by_idx(cols[0], (int64_t*)ray_data(idx), nrows);
ray_release(idx);
return result;
}
ray_t* exec_sort(ray_graph_t* g, ray_op_t* op, ray_t* tbl, int64_t limit) {
if (!tbl || RAY_IS_ERR(tbl)) return tbl;
ray_op_ext_t* ext = find_ext(g, op->id);
if (!ext) return ray_error("nyi", NULL);
int64_t nrows = ray_table_nrows(tbl);
int64_t ncols = ray_table_ncols(tbl);
if (ncols > 4096) return ray_error("nyi", NULL);
uint8_t n_sort = ext->sort.n_cols;
if (n_sort > 16) return ray_error("nyi", NULL);
ray_t* sort_vecs[n_sort > 0 ? n_sort : 1];
uint8_t sort_owned[n_sort > 0 ? n_sort : 1];
memset(sort_vecs, 0, (n_sort > 0 ? n_sort : 1) * sizeof(ray_t*));
memset(sort_owned, 0, n_sort > 0 ? n_sort : 1);
for (uint8_t k = 0; k < n_sort; k++) {
ray_op_t* key_op = ext->sort.columns[k];
ray_op_ext_t* key_ext = find_ext(g, key_op->id);
if (key_ext && key_ext->base.opcode == OP_SCAN) {
sort_vecs[k] = ray_table_get_col(tbl, key_ext->sym);
} else {
ray_t* saved = g->table;
g->table = tbl;
sort_vecs[k] = exec_node(g, key_op);
g->table = saved;
sort_owned[k] = 1;
}
if (!sort_vecs[k] || RAY_IS_ERR(sort_vecs[k])) {
ray_t* err = sort_vecs[k] ? sort_vecs[k] : ray_error("nyi", NULL);
for (uint8_t j = 0; j < k; j++) {
if (sort_owned[j] && sort_vecs[j] && !RAY_IS_ERR(sort_vecs[j]))
ray_release(sort_vecs[j]);
}
return err;
}
}
uint64_t* sorted_keys = NULL;
ray_t* sorted_keys_hdr = NULL;
ray_t* idx_vec = sort_indices_ex(sort_vecs, ext->sort.desc,
ext->sort.nulls_first, n_sort, nrows,
&sorted_keys, &sorted_keys_hdr);
if (!idx_vec || RAY_IS_ERR(idx_vec)) {
if (sorted_keys_hdr) scratch_free(sorted_keys_hdr);
for (uint8_t k = 0; k < n_sort; k++) {
if (sort_owned[k] && sort_vecs[k] && !RAY_IS_ERR(sort_vecs[k]))
ray_release(sort_vecs[k]);
}
return idx_vec ? idx_vec : ray_error("oom", NULL);
}
int64_t* sorted_idx = (int64_t*)ray_data(idx_vec);
{
ray_pool_t* cp = ray_pool_get();
if (pool_cancelled(cp)) {
if (sorted_keys_hdr) scratch_free(sorted_keys_hdr);
for (uint8_t k = 0; k < n_sort; k++) {
if (sort_owned[k] && sort_vecs[k] && !RAY_IS_ERR(sort_vecs[k]))
ray_release(sort_vecs[k]);
}
ray_release(idx_vec);
return ray_error("cancel", NULL);
}
}
int64_t gather_rows = nrows;
if (limit > 0 && limit < nrows) gather_rows = limit;
ray_t* result = ray_table_new(ncols);
if (!result || RAY_IS_ERR(result)) {
if (sorted_keys_hdr) scratch_free(sorted_keys_hdr);
for (uint8_t k = 0; k < n_sort; k++) {
if (sort_owned[k] && sort_vecs[k] && !RAY_IS_ERR(sort_vecs[k]))
ray_release(sort_vecs[k]);
}
ray_release(idx_vec);
return result;
}
ray_pool_t* gather_pool = (gather_rows > RAY_PARALLEL_THRESHOLD) ? ray_pool_get() : NULL;
ray_t* new_cols[ncols];
int64_t col_names[ncols];
for (int64_t c = 0; c < ncols; c++) {
ray_t* col = ray_table_get_col_idx(tbl, c);
col_names[c] = ray_table_col_name(tbl, c);
if (!col) { new_cols[c] = NULL; continue; }
ray_t* nc;
if (col->type == RAY_LIST) {
nc = ray_list_new(gather_rows);
} else {
nc = col_vec_new(col, gather_rows);
}
if (!nc || RAY_IS_ERR(nc)) {
for (int64_t j = 0; j < c; j++)
if (new_cols[j]) ray_release(new_cols[j]);
ray_release(result);
if (sorted_keys_hdr) scratch_free(sorted_keys_hdr);
for (uint8_t k = 0; k < n_sort; k++)
if (sort_owned[k] && sort_vecs[k] && !RAY_IS_ERR(sort_vecs[k]))
ray_release(sort_vecs[k]);
ray_release(idx_vec);
return nc ? nc : ray_error("oom", NULL);
}
if (col->type == RAY_LIST) {
ray_t** src_ptrs = (ray_t**)ray_data(col);
ray_t** dst_ptrs = (ray_t**)ray_data(nc);
for (int64_t r = 0; r < gather_rows; r++) {
dst_ptrs[r] = src_ptrs[sorted_idx[r]];
if (dst_ptrs[r]) ray_retain(dst_ptrs[r]);
}
}
nc->len = gather_rows;
new_cols[c] = nc;
}
int64_t sort_key_sym = -1;
if (sorted_keys && n_sort == 1 && !RAY_IS_SYM(sort_vecs[0]->type)) {
ray_op_ext_t* key_ext = find_ext(g, ext->sort.columns[0]->id);
if (key_ext && key_ext->base.opcode == OP_SCAN)
sort_key_sym = key_ext->sym;
}
int64_t decode_col_idx = -1;
if (sort_key_sym >= 0) {
for (int64_t c = 0; c < ncols; c++) {
if (col_names[c] == sort_key_sym && new_cols[c]) {
decode_col_idx = c;
break;
}
}
}
if (decode_col_idx >= 0) {
radix_decode_into(ray_data(new_cols[decode_col_idx]),
sort_vecs[0]->type, sorted_keys,
gather_rows, ext->sort.desc ? ext->sort.desc[0] : 0);
}
for (int64_t base = 0; base < ncols; ) {
char* g_srcs[MGATHER_MAX_COLS];
char* g_dsts[MGATHER_MAX_COLS];
uint8_t g_esz[MGATHER_MAX_COLS];
int64_t g_nc = 0;
for (; base < ncols && g_nc < MGATHER_MAX_COLS; base++) {
if (!new_cols[base] || base == decode_col_idx) continue;
ray_t* col = ray_table_get_col_idx(tbl, base);
if (col->type == RAY_LIST) continue;
g_srcs[g_nc] = (char*)ray_data(col);
g_dsts[g_nc] = (char*)ray_data(new_cols[base]);
g_esz[g_nc] = col_esz(col);
g_nc++;
}
if (g_nc == 0) continue;
if (n_sort == 1)
partitioned_gather(gather_pool, sorted_idx, gather_rows,
nrows, g_srcs, g_dsts, g_esz, g_nc);
else {
multi_gather_ctx_t mg = { .idx = sorted_idx, .ncols = g_nc };
for (int64_t i = 0; i < g_nc; i++) {
mg.srcs[i] = g_srcs[i];
mg.dsts[i] = g_dsts[i];
mg.esz[i] = g_esz[i];
}
if (gather_pool)
ray_pool_dispatch(gather_pool, multi_gather_fn, &mg,
gather_rows);
else
multi_gather_fn(&mg, 0, 0, gather_rows);
}
}
for (int64_t c = 0; c < ncols; c++) {
if (!new_cols[c]) continue;
ray_t* col = ray_table_get_col_idx(tbl, c);
if (!col) continue;
col_propagate_str_pool(new_cols[c], col);
if (col->type == RAY_SYM &&
!(col->attrs & RAY_ATTR_SLICE) &&
(!(col->attrs & RAY_ATTR_HAS_NULLS) || (col->attrs & RAY_ATTR_NULLMAP_EXT)) &&
col->sym_dict) {
ray_retain(col->sym_dict);
new_cols[c]->sym_dict = col->sym_dict;
}
bool src_has_nulls = (col->attrs & RAY_ATTR_HAS_NULLS) ||
((col->attrs & RAY_ATTR_SLICE) && col->slice_parent &&
(col->slice_parent->attrs & RAY_ATTR_HAS_NULLS));
if (src_has_nulls) {
for (int64_t r = 0; r < gather_rows; r++)
if (ray_vec_is_null(col, sorted_idx[r]))
ray_vec_set_null(new_cols[c], r, true);
}
}
for (int64_t c = 0; c < ncols; c++) {
if (!new_cols[c]) continue;
result = ray_table_add_col(result, col_names[c], new_cols[c]);
ray_release(new_cols[c]);
}
if (sorted_keys_hdr) scratch_free(sorted_keys_hdr);
for (uint8_t k = 0; k < n_sort; k++) {
if (sort_owned[k] && sort_vecs[k] && !RAY_IS_ERR(sort_vecs[k]))
ray_release(sort_vecs[k]);
}
ray_release(idx_vec);
return result;
}
ray_t* ray_asc_fn(ray_t* x) {
if (!x || RAY_IS_ERR(x)) return x;
if (ray_is_atom(x)) { ray_retain(x); return x; }
if (!ray_is_vec(x)) return ray_error("type", "asc expects a vector");
int64_t n = ray_len(x);
if (n <= 1) { ray_retain(x); return x; }
uint8_t desc = 0;
return ray_sort(&x, &desc, NULL, 1, n);
}
ray_t* ray_desc_fn(ray_t* x) {
if (!x || RAY_IS_ERR(x)) return x;
if (ray_is_atom(x)) { ray_retain(x); return x; }
if (!ray_is_vec(x)) return ray_error("type", "desc expects a vector");
int64_t n = ray_len(x);
if (n <= 1) { ray_retain(x); return x; }
uint8_t desc = 1;
return ray_sort(&x, &desc, NULL, 1, n);
}
ray_t* ray_iasc_fn(ray_t* x) {
if (!x || RAY_IS_ERR(x)) return x;
if (!ray_is_vec(x)) return ray_error("type", "iasc expects a vector");
int64_t n = ray_len(x);
uint8_t desc = 0;
return ray_sort_indices(&x, &desc, NULL, 1, n);
}
ray_t* ray_idesc_fn(ray_t* x) {
if (!x || RAY_IS_ERR(x)) return x;
if (!ray_is_vec(x)) return ray_error("type", "idesc expects a vector");
int64_t n = ray_len(x);
uint8_t desc = 1;
return ray_sort_indices(&x, &desc, NULL, 1, n);
}
ray_t* ray_rank_fn(ray_t* x) {
if (!x || RAY_IS_ERR(x)) return x;
if (!ray_is_vec(x)) return ray_error("type", "rank expects a vector");
int64_t n = ray_len(x);
uint8_t desc = 0;
ray_t* idx = ray_sort_indices(&x, &desc, NULL, 1, n);
if (RAY_IS_ERR(idx)) return idx;
ray_t* result = ray_vec_new(RAY_I64, n);
if (RAY_IS_ERR(result)) { ray_release(idx); return result; }
result->len = n;
int64_t* idx_data = (int64_t*)ray_data(idx);
int64_t* rank_data = (int64_t*)ray_data(result);
for (int64_t i = 0; i < n; i++)
rank_data[idx_data[i]] = i;
ray_release(idx);
return result;
}
ray_t* sort_table_by_keys(ray_t* tbl, ray_t* keys, uint8_t descending) {
if (!tbl || tbl->type != RAY_TABLE)
return ray_error("type", "xasc/xdesc expects a table as first argument");
int64_t n_keys = 0;
int64_t key_ids[16];
if (keys->type == -RAY_SYM) {
key_ids[0] = keys->i64;
n_keys = 1;
} else if (keys->type == RAY_SYM) {
int64_t* syms = (int64_t*)ray_data(keys);
n_keys = ray_len(keys);
if (n_keys > 16) return ray_error("limit", "xasc/xdesc: max 16 key columns");
for (int64_t i = 0; i < n_keys; i++) key_ids[i] = syms[i];
} else if (is_list(keys)) {
ray_t** elems = (ray_t**)ray_data(keys);
n_keys = ray_len(keys);
if (n_keys > 16) return ray_error("limit", "xasc/xdesc: max 16 key columns");
for (int64_t i = 0; i < n_keys; i++) {
if (elems[i]->type != -RAY_SYM)
return ray_error("type", "xasc/xdesc key must be a symbol");
key_ids[i] = elems[i]->i64;
}
} else {
return ray_error("type", "xasc/xdesc key must be a symbol or list of symbols");
}
if (n_keys == 0) { ray_retain(tbl); return tbl; }
int64_t nrows = ray_table_nrows(tbl);
if (nrows <= 1) { ray_retain(tbl); return tbl; }
ray_t* key_cols[16];
for (int64_t i = 0; i < n_keys; i++) {
key_cols[i] = ray_table_get_col(tbl, key_ids[i]);
if (!key_cols[i])
return ray_error("domain", "xasc/xdesc: key column not found in table");
}
uint8_t descs[16];
for (int64_t i = 0; i < n_keys; i++) descs[i] = descending;
uint64_t* sorted_keys = NULL;
ray_t* sorted_keys_hdr = NULL;
ray_t* idx = sort_indices_ex(key_cols, descs, NULL, (uint8_t)n_keys, nrows,
&sorted_keys, &sorted_keys_hdr);
if (RAY_IS_ERR(idx)) {
if (sorted_keys_hdr) scratch_free(sorted_keys_hdr);
return idx;
}
int64_t* idx_data = (int64_t*)ray_data(idx);
int64_t ncols = ray_table_ncols(tbl);
ray_pool_t* gather_pool = (nrows > RAY_PARALLEL_THRESHOLD)
? ray_pool_get() : NULL;
ray_t* nc_hdr = NULL;
ray_t** new_cols = (ray_t**)scratch_alloc(&nc_hdr,
(size_t)ncols * sizeof(ray_t*));
ray_t* cn_hdr = NULL;
int64_t* col_names = (int64_t*)scratch_alloc(&cn_hdr,
(size_t)ncols * sizeof(int64_t));
if (!new_cols || !col_names) {
if (nc_hdr) scratch_free(nc_hdr);
if (cn_hdr) scratch_free(cn_hdr);
if (sorted_keys_hdr) scratch_free(sorted_keys_hdr);
ray_release(idx);
return ray_error("oom", NULL);
}
for (int64_t c = 0; c < ncols; c++) new_cols[c] = NULL;
for (int64_t c = 0; c < ncols; c++) {
ray_t* col = ray_table_get_col_idx(tbl, c);
col_names[c] = ray_table_col_name(tbl, c);
if (!col) continue;
ray_t* nc;
if (col->type == RAY_LIST)
nc = ray_list_new(nrows);
else
nc = col_vec_new(col, nrows);
if (!nc || RAY_IS_ERR(nc)) {
for (int64_t j = 0; j < c; j++)
if (new_cols[j]) ray_release(new_cols[j]);
scratch_free(nc_hdr);
scratch_free(cn_hdr);
if (sorted_keys_hdr) scratch_free(sorted_keys_hdr);
ray_release(idx);
return nc ? nc : ray_error("oom", NULL);
}
if (col->type == RAY_LIST) {
ray_t** src_ptrs = (ray_t**)ray_data(col);
ray_t** dst_ptrs = (ray_t**)ray_data(nc);
for (int64_t r = 0; r < nrows; r++) {
dst_ptrs[r] = src_ptrs[idx_data[r]];
if (dst_ptrs[r]) ray_retain(dst_ptrs[r]);
}
}
nc->len = nrows;
new_cols[c] = nc;
}
int64_t decode_col_idx = -1;
if (sorted_keys && n_keys == 1 && !RAY_IS_SYM(key_cols[0]->type)) {
for (int64_t c = 0; c < ncols; c++) {
if (col_names[c] == key_ids[0] && new_cols[c]) {
decode_col_idx = c;
break;
}
}
}
if (decode_col_idx >= 0) {
radix_decode_into(ray_data(new_cols[decode_col_idx]),
key_cols[0]->type, sorted_keys,
nrows, descs[0]);
}
for (int64_t base = 0; base < ncols; ) {
char* g_srcs[MGATHER_MAX_COLS];
char* g_dsts[MGATHER_MAX_COLS];
uint8_t g_esz[MGATHER_MAX_COLS];
int64_t g_nc = 0;
for (; base < ncols && g_nc < MGATHER_MAX_COLS; base++) {
if (!new_cols[base] || base == decode_col_idx) continue;
ray_t* col = ray_table_get_col_idx(tbl, base);
if (col->type == RAY_LIST) continue;
g_srcs[g_nc] = (char*)ray_data(col);
g_dsts[g_nc] = (char*)ray_data(new_cols[base]);
g_esz[g_nc] = col_esz(col);
g_nc++;
}
if (g_nc == 0) continue;
if (n_keys == 1)
partitioned_gather(gather_pool, idx_data, nrows,
nrows, g_srcs, g_dsts, g_esz, g_nc);
else {
multi_gather_ctx_t mg = { .idx = idx_data, .ncols = g_nc };
for (int64_t i = 0; i < g_nc; i++) {
mg.srcs[i] = g_srcs[i];
mg.dsts[i] = g_dsts[i];
mg.esz[i] = g_esz[i];
}
if (gather_pool)
ray_pool_dispatch(gather_pool, multi_gather_fn, &mg, nrows);
else
multi_gather_fn(&mg, 0, 0, nrows);
}
}
for (int64_t c = 0; c < ncols; c++) {
if (!new_cols[c]) continue;
ray_t* col = ray_table_get_col_idx(tbl, c);
if (!col) continue;
col_propagate_str_pool(new_cols[c], col);
if (col->type == RAY_SYM &&
!(col->attrs & RAY_ATTR_SLICE) &&
(!(col->attrs & RAY_ATTR_HAS_NULLS) || (col->attrs & RAY_ATTR_NULLMAP_EXT)) &&
col->sym_dict) {
ray_retain(col->sym_dict);
new_cols[c]->sym_dict = col->sym_dict;
}
bool src_has_nulls = (col->attrs & RAY_ATTR_HAS_NULLS) ||
((col->attrs & RAY_ATTR_SLICE) && col->slice_parent &&
(col->slice_parent->attrs & RAY_ATTR_HAS_NULLS));
if (src_has_nulls) {
for (int64_t r = 0; r < nrows; r++)
if (ray_vec_is_null(col, idx_data[r]))
ray_vec_set_null(new_cols[c], r, true);
}
}
ray_t* result = ray_table_new(ncols);
if (!result || RAY_IS_ERR(result)) {
for (int64_t c = 0; c < ncols; c++)
if (new_cols[c]) ray_release(new_cols[c]);
scratch_free(nc_hdr);
scratch_free(cn_hdr);
if (sorted_keys_hdr) scratch_free(sorted_keys_hdr);
ray_release(idx);
return result ? result : ray_error("oom", NULL);
}
for (int64_t c = 0; c < ncols; c++) {
if (!new_cols[c]) continue;
result = ray_table_add_col(result, col_names[c], new_cols[c]);
ray_release(new_cols[c]);
}
scratch_free(nc_hdr);
scratch_free(cn_hdr);
if (sorted_keys_hdr) scratch_free(sorted_keys_hdr);
ray_release(idx);
return result;
}
ray_t* ray_xasc_fn(ray_t* tbl, ray_t* keys) {
return sort_table_by_keys(tbl, keys, 0);
}
ray_t* ray_xdesc_fn(ray_t* tbl, ray_t* keys) {
return sort_table_by_keys(tbl, keys, 1);
}
ray_t* ray_xrank_fn(ray_t* n_obj, ray_t* vec) {
if (!is_numeric(n_obj))
return ray_error("type", "xrank: first arg must be integer");
if (!ray_is_vec(vec))
return ray_error("type", "xrank: second arg must be a vector");
int64_t n_groups = as_i64(n_obj);
int64_t len = ray_len(vec);
if (n_groups <= 0 || len == 0) return ray_vec_new(RAY_I64, 0);
uint8_t desc = 0;
ray_t* idx = ray_sort_indices(&vec, &desc, NULL, 1, len);
if (!idx || RAY_IS_ERR(idx)) return idx ? idx : ray_error("oom", NULL);
ray_t* result = ray_vec_new(RAY_I64, len);
if (!result || RAY_IS_ERR(result)) { ray_release(idx); return result ? result : ray_error("oom", NULL); }
result->len = len;
const int64_t* idx_data = (const int64_t*)ray_data(idx);
int64_t* out = (int64_t*)ray_data(result);
for (int64_t i = 0; i < len; i++)
out[idx_data[i]] = i * n_groups / len;
ray_release(idx);
return result;
}