#include <stdlib.h>
#include <string.h>
#include <strings.h>
#include <stdio.h>
#include <emmintrin.h>
#include <assert.h>
#include "libpmemcto.h"
#include "art.h"
#define IS_LEAF(x) (((uintptr_t)x & 1))
#define SET_LEAF(x) ((void *)((uintptr_t)x | 1))
#define LEAF_RAW(x) ((void *)((uintptr_t)x & ~1))
static art_node *
alloc_node(PMEMctopool *pcp, uint8_t type)
{
art_node *n;
switch (type) {
case NODE4:
n = pmemcto_calloc(pcp, 1, sizeof(art_node4));
break;
case NODE16:
n = pmemcto_calloc(pcp, 1, sizeof(art_node16));
break;
case NODE48:
n = pmemcto_calloc(pcp, 1, sizeof(art_node48));
break;
case NODE256:
n = pmemcto_calloc(pcp, 1, sizeof(art_node256));
break;
default:
abort();
}
assert(n != NULL);
n->type = type;
return n;
}
int
art_tree_init(art_tree *t)
{
t->root = NULL;
t->size = 0;
return 0;
}
static void
destroy_node(PMEMctopool *pcp, art_node *n)
{
if (!n)
return;
if (IS_LEAF(n)) {
pmemcto_free(pcp, LEAF_RAW(n));
return;
}
int i;
union {
art_node4 *p1;
art_node16 *p2;
art_node48 *p3;
art_node256 *p4;
} p;
switch (n->type) {
case NODE4:
p.p1 = (art_node4 *)n;
for (i = 0; i < n->num_children; i++) {
destroy_node(pcp, p.p1->children[i]);
}
break;
case NODE16:
p.p2 = (art_node16 *)n;
for (i = 0; i < n->num_children; i++) {
destroy_node(pcp, p.p2->children[i]);
}
break;
case NODE48:
p.p3 = (art_node48 *)n;
for (i = 0; i < n->num_children; i++) {
destroy_node(pcp, p.p3->children[i]);
}
break;
case NODE256:
p.p4 = (art_node256 *)n;
for (i = 0; i < 256; i++) {
if (p.p4->children[i])
destroy_node(pcp, p.p4->children[i]);
}
break;
default:
abort();
}
pmemcto_free(pcp, n);
}
int
art_tree_destroy(PMEMctopool *pcp, art_tree *t)
{
destroy_node(pcp, t->root);
return 0;
}
static art_node **
find_child(art_node *n, unsigned char c)
{
__m128i cmp;
int i, mask, bitfield;
union {
art_node4 *p1;
art_node16 *p2;
art_node48 *p3;
art_node256 *p4;
} p;
switch (n->type) {
case NODE4:
p.p1 = (art_node4 *)n;
for (i = 0; i < n->num_children; i++) {
if (p.p1->keys[i] == c)
return &p.p1->children[i];
}
break;
case NODE16:
p.p2 = (art_node16 *)n;
cmp = _mm_cmpeq_epi8(_mm_set1_epi8(c),
_mm_loadu_si128((__m128i *)p.p2->keys));
mask = (1 << n->num_children) - 1;
bitfield = _mm_movemask_epi8(cmp) & mask;
if (bitfield)
return &p.p2->children[__builtin_ctz(bitfield)];
break;
case NODE48:
p.p3 = (art_node48 *)n;
i = p.p3->keys[c];
if (i)
return &p.p3->children[i - 1];
break;
case NODE256:
p.p4 = (art_node256 *)n;
if (p.p4->children[c])
return &p.p4->children[c];
break;
default:
abort();
}
return NULL;
}
static inline int
min(int a, int b)
{
return (a < b) ? a : b;
}
static int
check_prefix(const art_node *n, const unsigned char *key, int key_len,
int depth)
{
int max_cmp = min(min(n->partial_len, MAX_PREFIX_LEN),
key_len - depth);
int idx;
for (idx = 0; idx < max_cmp; idx++) {
if (n->partial[idx] != key[depth + idx])
return idx;
}
return idx;
}
static int
leaf_matches(const art_leaf *n, const unsigned char *key, int key_len,
int depth)
{
(void) depth;
if (n->key_len != (uint32_t)key_len)
return 1;
return memcmp(n->key, key, key_len);
}
void *
art_search(const art_tree *t, const unsigned char *key, int key_len)
{
art_node **child;
art_node *n = t->root;
int prefix_len, depth = 0;
while (n) {
if (IS_LEAF(n)) {
n = LEAF_RAW(n);
if (!leaf_matches((art_leaf *)n,
key, key_len, depth)) {
return ((art_leaf *)n)->value;
}
return NULL;
}
if (n->partial_len) {
prefix_len = check_prefix(n, key, key_len, depth);
if (prefix_len != min(MAX_PREFIX_LEN,
n->partial_len))
return NULL;
depth = depth + n->partial_len;
}
child = find_child(n, key[depth]);
n = (child) ? *child : NULL;
depth++;
}
return NULL;
}
static art_leaf *
minimum(const art_node *n)
{
if (!n)
return NULL;
if (IS_LEAF(n))
return LEAF_RAW(n);
int idx;
switch (n->type) {
case NODE4:
return minimum(((art_node4 *)n)->children[0]);
case NODE16:
return minimum(((art_node16 *)n)->children[0]);
case NODE48:
idx = 0;
while (!((art_node48 *)n)->keys[idx])
idx++;
idx = ((art_node48 *)n)->keys[idx] - 1;
assert(idx < 48);
return minimum(((art_node48 *) n)->children[idx]);
case NODE256:
idx = 0;
while (!((art_node256 *)n)->children[idx])
idx++;
return minimum(((art_node256 *)n)->children[idx]);
default:
abort();
}
}
static art_leaf *
maximum(const art_node *n)
{
if (!n)
return NULL;
if (IS_LEAF(n))
return LEAF_RAW(n);
int idx;
switch (n->type) {
case NODE4:
return maximum(
((art_node4 *)n)->children[n->num_children - 1]);
case NODE16:
return maximum(
((art_node16 *)n)->children[n->num_children - 1]);
case NODE48:
idx = 255;
while (!((art_node48 *)n)->keys[idx])
idx--;
idx = ((art_node48 *)n)->keys[idx] - 1;
assert((idx >= 0) && (idx < 48));
return maximum(((art_node48 *)n)->children[idx]);
case NODE256:
idx = 255;
while (!((art_node256 *)n)->children[idx])
idx--;
return maximum(((art_node256 *)n)->children[idx]);
default:
abort();
}
}
art_leaf *
art_minimum(art_tree *t)
{
return minimum((art_node *)t->root);
}
art_leaf *
art_maximum(art_tree *t)
{
return maximum((art_node *) t->root);
}
static art_leaf *
make_leaf(PMEMctopool *pcp, const unsigned char *key, int key_len, void *value,
int val_len)
{
art_leaf *l = pmemcto_malloc(pcp, sizeof(art_leaf) + key_len + val_len);
assert(l != NULL);
l->key_len = key_len;
l->val_len = val_len;
l->key = &(l->data[0]) + 0;
l->value = &(l->data[0]) + key_len;
memcpy(l->key, key, key_len);
memcpy(l->value, value, val_len);
return l;
}
static int
longest_common_prefix(art_leaf *l1, art_leaf *l2, int depth)
{
int max_cmp = min(l1->key_len, l2->key_len) - depth;
int idx;
for (idx = 0; idx < max_cmp; idx++) {
if (l1->key[depth + idx] != l2->key[depth + idx])
return idx;
}
return idx;
}
static void
copy_header(art_node *dest, art_node *src)
{
dest->num_children = src->num_children;
dest->partial_len = src->partial_len;
memcpy(dest->partial, src->partial,
min(MAX_PREFIX_LEN, src->partial_len));
}
static void
add_child256(PMEMctopool *pcp, art_node256 *n, art_node **ref, unsigned char c,
void *child)
{
(void) ref;
n->n.num_children++;
n->children[c] = child;
}
static void
add_child48(PMEMctopool *pcp, art_node48 *n, art_node **ref, unsigned char c,
void *child)
{
if (n->n.num_children < 48) {
int pos = 0;
while (n->children[pos])
pos++;
n->children[pos] = child;
n->keys[c] = pos + 1;
n->n.num_children++;
} else {
art_node256 *new = (art_node256 *)alloc_node(pcp, NODE256);
for (int i = 0; i < 256; i++) {
if (n->keys[i]) {
new->children[i] = n->children[n->keys[i] - 1];
}
}
copy_header((art_node *)new, (art_node *)n);
*ref = (art_node *)new;
pmemcto_free(pcp, n);
add_child256(pcp, new, ref, c, child);
}
}
static void
add_child16(PMEMctopool *pcp, art_node16 *n, art_node **ref, unsigned char c,
void *child)
{
if (n->n.num_children < 16) {
__m128i cmp;
cmp = _mm_cmplt_epi8(_mm_set1_epi8(c),
_mm_loadu_si128((__m128i *)n->keys));
unsigned mask = (1 << n->n.num_children) - 1;
unsigned bitfield = _mm_movemask_epi8(cmp) & mask;
unsigned idx;
if (bitfield) {
idx = __builtin_ctz(bitfield);
memmove(n->keys + idx + 1, n->keys + idx,
n->n.num_children - idx);
memmove(n->children + idx + 1, n->children + idx,
(n->n.num_children - idx) * sizeof(void *));
} else
idx = n->n.num_children;
n->keys[idx] = c;
n->children[idx] = child;
n->n.num_children++;
} else {
art_node48 *new = (art_node48 *)alloc_node(pcp, NODE48);
memcpy(new->children, n->children,
sizeof(void *) * n->n.num_children);
for (int i = 0; i < n->n.num_children; i++) {
new->keys[n->keys[i]] = i + 1;
}
copy_header((art_node *)new, (art_node *)n);
*ref = (art_node *) new;
pmemcto_free(pcp, n);
add_child48(pcp, new, ref, c, child);
}
}
static void
add_child4(PMEMctopool *pcp, art_node4 *n, art_node **ref, unsigned char c,
void *child)
{
if (n->n.num_children < 4) {
int idx;
for (idx = 0; idx < n->n.num_children; idx++) {
if (c < n->keys[idx])
break;
}
memmove(n->keys + idx + 1, n->keys + idx,
n->n.num_children - idx);
memmove(n->children + idx + 1, n->children + idx,
(n->n.num_children - idx) * sizeof(void *));
n->keys[idx] = c;
n->children[idx] = child;
n->n.num_children++;
} else {
art_node16 *new = (art_node16 *)alloc_node(pcp, NODE16);
memcpy(new->children, n->children,
sizeof(void *) * n->n.num_children);
memcpy(new->keys, n->keys,
sizeof(unsigned char) * n->n.num_children);
copy_header((art_node *)new, (art_node *)n);
*ref = (art_node *)new;
pmemcto_free(pcp, n);
add_child16(pcp, new, ref, c, child);
}
}
static void
add_child(PMEMctopool *pcp, art_node *n, art_node **ref, unsigned char c,
void *child)
{
switch (n->type) {
case NODE4:
return add_child4(pcp, (art_node4 *)n, ref, c, child);
case NODE16:
return add_child16(pcp, (art_node16 *)n, ref, c, child);
case NODE48:
return add_child48(pcp, (art_node48 *)n, ref, c, child);
case NODE256:
return add_child256(pcp, (art_node256 *)n, ref, c, child);
default:
abort();
}
}
static int
prefix_mismatch(const art_node *n, const unsigned char *key, int key_len,
int depth)
{
int max_cmp = min(min(MAX_PREFIX_LEN, n->partial_len), key_len - depth);
int idx;
for (idx = 0; idx < max_cmp; idx++) {
if (n->partial[idx] != key[depth + idx])
return idx;
}
if (n->partial_len > MAX_PREFIX_LEN) {
art_leaf *l = minimum(n);
assert(l != NULL);
max_cmp = min(l->key_len, key_len) - depth;
for (; idx < max_cmp; idx++) {
if (l->key[idx + depth] != key[depth + idx])
return idx;
}
}
return idx;
}
static void *
recursive_insert(PMEMctopool *pcp, art_node *n, art_node **ref,
const unsigned char *key, int key_len, void *value,
int val_len, int depth, int *old)
{
if (!n) {
*ref = (art_node *)SET_LEAF(
make_leaf(pcp, key, key_len, value, val_len));
return NULL;
}
if (IS_LEAF(n)) {
art_leaf *l = LEAF_RAW(n);
if (!leaf_matches(l, key, key_len, depth)) {
*old = 1;
void *old_val = l->value;
l->value = value;
return old_val;
}
art_node4 *new = (art_node4 *)alloc_node(pcp, NODE4);
art_leaf *l2 = make_leaf(pcp, key, key_len, value, val_len);
int longest_prefix = longest_common_prefix(l, l2, depth);
new->n.partial_len = longest_prefix;
memcpy(new->n.partial, key + depth,
min(MAX_PREFIX_LEN, longest_prefix));
*ref = (art_node *)new;
add_child4(pcp, new, ref, l->key[depth + longest_prefix],
SET_LEAF(l));
add_child4(pcp, new, ref, l2->key[depth + longest_prefix],
SET_LEAF(l2));
return NULL;
}
if (n->partial_len) {
int prefix_diff = prefix_mismatch(n, key, key_len, depth);
if ((uint32_t)prefix_diff >= n->partial_len) {
depth += n->partial_len;
goto RECURSE_SEARCH;
}
art_node4 *new = (art_node4 *)alloc_node(pcp, NODE4);
*ref = (art_node *)new;
new->n.partial_len = prefix_diff;
memcpy(new->n.partial, n->partial,
min(MAX_PREFIX_LEN, prefix_diff));
if (n->partial_len <= MAX_PREFIX_LEN) {
add_child4(pcp, new, ref,
n->partial[prefix_diff], n);
n->partial_len -= (prefix_diff + 1);
memmove(n->partial, n->partial + prefix_diff + 1,
min(MAX_PREFIX_LEN, n->partial_len));
} else {
n->partial_len -= (prefix_diff + 1);
art_leaf *l = minimum(n);
assert(l != NULL);
add_child4(pcp, new, ref,
l->key[depth + prefix_diff], n);
memcpy(n->partial, l->key + depth + prefix_diff + 1,
min(MAX_PREFIX_LEN, n->partial_len));
}
art_leaf *l = make_leaf(pcp, key, key_len, value, val_len);
add_child4(pcp, new, ref,
key[depth + prefix_diff], SET_LEAF(l));
return NULL;
}
RECURSE_SEARCH:;
art_node **child = find_child(n, key[depth]);
if (child) {
return recursive_insert(pcp, *child, child, key, key_len,
value, val_len, depth + 1, old);
}
art_leaf *l = make_leaf(pcp, key, key_len, value, val_len);
add_child(pcp, n, ref, key[depth], SET_LEAF(l));
return NULL;
}
void *
art_insert(PMEMctopool *pcp, art_tree *t, const unsigned char *key, int key_len,
void *value, int val_len)
{
int old_val = 0;
void *old = recursive_insert(pcp, t->root, &t->root, key, key_len,
value, val_len, 0, &old_val);
if (!old_val)
t->size++;
return old;
}
static void
remove_child256(PMEMctopool *pcp, art_node256 *n, art_node **ref,
unsigned char c)
{
n->children[c] = NULL;
n->n.num_children--;
if (n->n.num_children == 37) {
art_node48 *new = (art_node48 *)alloc_node(pcp, NODE48);
*ref = (art_node *) new;
copy_header((art_node *)new, (art_node *)n);
int pos = 0;
for (int i = 0; i < 256; i++) {
if (n->children[i]) {
assert(pos < 48);
new->children[pos] = n->children[i];
new->keys[i] = pos + 1;
pos++;
}
}
pmemcto_free(pcp, n);
}
}
static void
remove_child48(PMEMctopool *pcp, art_node48 *n, art_node **ref, unsigned char c)
{
int pos = n->keys[c];
n->keys[c] = 0;
n->children[pos - 1] = NULL;
n->n.num_children--;
if (n->n.num_children == 12) {
art_node16 *new = (art_node16 *)alloc_node(pcp, NODE16);
*ref = (art_node *)new;
copy_header((art_node *) new, (art_node *)n);
int child = 0;
for (int i = 0; i < 256; i++) {
pos = n->keys[i];
if (pos) {
assert(child < 16);
new->keys[child] = i;
new->children[child] = n->children[pos - 1];
child++;
}
}
pmemcto_free(pcp, n);
}
}
static void
remove_child16(PMEMctopool *pcp, art_node16 *n, art_node **ref, art_node **l)
{
int pos = l - n->children;
memmove(n->keys + pos, n->keys + pos + 1, n->n.num_children - 1 - pos);
memmove(n->children + pos, n->children + pos + 1,
(n->n.num_children - 1 - pos) * sizeof(void *));
n->n.num_children--;
if (n->n.num_children == 3) {
art_node4 *new = (art_node4 *)alloc_node(pcp, NODE4);
*ref = (art_node *) new;
copy_header((art_node *)new, (art_node *)n);
memcpy(new->keys, n->keys, 4);
memcpy(new->children, n->children, 4 * sizeof(void *));
pmemcto_free(pcp, n);
}
}
static void
remove_child4(PMEMctopool *pcp, art_node4 *n, art_node **ref, art_node **l)
{
int pos = l - n->children;
memmove(n->keys + pos, n->keys + pos + 1, n->n.num_children - 1 - pos);
memmove(n->children + pos, n->children + pos + 1,
(n->n.num_children - 1 - pos) * sizeof(void *));
n->n.num_children--;
if (n->n.num_children == 1) {
art_node *child = n->children[0];
if (!IS_LEAF(child)) {
int prefix = n->n.partial_len;
if (prefix < MAX_PREFIX_LEN) {
n->n.partial[prefix] = n->keys[0];
prefix++;
}
if (prefix < MAX_PREFIX_LEN) {
int sub_prefix =
min(child->partial_len,
MAX_PREFIX_LEN - prefix);
memcpy(n->n.partial + prefix,
child->partial, sub_prefix);
prefix += sub_prefix;
}
memcpy(child->partial, n->n.partial,
min(prefix, MAX_PREFIX_LEN));
child->partial_len += n->n.partial_len + 1;
}
*ref = child;
pmemcto_free(pcp, n);
}
}
static void
remove_child(PMEMctopool *pcp, art_node *n, art_node **ref, unsigned char c,
art_node **l)
{
switch (n->type) {
case NODE4:
return remove_child4(pcp, (art_node4 *)n, ref, l);
case NODE16:
return remove_child16(pcp, (art_node16 *)n, ref, l);
case NODE48:
return remove_child48(pcp, (art_node48 *)n, ref, c);
case NODE256:
return remove_child256(pcp, (art_node256 *)n, ref, c);
default:
abort();
}
}
static art_leaf *
recursive_delete(PMEMctopool *pcp, art_node *n, art_node **ref,
const unsigned char *key, int key_len, int depth)
{
if (!n)
return NULL;
if (IS_LEAF(n)) {
art_leaf *l = LEAF_RAW(n);
if (!leaf_matches(l, key, key_len, depth)) {
*ref = NULL;
return l;
}
return NULL;
}
if (n->partial_len) {
int prefix_len = check_prefix(n, key, key_len, depth);
if (prefix_len != min(MAX_PREFIX_LEN, n->partial_len)) {
return NULL;
}
depth = depth + n->partial_len;
}
art_node **child = find_child(n, key[depth]);
if (!child)
return NULL;
if (IS_LEAF(*child)) {
art_leaf *l = LEAF_RAW(*child);
if (!leaf_matches(l, key, key_len, depth)) {
remove_child(pcp, n, ref, key[depth], child);
return l;
}
return NULL;
} else {
return recursive_delete(pcp, *child, child, key,
key_len, depth + 1);
}
}
void *
art_delete(PMEMctopool *pcp, art_tree *t, const unsigned char *key, int key_len)
{
art_leaf *l = recursive_delete(pcp, t->root, &t->root, key, key_len, 0);
if (l) {
t->size--;
void *old = l->value;
pmemcto_free(pcp, l);
return old;
}
return NULL;
}
static int
recursive_iter(art_node *n, art_callback cb, void *data)
{
if (!n)
return 0;
if (IS_LEAF(n)) {
art_leaf *l = LEAF_RAW(n);
return cb(data, (const unsigned char *)l->key, l->key_len,
l->value, l->val_len);
}
int idx, res;
switch (n->type) {
case NODE4:
for (int i = 0; i < n->num_children; i++) {
res = recursive_iter(((art_node4 *)n)->children[i],
cb, data);
if (res)
return res;
}
break;
case NODE16:
for (int i = 0; i < n->num_children; i++) {
res = recursive_iter(
((art_node16 *)n)->children[i],
cb, data);
if (res)
return res;
}
break;
case NODE48:
for (int i = 0; i < 256; i++) {
idx = ((art_node48 *)n)->keys[i];
if (!idx)
continue;
res = recursive_iter(
((art_node48 *)n)->children[idx - 1],
cb, data);
if (res)
return res;
}
break;
case NODE256:
for (int i = 0; i < 256; i++) {
if (!((art_node256 *)n)->children[i])
continue;
res = recursive_iter(
((art_node256 *)n)->children[i],
cb, data);
if (res)
return res;
}
break;
default:
abort();
}
return 0;
}
int
art_iter(art_tree *t, art_callback cb, void *data)
{
return recursive_iter(t->root, cb, data);
}
static int
recursive_iter2(art_node *n, art_callback cb, void *data)
{
cb_data _cbd, *cbd = &_cbd;
int first = 1;
if (!n)
return 0;
cbd->node = (void *)n;
cbd->node_type = n->type;
cbd->child_idx = -1;
if (IS_LEAF(n)) {
art_leaf *l = LEAF_RAW(n);
return cb(cbd, (const unsigned char *)l->key,
l->key_len, l->value, l->val_len);
}
int idx, res;
switch (n->type) {
case NODE4:
for (int i = 0; i < n->num_children; i++) {
cbd->first_child = first;
first = 0;
cbd->child_idx = i;
cb((void *)cbd, NULL, 0, NULL, 0);
res = recursive_iter2(((art_node4 *)n)->children[i],
cb, data);
if (res)
return res;
}
break;
case NODE16:
for (int i = 0; i < n->num_children; i++) {
cbd->first_child = first;
first = 0;
cbd->child_idx = i;
cb((void *)cbd, NULL, 0, NULL, 0);
res = recursive_iter2(((art_node16 *)n)->children[i],
cb, data);
if (res)
return res;
}
break;
case NODE48:
for (int i = 0; i < 256; i++) {
idx = ((art_node48 *)n)->keys[i];
if (!idx)
continue;
cbd->first_child = first;
first = 0;
cbd->child_idx = i;
cb((void *)cbd, NULL, 0, NULL, 0);
res = recursive_iter2(
((art_node48 *)n)->children[idx - 1],
cb, data);
if (res)
return res;
}
break;
case NODE256:
for (int i = 0; i < 256; i++) {
if (!((art_node256 *)n)->children[i])
continue;
cbd->first_child = first;
first = 0;
cbd->child_idx = i;
cb((void *)cbd, NULL, 0, NULL, 0);
res = recursive_iter2(
((art_node256 *)n)->children[i],
cb, data);
if (res)
return res;
}
break;
default:
abort();
}
return 0;
}
int
art_iter2(art_tree *t, art_callback cb, void *data)
{
return recursive_iter2(t->root, cb, data);
}
static int
leaf_prefix_matches(const art_leaf *n, const unsigned char *prefix,
int prefix_len)
{
if (n->key_len < (uint32_t)prefix_len)
return 1;
return memcmp(n->key, prefix, prefix_len);
}
int
art_iter_prefix(art_tree *t, const unsigned char *key, int key_len,
art_callback cb, void *data)
{
art_node **child;
art_node *n = t->root;
int prefix_len, depth = 0;
while (n) {
if (IS_LEAF(n)) {
n = LEAF_RAW(n);
if (!leaf_prefix_matches(
(art_leaf *)n, key, key_len)) {
art_leaf *l = (art_leaf *)n;
return cb(data, (const unsigned char *)l->key,
l->key_len, l->value, l->val_len);
}
return 0;
}
if (depth == key_len) {
art_leaf *l = minimum(n);
assert(l != NULL);
if (!leaf_prefix_matches(l, key, key_len))
return recursive_iter(n, cb, data);
return 0;
}
if (n->partial_len) {
prefix_len = prefix_mismatch(n, key, key_len, depth);
if (!prefix_len)
return 0;
else if (depth + prefix_len == key_len) {
return recursive_iter(n, cb, data);
}
depth = depth + n->partial_len;
}
child = find_child(n, key[depth]);
n = (child) ? *child : NULL;
depth++;
}
return 0;
}