#include "simple_tokenizer.h"
#include <algorithm>
#include <cctype>
#include <cstdlib>
#include <memory>
#include <mutex>
#include <set>
#include <string>
#include <vector>
namespace simple_tokenizer {
namespace {
std::mutex pinyin_mutex;
std::shared_ptr<PinYin> global_pinyin;
}
SimpleTokenizer::SimpleTokenizer(const char **azArg, int nArg) {
if (nArg >= 1) {
enable_pinyin = atoi(azArg[0]) != 0;
}
}
std::shared_ptr<PinYin> SimpleTokenizer::get_pinyin() {
std::lock_guard<std::mutex> lock(pinyin_mutex);
if (global_pinyin == nullptr) {
global_pinyin = std::make_shared<PinYin>();
}
return global_pinyin;
}
bool SimpleTokenizer::set_pinyin_dict(const std::string &pinyin_file_path, std::string &err) {
std::shared_ptr<PinYin> new_pinyin;
try {
if (pinyin_file_path.empty()) {
new_pinyin = std::make_shared<PinYin>();
} else {
new_pinyin = std::make_shared<PinYin>(pinyin_file_path);
}
} catch (const std::exception &e) {
err = e.what();
return false;
}
std::lock_guard<std::mutex> lock(pinyin_mutex);
global_pinyin = new_pinyin;
return true;
}
static TokenCategory from_char(char c) {
auto uc = static_cast<unsigned char>(c);
if (uc > 127) {
return TokenCategory::OTHER;
}
if (std::isdigit(uc)) {
return TokenCategory::DIGIT;
}
if (std::isspace(uc) || std::iscntrl(uc)) {
return TokenCategory::SPACE;
}
if (std::isalpha(uc)) {
return TokenCategory::ASCII_ALPHABETIC;
}
return TokenCategory::OTHER;
}
std::string SimpleTokenizer::tokenize_query(const char *text, int textLen, int flags) {
int start = 0;
int index = 0;
std::string tmp;
std::string result;
while (index < textLen) {
TokenCategory category = from_char(text[index]);
switch (category) {
case TokenCategory::OTHER:
index += PinYin::get_str_len(text[index]);
break;
default:
while (++index < textLen && from_char(text[index]) == category) {
}
break;
}
tmp.clear();
std::copy(text + start, text + index, std::back_inserter(tmp));
append_result(result, tmp, category, start, flags);
start = index;
}
return result;
}
#ifdef USE_JIEBA
std::string jieba_dict_path = "./dict/";
std::string SimpleTokenizer::tokenize_jieba_query(const char *text, int textLen, int flags) {
(void)textLen;
static cppjieba::Jieba jieba(jieba_dict_path + "jieba.dict.utf8", jieba_dict_path + "hmm_model.utf8",
jieba_dict_path + "user.dict.utf8", jieba_dict_path + "idf.utf8",
jieba_dict_path + "stop_words.utf8");
std::string tmp;
std::string result;
std::vector<cppjieba::Word> words;
jieba.Cut(text, words);
for (auto word : words) {
TokenCategory category = from_char(text[word.offset]);
for (auto c : word.word) {
if (from_char(c) != category) {
category = TokenCategory::OTHER;
break;
}
}
append_result(result, word.word, category, word.offset, flags);
}
return result;
}
#endif
void SimpleTokenizer::append_result(std::string &result, std::string part, TokenCategory category, int offset,
int flags) {
if (category != TokenCategory::SPACE) {
std::string tmp = std::move(part);
if (category == TokenCategory::ASCII_ALPHABETIC) {
std::transform(tmp.begin(), tmp.end(), tmp.begin(), [](unsigned char c) { return std::tolower(c); });
}
if (flags != 0 && category == TokenCategory::ASCII_ALPHABETIC && tmp.size() > 1) {
if (offset == 0) {
result.append("( ");
} else {
result.append(" AND ( ");
}
std::set<std::string> pys = SimpleTokenizer::get_pinyin()->split_pinyin(tmp);
bool addOr = false;
for (const std::string &s : pys) {
if (addOr) {
result.append(" OR ");
}
result.append(s);
result.append("*");
addOr = true;
}
result.append(" )");
} else {
if (offset > 0) {
result.append(" AND ");
}
if (tmp == "\"") {
tmp += tmp;
}
if (category != TokenCategory::ASCII_ALPHABETIC) {
result.append('"' + tmp + '"');
} else {
result.append(tmp);
}
if (category != TokenCategory::OTHER) {
result.append("*");
}
}
}
}
int SimpleTokenizer::tokenize(void *pCtx, int flags, const char *text, int textLen, xTokenFn xToken) const {
int rc = SQLITE_OK;
int start = 0;
int index = 0;
std::string result;
while (index < textLen) {
TokenCategory category = from_char(text[index]);
switch (category) {
case TokenCategory::OTHER:
index += PinYin::get_str_len(text[index]);
break;
default:
while (++index < textLen && from_char(text[index]) == category) {
}
break;
}
if (category != TokenCategory::SPACE) {
result.clear();
std::copy(text + start, text + index, std::back_inserter(result));
if (category == TokenCategory::ASCII_ALPHABETIC) {
std::transform(result.begin(), result.end(), result.begin(), [](unsigned char c) { return std::tolower(c); });
}
rc = xToken(pCtx, 0, result.c_str(), (int)result.length(), start, index);
if (enable_pinyin && category == TokenCategory::OTHER && (flags & FTS5_TOKENIZE_DOCUMENT)) {
std::shared_ptr<PinYin> pinyin = SimpleTokenizer::get_pinyin();
const std::vector<std::string> &pys = pinyin->get_pinyin(result);
for (const std::string &s : pys) {
rc = xToken(pCtx, FTS5_TOKEN_COLOCATED, s.c_str(), (int)s.length(), start, index);
}
}
}
start = index;
}
return rc;
}
}