#include "internal/uint_set.h"
#include "internal/common.h"
#include <assert.h>
void ossl_uint_set_init(UINT_SET *s)
{
ossl_list_uint_set_init(s);
}
void ossl_uint_set_destroy(UINT_SET *s)
{
UINT_SET_ITEM *x, *xnext;
for (x = ossl_list_uint_set_head(s); x != NULL; x = xnext) {
xnext = ossl_list_uint_set_next(x);
OPENSSL_free(x);
}
}
static void uint_set_merge_adjacent(UINT_SET *s, UINT_SET_ITEM *x)
{
UINT_SET_ITEM *xprev = ossl_list_uint_set_prev(x);
if (xprev == NULL)
return;
if (x->range.start - 1 != xprev->range.end)
return;
x->range.start = xprev->range.start;
ossl_list_uint_set_remove(s, xprev);
OPENSSL_free(xprev);
}
static uint64_t u64_min(uint64_t x, uint64_t y)
{
return x < y ? x : y;
}
static uint64_t u64_max(uint64_t x, uint64_t y)
{
return x > y ? x : y;
}
static int uint_range_overlaps(const UINT_RANGE *a,
const UINT_RANGE *b)
{
return u64_min(a->end, b->end)
>= u64_max(a->start, b->start);
}
static UINT_SET_ITEM *create_set_item(uint64_t start, uint64_t end)
{
UINT_SET_ITEM *x = OPENSSL_malloc(sizeof(UINT_SET_ITEM));
if (x == NULL)
return NULL;
ossl_list_uint_set_init_elem(x);
x->range.start = start;
x->range.end = end;
return x;
}
int ossl_uint_set_insert(UINT_SET *s, const UINT_RANGE *range)
{
UINT_SET_ITEM *x, *xnext, *z, *zprev, *f;
uint64_t start = range->start, end = range->end;
if (!ossl_assert(start <= end))
return 0;
if (ossl_list_uint_set_is_empty(s)) {
x = create_set_item(start, end);
if (x == NULL)
return 0;
ossl_list_uint_set_insert_head(s, x);
return 1;
}
z = ossl_list_uint_set_tail(s);
if (start > z->range.end) {
if (z->range.end + 1 == start) {
z->range.end = end;
return 1;
}
x = create_set_item(start, end);
if (x == NULL)
return 0;
ossl_list_uint_set_insert_tail(s, x);
return 1;
}
f = ossl_list_uint_set_head(s);
if (start <= f->range.start && end >= z->range.end) {
x = ossl_list_uint_set_head(s);
x->range.start = start;
x->range.end = end;
for (x = ossl_list_uint_set_next(x); x != NULL; x = xnext) {
xnext = ossl_list_uint_set_next(x);
ossl_list_uint_set_remove(s, x);
}
return 1;
}
z = end < f->range.start ? f : z;
for (; z != NULL; z = zprev) {
zprev = ossl_list_uint_set_prev(z);
if (z->range.start <= start && z->range.end >= end)
return 1;
if (uint_range_overlaps(&z->range, range)) {
UINT_SET_ITEM *ovend = z;
ovend->range.end = u64_max(end, z->range.end);
while (zprev != NULL && uint_range_overlaps(&zprev->range, range)) {
z = zprev;
zprev = ossl_list_uint_set_prev(z);
}
ovend->range.start = u64_min(start, z->range.start);
while (z != ovend) {
z = ossl_list_uint_set_next(x = z);
ossl_list_uint_set_remove(s, x);
OPENSSL_free(x);
}
break;
} else if (end < z->range.start
&& (zprev == NULL || start > zprev->range.end)) {
if (z->range.start == end + 1) {
z->range.start = start;
uint_set_merge_adjacent(s, z);
} else if (zprev != NULL && zprev->range.end + 1 == start) {
zprev->range.end = end;
uint_set_merge_adjacent(s, z);
} else {
x = create_set_item(start, end);
if (x == NULL)
return 0;
ossl_list_uint_set_insert_before(s, z, x);
}
break;
}
}
return 1;
}
int ossl_uint_set_remove(UINT_SET *s, const UINT_RANGE *range)
{
UINT_SET_ITEM *z, *zprev, *y;
uint64_t start = range->start, end = range->end;
if (!ossl_assert(start <= end))
return 0;
for (z = ossl_list_uint_set_tail(s); z != NULL; z = zprev) {
zprev = ossl_list_uint_set_prev(z);
if (start > z->range.end)
break;
if (start <= z->range.start && end >= z->range.end) {
ossl_list_uint_set_remove(s, z);
OPENSSL_free(z);
} else if (start <= z->range.start && end >= z->range.start) {
assert(end < z->range.end);
z->range.start = end + 1;
} else if (end >= z->range.end) {
assert(start > z->range.start);
assert(start > 0);
z->range.end = start - 1;
break;
} else if (start > z->range.start && end < z->range.end) {
y = create_set_item(end + 1, z->range.end);
if (y == NULL)
return 0;
ossl_list_uint_set_insert_after(s, z, y);
z->range.end = start - 1;
break;
} else {
assert(!uint_range_overlaps(&z->range, range));
}
}
return 1;
}
int ossl_uint_set_query(const UINT_SET *s, uint64_t v)
{
UINT_SET_ITEM *x;
if (ossl_list_uint_set_is_empty(s))
return 0;
for (x = ossl_list_uint_set_tail(s); x != NULL; x = ossl_list_uint_set_prev(x))
if (x->range.start <= v && x->range.end >= v)
return 1;
else if (x->range.end < v)
return 0;
return 0;
}