#include <algorithm>
#include <cstring>
#include <stdexcept>
#include <unordered_map>
#include "marisa.h"
#include "Lexicon.hpp"
#include "MarisaDict.hpp"
#include "SerializedValues.hpp"
using namespace opencc;
namespace {
static const char* OCD2_HEADER = "OPENCC_MARISA_0.2.5";
}
class MarisaDict::MarisaInternal {
public:
std::unique_ptr<marisa::Trie> marisa;
std::string mappedBuffer;
MarisaInternal() : marisa(new marisa::Trie()) {}
};
MarisaDict::MarisaDict() : internal(new MarisaInternal()) {}
MarisaDict::~MarisaDict() {}
size_t MarisaDict::KeyMaxLength() const { return maxLength; }
Optional<const DictEntry*> MarisaDict::Match(const char* word,
size_t len) const {
if (len > maxLength) {
return Optional<const DictEntry*>::Null();
}
const marisa::Trie& trie = *internal->marisa;
marisa::Agent agent;
agent.set_query(word, len);
if (trie.lookup(agent)) {
return Optional<const DictEntry*>(lexicon->At(agent.key().id()));
} else {
return Optional<const DictEntry*>::Null();
}
}
Optional<const DictEntry*> MarisaDict::MatchPrefix(const char* word,
size_t len) const {
const marisa::Trie& trie = *internal->marisa;
marisa::Agent agent;
agent.set_query(word, (std::min)(maxLength, len));
const DictEntry* match = nullptr;
while (trie.common_prefix_search(agent)) {
match = lexicon->At(agent.key().id());
}
if (match == nullptr) {
return Optional<const DictEntry*>::Null();
} else {
return Optional<const DictEntry*>(match);
}
}
std::vector<const DictEntry*> MarisaDict::MatchAllPrefixes(const char* word,
size_t len) const {
const marisa::Trie& trie = *internal->marisa;
marisa::Agent agent;
agent.set_query(word, (std::min)(maxLength, len));
std::vector<const DictEntry*> matches;
while (trie.common_prefix_search(agent)) {
matches.push_back(lexicon->At(agent.key().id()));
}
std::reverse(matches.begin(), matches.end());
return matches;
}
LexiconPtr MarisaDict::GetLexicon() const { return lexicon; }
MarisaDictPtr MarisaDict::NewFromFile(FILE* fp) {
size_t headerLen = strlen(OCD2_HEADER);
void* buffer = malloc(sizeof(char) * headerLen);
size_t bytesRead = fread(buffer, sizeof(char), headerLen, fp);
if (bytesRead != headerLen || memcmp(buffer, OCD2_HEADER, headerLen) != 0) {
throw InvalidFormat("Invalid OpenCC dictionary header");
}
free(buffer);
long trieOffset = ftell(fp);
fseek(fp, 0L, SEEK_END);
long fileEnd = ftell(fp);
fseek(fp, trieOffset, SEEK_SET);
size_t remainingSize =
(fileEnd > trieOffset) ? static_cast<size_t>(fileEnd - trieOffset) : 0;
MarisaDictPtr dict(new MarisaDict());
dict->internal->mappedBuffer.resize(remainingSize);
bytesRead = fread(const_cast<char*>(dict->internal->mappedBuffer.data()),
sizeof(char), remainingSize, fp);
if (bytesRead != remainingSize) {
throw InvalidFormat("Invalid OpenCC Marisa dictionary.");
}
try {
dict->internal->marisa->map(dict->internal->mappedBuffer.data(),
dict->internal->mappedBuffer.size());
} catch (const std::exception& e) {
throw InvalidFormat(std::string("Invalid OpenCC Marisa dictionary: ") +
e.what());
}
const size_t trieSize = dict->internal->marisa->io_size();
if (trieSize > dict->internal->mappedBuffer.size()) {
throw InvalidFormat(
"Invalid OpenCC Marisa dictionary (trie exceeds file size)");
}
size_t valuesBytesRead = 0;
std::shared_ptr<SerializedValues> serialized_values =
SerializedValues::NewFromBuffer(
dict->internal->mappedBuffer.data() + trieSize,
dict->internal->mappedBuffer.size() - trieSize, &valuesBytesRead);
LexiconPtr values_lexicon = serialized_values->GetLexicon();
size_t numKeys = dict->internal->marisa->num_keys();
if (numKeys != values_lexicon->Length()) {
throw InvalidFormat(
"Invalid OpenCC Marisa dictionary (key count mismatch)");
}
marisa::Agent agent;
agent.set_query("");
std::vector<std::unique_ptr<DictEntry>> entries;
entries.resize(values_lexicon->Length());
size_t maxLength = 0;
try {
while (dict->internal->marisa->predictive_search(agent)) {
const std::string key(agent.key().ptr(), agent.key().length());
size_t id = agent.key().id();
if (id >= entries.size()) {
throw InvalidFormat(
"Invalid OpenCC Marisa dictionary (key id out of bounds)");
}
maxLength = (std::max)(key.length(), maxLength);
std::unique_ptr<DictEntry> entry(
DictEntryFactory::New(key, values_lexicon->At(id)->Values()));
entries[id] = std::move(entry);
}
} catch (const InvalidFormat&) {
throw;
} catch (const std::exception& e) {
throw InvalidFormat(std::string("Invalid OpenCC Marisa dictionary: ") +
e.what());
}
dict->lexicon.reset(new Lexicon(std::move(entries)));
dict->maxLength = maxLength;
return dict;
}
MarisaDictPtr MarisaDict::NewFromDict(const Dict& thatDict) {
const LexiconPtr& thatLexicon = thatDict.GetLexicon();
size_t maxLength = 0;
marisa::Keyset keyset;
std::unordered_map<std::string, std::unique_ptr<DictEntry>> key_value_map;
for (size_t i = 0; i < thatLexicon->Length(); i++) {
const DictEntry* entry = thatLexicon->At(i);
keyset.push_back(entry->Key().c_str());
key_value_map[entry->Key()].reset(DictEntryFactory::New(entry));
maxLength = (std::max)(entry->KeyLength(), maxLength);
}
MarisaDictPtr dict(new MarisaDict());
dict->internal->marisa->build(keyset);
marisa::Agent agent;
agent.set_query("");
std::vector<std::unique_ptr<DictEntry>> entries;
entries.resize(thatLexicon->Length());
while (dict->internal->marisa->predictive_search(agent)) {
std::string key(agent.key().ptr(), agent.key().length());
std::unique_ptr<DictEntry> entry = std::move(key_value_map[key]);
entries[agent.key().id()] = std::move(entry);
}
dict->lexicon.reset(new Lexicon(std::move(entries)));
dict->maxLength = maxLength;
return dict;
}
void MarisaDict::SerializeToFile(FILE* fp) const {
fwrite(OCD2_HEADER, sizeof(char), strlen(OCD2_HEADER), fp);
marisa::fwrite(fp, *internal->marisa);
std::unique_ptr<SerializedValues> serialized_values(
new SerializedValues(lexicon));
serialized_values->SerializeToFile(fp);
}