#include <stdbool.h>
#include "core/or/or.h"
#include "ht.h"
#include "siphash.h"
#define BTRACK_ORCONN_PRIVATE
#include "feature/control/btrack_orconn.h"
#include "feature/control/btrack_orconn_maps.h"
#include "lib/log/log.h"
static inline unsigned int
bto_gid_hash_(bt_orconn_t *elm)
{
return (unsigned)siphash24g(&elm->gid, sizeof(elm->gid));
}
static inline int
bto_gid_eq_(bt_orconn_t *a, bt_orconn_t *b)
{
return a->gid == b->gid;
}
static inline unsigned int
bto_chan_hash_(bt_orconn_t *elm)
{
return (unsigned)siphash24g(&elm->chan, sizeof(elm->chan));
}
static inline int
bto_chan_eq_(bt_orconn_t *a, bt_orconn_t *b)
{
return a->chan == b->chan;
}
HT_HEAD(bto_gid_ht, bt_orconn_t);
HT_PROTOTYPE(bto_gid_ht, bt_orconn_t, node, bto_gid_hash_, bto_gid_eq_);
HT_GENERATE2(bto_gid_ht, bt_orconn_t, node,
bto_gid_hash_, bto_gid_eq_, 0.6,
tor_reallocarray_, tor_free_);
static struct bto_gid_ht *bto_gid_map;
HT_HEAD(bto_chan_ht, bt_orconn_t);
HT_PROTOTYPE(bto_chan_ht, bt_orconn_t, chan_node, bto_chan_hash_,
bto_chan_eq_);
HT_GENERATE2(bto_chan_ht, bt_orconn_t, chan_node,
bto_chan_hash_, bto_chan_eq_, 0.6,
tor_reallocarray_, tor_free_);
static struct bto_chan_ht *bto_chan_map;
static void
bto_gid_clear_map(void)
{
bt_orconn_t **elt, **next, *c;
for (elt = HT_START(bto_gid_ht, bto_gid_map);
elt;
elt = next) {
c = *elt;
next = HT_NEXT_RMV(bto_gid_ht, bto_gid_map, elt);
c->gid = 0;
if (!c->chan)
tor_free(c);
}
HT_CLEAR(bto_gid_ht, bto_gid_map);
tor_free(bto_gid_map);
}
static void
bto_chan_clear_map(void)
{
bt_orconn_t **elt, **next, *c;
for (elt = HT_START(bto_chan_ht, bto_chan_map);
elt;
elt = next) {
c = *elt;
next = HT_NEXT_RMV(bto_chan_ht, bto_chan_map, elt);
c->chan = 0;
if (!c->gid)
tor_free(c);
}
HT_CLEAR(bto_chan_ht, bto_chan_map);
tor_free(bto_chan_map);
}
void
bto_delete(uint64_t gid)
{
bt_orconn_t key, *bto;
key.gid = gid;
key.chan = 0;
bto = HT_FIND(bto_gid_ht, bto_gid_map, &key);
if (!bto) {
log_debug(LD_BTRACK, "tried to delete unregistered ORCONN gid=%"PRIu64,
gid);
return;
}
HT_REMOVE(bto_gid_ht, bto_gid_map, &key);
if (bto->chan) {
key.chan = bto->chan;
HT_REMOVE(bto_chan_ht, bto_chan_map, &key);
}
tor_free(bto);
}
static bt_orconn_t *
bto_update(bt_orconn_t *bto, const bt_orconn_t *key)
{
tor_assert(!bto->gid || !key->gid || bto->gid == key->gid);
if (!bto->gid && key->gid) {
log_debug(LD_BTRACK, "ORCONN chan=%"PRIu64" newgid=%"PRIu64, key->chan,
key->gid);
bto->gid = key->gid;
HT_INSERT(bto_gid_ht, bto_gid_map, bto);
}
tor_assert(!bto->chan || !key->chan || bto->chan == key->chan);
if (!bto->chan && key->chan) {
log_debug(LD_BTRACK, "ORCONN gid=%"PRIu64" newchan=%"PRIu64,
bto->gid, key->chan);
bto->chan = key->chan;
HT_INSERT(bto_chan_ht, bto_chan_map, bto);
}
return bto;
}
static bt_orconn_t *
bto_new(const bt_orconn_t *key)
{
struct bt_orconn_t *bto = tor_malloc(sizeof(*bto));
bto->gid = key->gid;
bto->chan = key->chan;
bto->state = 0;
bto->proxy_type = 0;
bto->is_orig = false;
bto->is_onehop = true;
if (bto->gid)
HT_INSERT(bto_gid_ht, bto_gid_map, bto);
if (bto->chan)
HT_INSERT(bto_chan_ht, bto_chan_map, bto);
return bto;
}
bt_orconn_t *
bto_find_or_new(uint64_t gid, uint64_t chan)
{
bt_orconn_t key, *bto = NULL;
tor_assert(gid || chan);
key.gid = gid;
key.chan = chan;
if (key.gid)
bto = HT_FIND(bto_gid_ht, bto_gid_map, &key);
if (!bto && key.chan) {
bto = HT_FIND(bto_chan_ht, bto_chan_map, &key);
}
if (bto)
return bto_update(bto, &key);
else
return bto_new(&key);
}
void
bto_init_maps(void)
{
bto_gid_map = tor_malloc(sizeof(*bto_gid_map));
HT_INIT(bto_gid_ht, bto_gid_map);
bto_chan_map = tor_malloc(sizeof(*bto_chan_map));
HT_INIT(bto_chan_ht, bto_chan_map);
}
void
bto_clear_maps(void)
{
bto_gid_clear_map();
bto_chan_clear_map();
}