#include "cli.h"
#include <sys/socket.h>
#include <sys/un.h>
#include <signal.h>
#include <unistd.h>
#include <errno.h>
typedef struct {
uint8_t *key;
size_t key_len;
uint8_t *data;
size_t data_len;
} McSlot;
static McSlot *g_slots;
static size_t g_modulus;
static size_t g_key_len;
static char g_sock_path[256];
static char g_dtob_path[4096];
static int g_listen_fd = -1;
static volatile sig_atomic_t g_reload = 0;
static DtobValue *g_root = NULL;
static uint64_t key_to_u64(const uint8_t *key, size_t key_len)
{
uint64_t v = 0;
size_t n = key_len < 8 ? key_len : 8;
for (size_t i = 0; i < n; i++)
v |= (uint64_t)key[i] << (i * 8);
return v;
}
static void cleanup_and_exit(int sig)
{
(void)sig;
if (g_listen_fd >= 0) close(g_listen_fd);
unlink(g_sock_path);
free(g_slots);
if (g_root) dtob_free(g_root);
_exit(0);
}
static void request_reload(int sig)
{
(void)sig;
g_reload = 1;
}
static size_t find_modulus(DtobKVPair *pairs, size_t n)
{
uint8_t *seen = NULL;
size_t seen_cap = 0;
for (size_t m = n; ; m++) {
if (m > seen_cap) {
free(seen);
seen_cap = m * 2;
seen = calloc(seen_cap, 1);
if (!seen) { perror("calloc"); exit(1); }
} else {
memset(seen, 0, m);
}
int ok = 1;
for (size_t i = 0; i < n; i++) {
uint64_t h = key_to_u64(pairs[i].key, pairs[i].key_len);
size_t slot = (size_t)(h % m);
if (seen[slot]) { ok = 0; break; }
seen[slot] = 1;
}
if (ok) { free(seen); return m; }
}
}
static int load_dtob(const char *path)
{
size_t flen;
uint8_t *buf = read_file(path, &flen);
if (!buf) return -1;
DtobValue *root = dtob_decode(buf, flen);
free(buf);
if (!root || root->type != DTOB_KV_SET || root->num_pairs == 0) {
if (root) dtob_free(root);
return -1;
}
size_t key_len = root->pairs[0].key_len;
for (size_t i = 1; i < root->num_pairs; i++) {
if (root->pairs[i].key_len != key_len) {
dtob_free(root);
return -1;
}
}
size_t modulus = find_modulus(root->pairs, root->num_pairs);
McSlot *slots = calloc(modulus, sizeof(McSlot));
if (!slots) { dtob_free(root); return -1; }
for (size_t i = 0; i < root->num_pairs; i++) {
DtobKVPair *p = &root->pairs[i];
uint64_t h = key_to_u64(p->key, p->key_len);
size_t idx = (size_t)(h % modulus);
slots[idx].key = p->key;
slots[idx].key_len = p->key_len;
slots[idx].data = p->value->data;
slots[idx].data_len = p->value->data_len;
}
if (g_root) dtob_free(g_root);
free(g_slots);
g_root = root;
g_slots = slots;
g_modulus = modulus;
g_key_len = key_len;
fprintf(stderr, "memcache: loaded %zu pairs, modulus %zu, key_len %zu\n",
root->num_pairs, modulus, key_len);
return 0;
}
static void handle_client(int fd, size_t key_len)
{
uint8_t *keybuf = malloc(key_len);
if (!keybuf) { close(fd); return; }
size_t got = 0;
while (got < key_len) {
ssize_t r = read(fd, keybuf + got, key_len - got);
if (r <= 0) { free(keybuf); close(fd); return; }
got += (size_t)r;
}
uint64_t h = key_to_u64(keybuf, key_len);
size_t idx = (size_t)(h % g_modulus);
McSlot *s = &g_slots[idx];
uint8_t lenbuf[4] = {0, 0, 0, 0};
if (s->key && s->key_len == key_len &&
memcmp(s->key, keybuf, key_len) == 0) {
uint32_t dlen = (uint32_t)s->data_len;
lenbuf[0] = (uint8_t)(dlen);
lenbuf[1] = (uint8_t)(dlen >> 8);
lenbuf[2] = (uint8_t)(dlen >> 16);
lenbuf[3] = (uint8_t)(dlen >> 24);
write(fd, lenbuf, 4);
if (s->data_len > 0)
write(fd, s->data, s->data_len);
} else {
write(fd, lenbuf, 4);
}
free(keybuf);
close(fd);
}
int cmd_memcache(int argc, char **argv)
{
if (argc < 1) {
fprintf(stderr, "usage: dtob memcache <file.dtob> [socket_path]\n");
return 1;
}
const char *dtob_path = argv[0];
const char *sock_path = (argc >= 2) ? argv[1] : "/tmp/dtob-memcache.sock";
if (strlen(dtob_path) >= sizeof(g_dtob_path)) {
fprintf(stderr, "error: dtob path too long\n");
return 1;
}
strcpy(g_dtob_path, dtob_path);
if (strlen(sock_path) >= sizeof(g_sock_path)) {
fprintf(stderr, "error: socket path too long\n");
return 1;
}
strcpy(g_sock_path, sock_path);
if (load_dtob(g_dtob_path) != 0) {
fprintf(stderr, "error: failed to load %s\n", g_dtob_path);
return 1;
}
pid_t pid = fork();
if (pid < 0) { perror("fork"); return 1; }
if (pid > 0) {
fprintf(stderr, "memcache: daemon pid %d, socket %s\n", (int)pid, g_sock_path);
_exit(0);
}
setsid();
{
char pid_path[260];
size_t slen = strlen(g_sock_path);
if (slen > 5 && strcmp(g_sock_path + slen - 5, ".sock") == 0) {
memcpy(pid_path, g_sock_path, slen - 5);
strcpy(pid_path + slen - 5, ".pid");
} else {
snprintf(pid_path, sizeof(pid_path), "%s.pid", g_sock_path);
}
FILE *pf = fopen(pid_path, "w");
if (pf) { fprintf(pf, "%d\n", (int)getpid()); fclose(pf); }
}
struct sigaction sa;
memset(&sa, 0, sizeof(sa));
sa.sa_handler = cleanup_and_exit;
sigaction(SIGTERM, &sa, NULL);
sigaction(SIGINT, &sa, NULL);
struct sigaction sa_hup;
memset(&sa_hup, 0, sizeof(sa_hup));
sa_hup.sa_handler = request_reload;
sigaction(SIGHUP, &sa_hup, NULL);
g_listen_fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (g_listen_fd < 0) { perror("socket"); _exit(1); }
struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
strncpy(addr.sun_path, g_sock_path, sizeof(addr.sun_path) - 1);
unlink(g_sock_path);
if (bind(g_listen_fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
perror("bind");
_exit(1);
}
if (listen(g_listen_fd, 16) < 0) {
perror("listen");
_exit(1);
}
for (;;) {
if (g_reload) {
g_reload = 0;
fprintf(stderr, "memcache: SIGHUP — reloading %s\n", g_dtob_path);
if (load_dtob(g_dtob_path) != 0)
fprintf(stderr, "memcache: reload failed, keeping old data\n");
}
int client = accept(g_listen_fd, NULL, NULL);
if (client < 0) {
if (errno == EINTR) continue;
break;
}
handle_client(client, g_key_len);
}
cleanup_and_exit(0);
return 0;
}