marisa-ffi 0.3.1

Rust FFI bindings for libmarisa - a space-efficient trie data structure
Documentation
#include "marisa_ffi.h"
#include "marisa/base.h"
#include "marisa.h"
#include <stdexcept>
#include <cstring>

// Error handling helper
static int handle_exception() {
    try {
        throw;
    } catch (const std::bad_alloc&) {
        return MARISA_MEMORY_ERROR;
    } catch (const std::system_error& e) {
        return MARISA_IO_ERROR;
    } catch (const std::exception&) {
        return MARISA_FORMAT_ERROR;
    } catch (...) {
        return MARISA_FORMAT_ERROR;
    }
}

// Core trie functions
extern "C" {

marisa_t* marisa_create(void) {
    try {
        return reinterpret_cast<marisa_t*>(new marisa::Trie());
    } catch (...) {
        return nullptr;
    }
}

void marisa_destroy(marisa_t* trie) {
    if (trie) {
        delete reinterpret_cast<marisa::Trie*>(trie);
    }
}

int marisa_open(marisa_t* trie, const char* filename) {
    try {
        if (!trie || !filename) {
            return MARISA_NULL_ERROR;
        }
        reinterpret_cast<marisa::Trie*>(trie)->load(filename);
        return MARISA_OK;
    } catch (...) {
        return handle_exception();
    }
}

int marisa_save(const marisa_t* trie, const char* filename) {
    try {
        if (!trie || !filename) {
            return MARISA_NULL_ERROR;
        }
        reinterpret_cast<const marisa::Trie*>(trie)->save(filename);
        return MARISA_OK;
    } catch (...) {
        return handle_exception();
    }
}

int marisa_write(const marisa_t* trie, FILE* file) {
    try {
        if (!trie || !file) {
            return MARISA_NULL_ERROR;
        }
        reinterpret_cast<const marisa::Trie*>(trie)->write(fileno(file));
        return MARISA_OK;
    } catch (...) {
        return handle_exception();
    }
}

int marisa_read(marisa_t* trie, FILE* file) {
    try {
        if (!trie || !file) {
            return MARISA_NULL_ERROR;
        }
        reinterpret_cast<marisa::Trie*>(trie)->read(fileno(file));
        return MARISA_OK;
    } catch (...) {
        return handle_exception();
    }
}

int marisa_map(marisa_t* trie, const char* filename) {
    try {
        if (!trie || !filename) {
            return MARISA_NULL_ERROR;
        }
        reinterpret_cast<marisa::Trie*>(trie)->mmap(filename);
        return MARISA_OK;
    } catch (...) {
        return handle_exception();
    }
}

int marisa_unmap(marisa_t* trie) {
    try {
        if (!trie) {
            return MARISA_NULL_ERROR;
        }
        // Note: marisa::Trie doesn't have explicit unmap, but destructor handles it
        return MARISA_OK;
    } catch (...) {
        return handle_exception();
    }
}

int marisa_build(marisa_t* trie, marisa_keyset_t* keyset) {
    try {
        if (!trie || !keyset) {
            return MARISA_NULL_ERROR;
        }
        reinterpret_cast<marisa::Trie*>(trie)->build(*reinterpret_cast<marisa::Keyset*>(keyset));
        return MARISA_OK;
    } catch (...) {
        return handle_exception();
    }
}

int marisa_build_trie(marisa_t* trie, marisa_keyset_t* keyset, int trie_mode) {
    try {
        if (!trie || !keyset) {
            return MARISA_NULL_ERROR;
        }
        reinterpret_cast<marisa::Trie*>(trie)->build(*reinterpret_cast<marisa::Keyset*>(keyset), trie_mode);
        return MARISA_OK;
    } catch (...) {
        return handle_exception();
    }
}

/* 精确查询 */
int marisa_lookup(const marisa_t* trie, const char* key, size_t length, marisa_id_t* id) {
    try {
        if (!trie || !key || !id) {
            return MARISA_NULL_ERROR;
        }
        
        marisa::Agent agent;
        agent.set_query(key, length);
        
        if (reinterpret_cast<const marisa::Trie*>(trie)->lookup(agent)) {
            *id = static_cast<marisa_id_t>(agent.key().id());
            return MARISA_OK;
        } else {
            return MARISA_FORMAT_ERROR; // Key not found
        }
    } catch (...) {
        return handle_exception();
    }
}

/* 前缀 / 预测查询 */
int marisa_predictive_search(const marisa_t* trie, const char* ptr, size_t length,
                             marisa_agent_t* agent) {
    try {
        if (!trie || !ptr || !agent) {
            return MARISA_NULL_ERROR;
        }
        
        reinterpret_cast<marisa::Agent*>(agent)->set_query(ptr, length);
        
        if (reinterpret_cast<const marisa::Trie*>(trie)->predictive_search(*reinterpret_cast<marisa::Agent*>(agent))) {
            return MARISA_OK;
        } else {
            return MARISA_FORMAT_ERROR; // No results
        }
    } catch (...) {
        return handle_exception();
    }
}

/* 反向查询(由 id 拿 key) */
int marisa_reverse_lookup(const marisa_t* trie, marisa_id_t id, marisa_agent_t* agent) {
    try {
        if (!trie || !agent) {
            return MARISA_NULL_ERROR;
        }
        
        reinterpret_cast<marisa::Agent*>(agent)->set_query(static_cast<size_t>(id));
        reinterpret_cast<const marisa::Trie*>(trie)->reverse_lookup(*reinterpret_cast<marisa::Agent*>(agent));
        return MARISA_OK;
    } catch (...) {
        return handle_exception();
    }
}

/* common-prefix-search(逐字返回所有前缀) */
int marisa_common_prefix_search(const marisa_t* trie, const char* ptr, size_t length,
                                marisa_agent_t* agent) {
    try {
        if (!trie || !ptr || !agent) {
            return MARISA_NULL_ERROR;
        }
        
        reinterpret_cast<marisa::Agent*>(agent)->set_query(ptr, length);
        
        if (reinterpret_cast<const marisa::Trie*>(trie)->common_prefix_search(*reinterpret_cast<marisa::Agent*>(agent))) {
            return MARISA_OK;
        } else {
            return MARISA_FORMAT_ERROR; // No results
        }
    } catch (...) {
        return handle_exception();
    }
}

// Agent functions
marisa_agent_t* marisa_agent_create(void) {
    try {
        return reinterpret_cast<marisa_agent_t*>(new marisa::Agent());
    } catch (...) {
        return nullptr;
    }
}

void marisa_agent_destroy(marisa_agent_t* agent) {
    if (agent) {
        delete reinterpret_cast<marisa::Agent*>(agent);
    }
}

/* 取出当前结果 */
const char* marisa_agent_key(const marisa_agent_t* agent) {
    try {
        if (!agent) {
            return nullptr;
        }
        return reinterpret_cast<const marisa::Agent*>(agent)->key().ptr();
    } catch (...) {
        return nullptr;
    }
}

size_t marisa_agent_key_length(const marisa_agent_t* agent) {
    try {
        if (!agent) {
            return 0;
        }
        return reinterpret_cast<const marisa::Agent*>(agent)->key().length();
    } catch (...) {
        return 0;
    }
}

marisa_id_t marisa_agent_id(const marisa_agent_t* agent) {
    try {
        if (!agent) {
            return 0;
        }
        return static_cast<marisa_id_t>(reinterpret_cast<const marisa::Agent*>(agent)->key().id());
    } catch (...) {
        return 0;
    }
}

/* 遍历下一个前缀匹配项 */
int marisa_agent_next(marisa_agent_t* agent) {
    try {
        if (!agent) {
            return MARISA_NULL_ERROR;
        }
        
        // Note: marisa::Agent doesn't have explicit next() method
        // This would need to be implemented based on the specific search context
        // For now, return error as this functionality needs to be designed
        return MARISA_STATE_ERROR;
    } catch (...) {
        return handle_exception();
    }
}

// Keyset functions
marisa_keyset_t* marisa_keyset_create(void) {
    try {
        return reinterpret_cast<marisa_keyset_t*>(new marisa::Keyset());
    } catch (...) {
        return nullptr;
    }
}

void marisa_keyset_destroy(marisa_keyset_t* keyset) {
    if (keyset) {
        delete reinterpret_cast<marisa::Keyset*>(keyset);
    }
}

int marisa_keyset_push(marisa_keyset_t* keyset, const char* key, size_t length) {
    try {
        if (!keyset || !key) {
            return MARISA_NULL_ERROR;
        }
        reinterpret_cast<marisa::Keyset*>(keyset)->push_back(key, length);
        return MARISA_OK;
    } catch (...) {
        return handle_exception();
    }
}

int marisa_keyset_push_back(marisa_keyset_t* keyset, const char* key, size_t length,
                            marisa_id_t id) {
    try {
        if (!keyset || !key) {
            return MARISA_NULL_ERROR;
        }
        
        marisa::Key key_obj;
        key_obj.set_str(key, length);
        key_obj.set_id(static_cast<size_t>(id));
        
        reinterpret_cast<marisa::Keyset*>(keyset)->push_back(key_obj);
        return MARISA_OK;
    } catch (...) {
        return handle_exception();
    }
}

void marisa_keyset_reset(marisa_keyset_t* keyset) {
    try {
        if (keyset) {
            reinterpret_cast<marisa::Keyset*>(keyset)->reset();
        }
    } catch (...) {
        // Ignore exceptions in reset
    }
}

// Utility functions
const char* marisa_strerror(int err) {
    switch (err) {
        case MARISA_OK: return "Success";
        case MARISA_STATE_ERROR: return "State error";
        case MARISA_NULL_ERROR: return "Null pointer error";
        case MARISA_BOUND_ERROR: return "Bound error";
        case MARISA_RANGE_ERROR: return "Range error";
        case MARISA_CODE_ERROR: return "Code error";
        case MARISA_RESET_ERROR: return "Reset error";
        case MARISA_SIZE_ERROR: return "Size error";
        case MARISA_MEMORY_ERROR: return "Memory error";
        case MARISA_IO_ERROR: return "I/O error";
        case MARISA_FORMAT_ERROR: return "Format error";
        default: return "Unknown error";
    }
}

const char* marisa_version(void) {
    return "0.3.1";
}

} // extern "C"