#include "ops/internal.h"
#include "lang/internal.h"
#include "mem/sys.h"
static bool rayvec_is_numeric(ray_t* v) {
if (!v || !ray_is_vec(v)) return false;
return v->type == RAY_F32 || v->type == RAY_F64
|| v->type == RAY_I32 || v->type == RAY_I64;
}
static double rayvec_at_f64(ray_t* v, int64_t i) {
void* d = ray_data(v);
switch (v->type) {
case RAY_F32: return (double)((float*)d)[i];
case RAY_F64: return ((double*)d)[i];
case RAY_I32: return (double)((int32_t*)d)[i];
case RAY_I64: return (double)((int64_t*)d)[i];
default: return 0.0;
}
}
static void rayvec_to_floats(ray_t* v, float* dst, int32_t dim) {
if (v->type == RAY_F32) {
memcpy(dst, ray_data(v), (size_t)dim * sizeof(float));
return;
}
for (int32_t i = 0; i < dim; i++) dst[i] = (float)rayvec_at_f64(v, i);
}
static int list_vec_validate(ray_t* list, int32_t* out_dim) {
if (!list || list->type != RAY_LIST) return 1;
if (list->len <= 0) { *out_dim = 0; return 0; }
ray_t* first = ray_list_get(list, 0);
if (!rayvec_is_numeric(first) || first->len <= 0) return 2;
int32_t dim = (int32_t)first->len;
for (int64_t i = 1; i < list->len; i++) {
ray_t* e = ray_list_get(list, i);
if (!rayvec_is_numeric(e) || e->len != dim) return 3;
}
*out_dim = dim;
return 0;
}
static float* list_flatten_floats(ray_t* list, int32_t dim, int64_t* out_n) {
int64_t n = list->len;
*out_n = n;
if (n == 0) return NULL;
float* buf = (float*)ray_sys_alloc((size_t)n * (size_t)dim * sizeof(float));
if (!buf) return NULL;
for (int64_t i = 0; i < n; i++) {
ray_t* e = ray_list_get(list, i);
rayvec_to_floats(e, buf + i * dim, dim);
}
return buf;
}
typedef enum { MET_COS_DIST, MET_INNER_PROD, MET_L2_DIST } metric_kind_t;
static double row_score(metric_kind_t k, ray_t* row,
const double* q, double q_norm, int32_t dim) {
double acc = 0.0, r_norm_sq = 0.0;
if (k == MET_L2_DIST) {
for (int32_t j = 0; j < dim; j++) {
double d = rayvec_at_f64(row, j) - q[j];
acc += d * d;
}
return sqrt(acc);
}
for (int32_t j = 0; j < dim; j++) {
double a = rayvec_at_f64(row, j);
acc += a * q[j];
if (k == MET_COS_DIST) r_norm_sq += a * a;
}
if (k == MET_INNER_PROD) return acc;
double denom = q_norm * sqrt(r_norm_sq);
double sim = (denom > 0.0) ? acc / denom : 0.0;
return 1.0 - sim;
}
static double* query_to_doubles(ray_t* q, int32_t dim, double* q_norm_out) {
double* buf = (double*)ray_sys_alloc((size_t)dim * sizeof(double));
if (!buf) return NULL;
double ns = 0.0;
for (int32_t j = 0; j < dim; j++) {
buf[j] = rayvec_at_f64(q, j);
ns += buf[j] * buf[j];
}
*q_norm_out = sqrt(ns);
return buf;
}
static ray_t* vec_binary_metric(metric_kind_t kind, ray_t* a, ray_t* b) {
if (!a || !b) return ray_error("type", NULL);
ray_t* list = NULL;
ray_t* query = NULL;
if (a->type == RAY_LIST && rayvec_is_numeric(b)) { list = a; query = b; }
else if (b->type == RAY_LIST && rayvec_is_numeric(a)) { list = b; query = a; }
if (list) {
int32_t dim;
if (list_vec_validate(list, &dim) != 0) return ray_error("type", NULL);
if (query->len != dim) return ray_error("length", NULL);
double q_norm;
double* q = query_to_doubles(query, dim, &q_norm);
if (!q) return ray_error("oom", NULL);
int64_t n = list->len;
ray_t* result = ray_vec_new(RAY_F64, n);
if (!result || RAY_IS_ERR(result)) { ray_sys_free(q); return ray_error("oom", NULL); }
result->len = n;
double* out = (double*)ray_data(result);
for (int64_t i = 0; i < n; i++) {
ray_t* row = ray_list_get(list, i);
out[i] = row_score(kind, row, q, q_norm, dim);
}
ray_sys_free(q);
return result;
}
if (!rayvec_is_numeric(a) || !rayvec_is_numeric(b)) return ray_error("type", NULL);
if (a->len != b->len || a->len <= 0) return ray_error("length", NULL);
int32_t dim = (int32_t)a->len;
double q_norm;
double* q = query_to_doubles(b, dim, &q_norm);
if (!q) return ray_error("oom", NULL);
double v = row_score(kind, a, q, q_norm, dim);
ray_sys_free(q);
return make_f64(v);
}
ray_t* ray_cos_dist_fn (ray_t* a, ray_t* b) { return vec_binary_metric(MET_COS_DIST, a, b); }
ray_t* ray_inner_prod_fn (ray_t* a, ray_t* b) { return vec_binary_metric(MET_INNER_PROD, a, b); }
ray_t* ray_l2_dist_fn (ray_t* a, ray_t* b) { return vec_binary_metric(MET_L2_DIST, a, b); }
ray_t* ray_norm_fn(ray_t* x) {
if (!x) return ray_error("type", NULL);
if (x->type == RAY_LIST) {
int32_t dim;
if (list_vec_validate(x, &dim) != 0) return ray_error("type", NULL);
int64_t n = x->len;
ray_t* result = ray_vec_new(RAY_F64, n);
if (!result || RAY_IS_ERR(result)) return ray_error("oom", NULL);
result->len = n;
double* out = (double*)ray_data(result);
for (int64_t i = 0; i < n; i++) {
ray_t* v = ray_list_get(x, i);
double s = 0.0;
for (int32_t j = 0; j < dim; j++) {
double e = rayvec_at_f64(v, j);
s += e * e;
}
out[i] = sqrt(s);
}
return result;
}
if (!rayvec_is_numeric(x)) return ray_error("type", NULL);
double s = 0.0;
for (int64_t i = 0; i < x->len; i++) {
double e = rayvec_at_f64(x, i);
s += e * e;
}
return make_f64(sqrt(s));
}
static int parse_metric_sym(ray_t* s, ray_hnsw_metric_t* out) {
if (!s || s->type != -RAY_SYM) return 0;
int64_t id = s->i64;
if (id == ray_sym_find("cosine", 6)) { *out = RAY_HNSW_COSINE; return 1; }
if (id == ray_sym_find("l2", 2)) { *out = RAY_HNSW_L2; return 1; }
if (id == ray_sym_find("ip", 2)) { *out = RAY_HNSW_IP; return 1; }
return 0;
}
static int64_t atom_to_i64(ray_t* a) {
if (!a) return 0;
switch (a->type) {
case -RAY_I64: return a->i64;
case -RAY_I32: return (int64_t)a->i32;
case -RAY_I16: return (int64_t)a->i16;
default: return 0;
}
}
static bool atom_is_int(ray_t* a) {
return a && (a->type == -RAY_I64 || a->type == -RAY_I32 || a->type == -RAY_I16);
}
ray_t* ray_knn_fn(ray_t** args, int64_t n) {
if (n < 3 || n > 4) return ray_error("rank", NULL);
ray_t* col = args[0];
ray_t* query = args[1];
ray_t* katom = args[2];
if (!col || col->type != RAY_LIST) return ray_error("type", NULL);
if (!rayvec_is_numeric(query)) return ray_error("type", NULL);
if (!atom_is_int(katom)) return ray_error("type", NULL);
ray_hnsw_metric_t metric = RAY_HNSW_COSINE;
if (n == 4 && !parse_metric_sym(args[3], &metric)) return ray_error("domain", NULL);
int32_t dim;
if (list_vec_validate(col, &dim) != 0) return ray_error("type", NULL);
if (query->len != dim) return ray_error("length", NULL);
int64_t k = atom_to_i64(katom);
if (k <= 0) return ray_error("domain", NULL);
int64_t nrows = col->len;
if (k > nrows) k = nrows;
if (nrows == 0) {
ray_t* rv = ray_vec_new(RAY_I64, 0);
ray_t* dv = ray_vec_new(RAY_F64, 0);
ray_t* tbl = ray_table_new(2);
tbl = ray_table_add_col(tbl, sym_intern_safe("_rowid", 6), rv);
tbl = ray_table_add_col(tbl, sym_intern_safe("_dist", 5), dv);
ray_release(rv); ray_release(dv);
return tbl;
}
double q_norm;
double* q = query_to_doubles(query, dim, &q_norm);
if (!q) return ray_error("oom", NULL);
typedef struct { double d; int64_t id; } ent_t;
ent_t* heap = (ent_t*)ray_sys_alloc((size_t)k * sizeof(ent_t));
if (!heap) { ray_sys_free(q); return ray_error("oom", NULL); }
int64_t hsz = 0;
for (int64_t i = 0; i < nrows; i++) {
ray_t* row = ray_list_get(col, i);
double d;
switch (metric) {
case RAY_HNSW_L2:
d = row_score(MET_L2_DIST, row, q, q_norm, dim);
break;
case RAY_HNSW_IP:
d = -row_score(MET_INNER_PROD, row, q, q_norm, dim);
break;
case RAY_HNSW_COSINE:
default:
d = row_score(MET_COS_DIST, row, q, q_norm, dim);
break;
}
if (hsz < k) {
int64_t j = hsz++;
heap[j] = (ent_t){ d, i };
while (j > 0) {
int64_t p = (j - 1) / 2;
if (heap[p].d >= heap[j].d) break;
ent_t t = heap[p]; heap[p] = heap[j]; heap[j] = t;
j = p;
}
} else if (d < heap[0].d) {
heap[0] = (ent_t){ d, i };
int64_t j = 0;
for (;;) {
int64_t l = 2*j+1, r = 2*j+2, best = j;
if (l < hsz && heap[l].d > heap[best].d) best = l;
if (r < hsz && heap[r].d > heap[best].d) best = r;
if (best == j) break;
ent_t t = heap[j]; heap[j] = heap[best]; heap[best] = t;
j = best;
}
}
}
ray_sys_free(q);
for (int64_t i = 1; i < hsz; i++) {
ent_t key = heap[i];
int64_t j = i - 1;
while (j >= 0 && heap[j].d > key.d) {
heap[j + 1] = heap[j];
j--;
}
heap[j + 1] = key;
}
ray_t* rv = ray_vec_new(RAY_I64, hsz);
ray_t* dv = ray_vec_new(RAY_F64, hsz);
if (!rv || RAY_IS_ERR(rv) || !dv || RAY_IS_ERR(dv)) {
ray_sys_free(heap);
if (rv && !RAY_IS_ERR(rv)) ray_release(rv);
if (dv && !RAY_IS_ERR(dv)) ray_release(dv);
return ray_error("oom", NULL);
}
int64_t* rd = (int64_t*)ray_data(rv);
double* dd = (double*)ray_data(dv);
for (int64_t i = 0; i < hsz; i++) { rd[i] = heap[i].id; dd[i] = heap[i].d; }
rv->len = hsz;
dv->len = hsz;
ray_sys_free(heap);
ray_t* tbl = ray_table_new(2);
if (!tbl || RAY_IS_ERR(tbl)) { ray_release(rv); ray_release(dv); return ray_error("oom", NULL); }
tbl = ray_table_add_col(tbl, sym_intern_safe("_rowid", 6), rv);
ray_release(rv);
tbl = ray_table_add_col(tbl, sym_intern_safe("_dist", 5), dv);
ray_release(dv);
return tbl;
}
static ray_hnsw_t* hnsw_unwrap(ray_t* h) {
if (!h) return NULL;
if (h->type != -RAY_I64) return NULL;
if (!(h->attrs & RAY_ATTR_HNSW)) return NULL;
return (ray_hnsw_t*)(uintptr_t)h->i64;
}
static ray_t* hnsw_wrap(ray_hnsw_t* idx) {
ray_t* h = ray_alloc(0);
if (!h || RAY_IS_ERR(h)) return h ? h : ray_error("oom", NULL);
h->type = -RAY_I64;
h->attrs |= RAY_ATTR_HNSW;
h->i64 = (int64_t)(uintptr_t)idx;
return h;
}
ray_t* ray_hnsw_build_fn(ray_t** args, int64_t n) {
if (n < 1 || n > 4) return ray_error("rank", NULL);
ray_t* col = args[0];
if (!col || col->type != RAY_LIST) return ray_error("type", NULL);
ray_hnsw_metric_t metric = RAY_HNSW_COSINE;
if (n >= 2 && !parse_metric_sym(args[1], &metric)) return ray_error("domain", NULL);
int32_t M = HNSW_DEFAULT_M;
if (n >= 3) {
if (!atom_is_int(args[2])) return ray_error("type", NULL);
int64_t v = atom_to_i64(args[2]);
if (v > 0 && v <= 512) M = (int32_t)v;
}
int32_t ef_c = HNSW_DEFAULT_EF_C;
if (n >= 4) {
if (!atom_is_int(args[3])) return ray_error("type", NULL);
int64_t v = atom_to_i64(args[3]);
if (v > 0 && v <= 4096) ef_c = (int32_t)v;
}
int32_t dim;
if (list_vec_validate(col, &dim) != 0) return ray_error("type", NULL);
if (dim <= 0) return ray_error("length", NULL);
int64_t n_rows;
float* flat = list_flatten_floats(col, dim, &n_rows);
if (!flat && n_rows > 0) return ray_error("oom", NULL);
ray_hnsw_t* idx = ray_hnsw_build(flat, n_rows, dim, metric, M, ef_c);
if (flat) ray_sys_free(flat);
if (!idx) return ray_error("oom", NULL);
ray_t* h = hnsw_wrap(idx);
if (!h || RAY_IS_ERR(h)) { ray_hnsw_free(idx); return h; }
return h;
}
ray_t* ray_ann_fn(ray_t** args, int64_t n) {
if (n < 3 || n > 4) return ray_error("rank", NULL);
ray_hnsw_t* idx = hnsw_unwrap(args[0]);
if (!idx) return ray_error("type", NULL);
if (!rayvec_is_numeric(args[1])) return ray_error("type", NULL);
if (!atom_is_int(args[2])) return ray_error("type", NULL);
int32_t dim = idx->dim;
if (args[1]->len != dim) return ray_error("length", NULL);
int64_t k = atom_to_i64(args[2]);
if (k <= 0) return ray_error("domain", NULL);
int32_t ef = (int32_t)k;
if (ef < HNSW_DEFAULT_EF_S) ef = HNSW_DEFAULT_EF_S;
if (n == 4) {
if (!atom_is_int(args[3])) return ray_error("type", NULL);
int64_t v = atom_to_i64(args[3]);
if (v > 0 && v <= 4096) ef = (int32_t)v;
}
float* qbuf = (float*)ray_sys_alloc((size_t)dim * sizeof(float));
if (!qbuf) return ray_error("oom", NULL);
rayvec_to_floats(args[1], qbuf, dim);
int64_t* out_ids = (int64_t*)ray_sys_alloc((size_t)k * sizeof(int64_t));
double* out_ds = (double*)ray_sys_alloc((size_t)k * sizeof(double));
if (!out_ids || !out_ds) {
ray_sys_free(qbuf);
if (out_ids) ray_sys_free(out_ids);
if (out_ds) ray_sys_free(out_ds);
return ray_error("oom", NULL);
}
int64_t found = ray_hnsw_search(idx, qbuf, dim, k, ef, out_ids, out_ds);
if (found < 0) {
ray_sys_free(qbuf); ray_sys_free(out_ids); ray_sys_free(out_ds);
return ray_error("oom", NULL);
}
ray_t* rv = ray_vec_new(RAY_I64, found);
ray_t* dv = ray_vec_new(RAY_F64, found);
if (!rv || RAY_IS_ERR(rv) || !dv || RAY_IS_ERR(dv)) {
ray_sys_free(qbuf); ray_sys_free(out_ids); ray_sys_free(out_ds);
if (rv && !RAY_IS_ERR(rv)) ray_release(rv);
if (dv && !RAY_IS_ERR(dv)) ray_release(dv);
return ray_error("oom", NULL);
}
int64_t* rd = (int64_t*)ray_data(rv);
double* dd = (double*)ray_data(dv);
for (int64_t i = 0; i < found; i++) { rd[i] = out_ids[i]; dd[i] = out_ds[i]; }
rv->len = found;
dv->len = found;
ray_sys_free(qbuf); ray_sys_free(out_ids); ray_sys_free(out_ds);
ray_t* tbl = ray_table_new(2);
if (!tbl || RAY_IS_ERR(tbl)) { ray_release(rv); ray_release(dv); return ray_error("oom", NULL); }
tbl = ray_table_add_col(tbl, sym_intern_safe("_rowid", 6), rv);
ray_release(rv);
tbl = ray_table_add_col(tbl, sym_intern_safe("_dist", 5), dv);
ray_release(dv);
return tbl;
}
ray_t* ray_hnsw_free_fn(ray_t* h) {
ray_hnsw_t* idx = hnsw_unwrap(h);
if (!idx) return ray_error("type", NULL);
ray_hnsw_free(idx);
h->i64 = 0;
h->attrs &= ~RAY_ATTR_HNSW;
return RAY_NULL_OBJ;
}
ray_t* ray_hnsw_save_fn(ray_t* h, ray_t* path) {
ray_hnsw_t* idx = hnsw_unwrap(h);
if (!idx) return ray_error("type", NULL);
if (!path || path->type != -RAY_STR) return ray_error("type", NULL);
const char* p = ray_str_ptr(path);
size_t len = ray_str_len(path);
if (!p || len == 0 || len >= 1023) return ray_error("domain", NULL);
char buf[1024];
memcpy(buf, p, len);
buf[len] = '\0';
ray_err_t err = ray_hnsw_save(idx, buf);
if (err != RAY_OK) return ray_error("io", NULL);
return RAY_NULL_OBJ;
}
ray_t* ray_hnsw_load_fn(ray_t* path) {
if (!path || path->type != -RAY_STR) return ray_error("type", NULL);
const char* p = ray_str_ptr(path);
size_t len = ray_str_len(path);
if (!p || len == 0 || len >= 1023) return ray_error("domain", NULL);
char buf[1024];
memcpy(buf, p, len);
buf[len] = '\0';
ray_hnsw_t* idx = ray_hnsw_load(buf);
if (!idx) return ray_error("io", NULL);
ray_t* h = hnsw_wrap(idx);
if (!h || RAY_IS_ERR(h)) { ray_hnsw_free(idx); return h; }
return h;
}
ray_t* ray_hnsw_info_fn(ray_t* h) {
ray_hnsw_t* idx = hnsw_unwrap(h);
if (!idx) return ray_error("type", NULL);
const char* mname = "cosine";
switch ((ray_hnsw_metric_t)idx->metric) {
case RAY_HNSW_L2: mname = "l2"; break;
case RAY_HNSW_IP: mname = "ip"; break;
default: break;
}
ray_t* keys = ray_sym_vec_new(RAY_SYM_W64, 6);
if (RAY_IS_ERR(keys)) return keys;
ray_t* vals = ray_list_new(6);
if (RAY_IS_ERR(vals)) { ray_release(keys); return vals; }
struct { const char* name; size_t nlen; ray_t* val; } rows[] = {
{ "nrows", 5, make_i64(idx->n_nodes) },
{ "dim", 3, make_i64((int64_t)idx->dim) },
{ "metric", 6, ray_sym(sym_intern_safe(mname, strlen(mname))) },
{ "nlayers", 7, make_i64((int64_t)idx->n_layers) },
{ "M", 1, make_i64((int64_t)idx->M) },
{ "efc", 3, make_i64((int64_t)idx->ef_construction) },
};
for (size_t i = 0; i < sizeof(rows)/sizeof(rows[0]); i++) {
int64_t s = sym_intern_safe(rows[i].name, rows[i].nlen);
keys = ray_vec_append(keys, &s);
vals = ray_list_append(vals, rows[i].val);
ray_release(rows[i].val);
}
return ray_dict_new(keys, vals);
}