#include "Util/SymSpell/SymSpell.h"
#include "Util/StringUtil.h"
#include <algorithm>
#include <cmath>
#include <fstream>
#include <iostream>
#include <limits>
#include <sstream>
SymSpell::SymSpell(Strategy strategy,
int maxDictionaryEditDistance,
int prefixLength)
: _prefixLength(prefixLength),
_compactMask((std::numeric_limits<uint32_t>::max() >> 8) << 2),
_maxDictionaryWordLength(0),
_maxDictionaryEditDistance(maxDictionaryEditDistance),
_words(),
_strategy(strategy) {
}
bool SymSpell::LoadWordDictionary(std::string path) {
std::fstream fin(path, std::ios::in);
if (!fin.is_open()) {
return false;
}
std::stringstream s;
s << fin.rdbuf();
return LoadWordDictionaryFromBuffer(s.str());
}
bool SymSpell::LoadWordDictionaryFromBuffer(std::string_view buffer) {
auto lines = string_util::Split(buffer, "\n");
try {
for (auto &line: lines) {
auto tokens = string_util::Split(line, " ");
if (tokens.empty()) {
continue;
}
if (tokens.size() == 2) {
int64_t count = std::stoll(std::string(tokens[1]));
if (count > std::numeric_limits<int>::max()) {
count = std::numeric_limits<int>::max();
}
CreateDictionaryEntry(std::string(tokens[0]), static_cast<int>(count));
} else {
CreateDictionaryEntry(std::string(tokens[0]), 1);
}
}
} catch (std::exception &e) {
std::cerr << e.what() << std::endl;
return false;
}
return true;
}
bool SymSpell::CreateDictionaryEntry(const std::string &key, int count) {
if (count <= 0) {
return false;
}
auto wordsFounded = _words.find(key);
if (wordsFounded != _words.end()) {
int countPrevious = wordsFounded->second;
count = (std::numeric_limits<int>::max() - countPrevious > count)
? (countPrevious + count)
: std::numeric_limits<int>::max();
wordsFounded->second = count;
return false;
} else {
_words.insert({key, count});
}
if (_strategy == Strategy::LazyLoaded) {
return true;
}
return BuildDeletesWords(key);
}
bool SymSpell::IsCorrectWord(const std::string &word) const {
return (word.size() <= static_cast<std::size_t>(_maxDictionaryEditDistance)) || (_words.count(word) == 1);
}
std::vector<SuggestItem> SymSpell::LookUp(const std::string &input) {
return LookUp(input, _maxDictionaryEditDistance);
}
std::vector<SuggestItem> SymSpell::LookUp(const std::string &input, int maxEditDistance) {
if (_strategy == Strategy::LazyLoaded) {
BuildAllDeletesWords();
_strategy = Strategy::Normal;
}
std::vector<SuggestItem> suggestions;
if (maxEditDistance > _maxDictionaryEditDistance) {
return suggestions;
}
if (input.size() - 1 > _maxDictionaryWordLength) {
return suggestions;
}
int suggestionCount = 0;
auto founded = _words.find(input);
if (founded != _words.end()) {
return suggestions;
}
std::unordered_set<std::string> hashset1;
std::unordered_set<std::string> hashset2;
hashset2.insert(std::string(input));
int maxEditDistance2 = maxEditDistance;
std::vector<std::string> candidates;
auto inputPrefixLen = input.size();
auto inputLen = inputPrefixLen;
if (inputPrefixLen > _prefixLength) {
inputPrefixLen = _prefixLength;
candidates.emplace_back(input.substr(0, inputPrefixLen));
} else {
candidates.emplace_back(input);
}
std::size_t candidateIndex = 0;
while (candidateIndex < candidates.size()) {
auto candidate = candidates[candidateIndex++];
auto candidateLen = candidate.size();
int lengthDiff = static_cast<int>(inputPrefixLen - candidateLen);
if (lengthDiff > maxEditDistance2) {
break;
}
auto it = _deletes.find(GetStringHash(candidate));
if (it != _deletes.end()) {
auto &dictSuggestions = it->second;
for (auto &suggestion: dictSuggestions) {
auto suggestionLen = suggestion.size();
if (suggestion == input) {
continue;
}
if ((::abs(static_cast<int>(suggestionLen - input.size())) > maxEditDistance2)
|| (suggestionLen < candidateLen)
|| (suggestionLen == candidateLen && suggestion != candidate))
{
continue;
}
auto suggestPrefixLen = std::min(suggestionLen, _prefixLength);
if ((suggestPrefixLen > inputPrefixLen) && (static_cast<int>(suggestPrefixLen - candidateLen) > maxEditDistance2)) {
continue;
}
int distance = 0;
int minLen = 0;
if (candidateLen == 0) {
distance = static_cast<int>(std::max(inputLen, suggestionLen));
auto flag = hashset2.insert(suggestion);
if (distance > maxEditDistance2 || !flag.second) {
continue;
}
} else if (suggestionLen == 1) {
if (input.find(suggestion[0]) == std::string_view::npos) {
distance = static_cast<int>(inputLen);
} else {
distance = static_cast<int>(inputLen) - 1;
}
auto flag = hashset2.insert(suggestion);
if (distance > maxEditDistance2 || !flag.second) {
continue;
}
} else {
if (_prefixLength - 1 == candidateLen) {
minLen = static_cast<int>(std::min(inputLen, suggestionLen) - _prefixLength);
if (minLen > 1 && input.substr(inputLen + 1 - minLen) != suggestion.substr(
suggestionLen + 1 - minLen)) {
continue;
}
if (minLen > 0 && (input[inputLen - minLen] != suggestion[suggestionLen - minLen]) && ((input[inputLen - minLen - 1] != suggestion[suggestionLen - minLen]) || (input[inputLen - minLen] != suggestion[suggestionLen - minLen - 1]))) {
continue;
}
}
if ((!DeleteInSuggestionPrefix(candidate, candidateLen, suggestion, suggestionLen)) || !hashset2.insert(suggestion).second) {
continue;
}
distance = _editDistance.Compare(input, suggestion, maxEditDistance2);
if (distance < 0) {
continue;
}
}
if (distance <= maxEditDistance2) {
suggestionCount = _words[suggestion];
SuggestItem si(suggestion, distance, suggestionCount);
if (!suggestions.empty() && distance < maxEditDistance2) {
suggestions.clear();
}
maxEditDistance2 = distance;
suggestions.push_back(si);
}
}
}
if ((lengthDiff < maxEditDistance) && (candidateLen <= _prefixLength)) {
if (lengthDiff >= maxEditDistance2) {
continue;
}
for (std::size_t i = 0; i < candidateLen; i++) {
std::string temp(candidate);
std::string del = temp.erase(i, 1);
if (hashset1.insert(del).second) {
candidates.push_back(del);
}
}
}
}
std::sort(suggestions.begin(), suggestions.end(), SuggestItem::Comapare);
return suggestions;
}
bool SymSpell::BuildDeletesWords(const std::string &key) {
if (key.size() > _maxDictionaryWordLength) {
_maxDictionaryWordLength = key.size();
}
auto edits = EditsPrefix(key);
for (auto it = edits.begin(); it != edits.end(); ++it) {
int deleteHash = GetStringHash(*it);
auto deletesFounded = _deletes.find(deleteHash);
if (deletesFounded != _deletes.end()) {
auto &suggestions = deletesFounded->second;
suggestions.emplace_back(key);
} else {
std::vector<std::string> suggestions = {key};
_deletes.insert({deleteHash, suggestions});
}
}
return true;
}
bool SymSpell::BuildAllDeletesWords() {
bool ret = true;
for (auto &pair: _words) {
ret &= BuildDeletesWords(pair.first);
}
return ret;
}
std::unordered_set<std::string> SymSpell::EditsPrefix(std::string_view key) {
std::unordered_set<std::string> hashSet;
if (key.size() <= static_cast<std::size_t>(_maxDictionaryEditDistance)) {
hashSet.insert("");
}
if (key.size() > _prefixLength) {
key = key.substr(0, _prefixLength);
}
hashSet.insert(std::string(key));
Edits(key, 0, hashSet);
return hashSet;
}
void SymSpell::Edits(std::string_view word, int editDistance, std::unordered_set<std::string> &deleteWord) {
editDistance++;
if (word.size() > 1) {
for (std::size_t i = 0; i < word.size(); i++) {
std::string tmp(word);
auto del = tmp.erase(i, 1);
if (deleteWord.insert(del).second) {
if (editDistance < _maxDictionaryEditDistance) {
Edits(del, editDistance, deleteWord);
}
}
}
}
}
int SymSpell::GetStringHash(std::string_view source) {
uint32_t lenMask = static_cast<uint32_t>(source.size());
if (lenMask > 3) {
lenMask = 3;
}
uint32_t hash = 2166136261;
for (auto c: source) {
hash ^= c;
hash *= 16777619;
}
hash &= _compactMask;
hash |= lenMask;
return static_cast<int>(hash);
}
bool SymSpell::DeleteInSuggestionPrefix(std::string_view deleteSuggest, std::size_t deleteLen,
std::string_view suggestion,
std::size_t suggestionLen) {
if (deleteLen == 0) {
return true;
}
if (_prefixLength < suggestionLen) {
suggestionLen = _prefixLength;
}
std::size_t j = 0;
for (std::size_t i = 0; i < deleteLen; i++) {
char delChar = deleteSuggest[i];
while (j < suggestionLen && delChar != suggestion[j]) { j++; }
if (j == suggestionLen) {
return false;
}
}
return true;
}