#include <iostream>
#include <sstream>
#include <chrono>
#include <regex>
#include <time.h>
#include <stdio.h>
#include <stdexcept>
#include <memory>
#include <vector>
#include <chrono>
#include <map>
#include <utility>
#include <algorithm>
#include <unordered_map>
#include "bm.h"
#include "bmalgo.h"
#include "bmserial.h"
#include "bmaggregator.h"
#include "bmdbg.h"
#include "bmtimer.h"
#include "bmundef.h"
using namespace std;
static
void show_help()
{
std::cerr
<< "BitMagic DNA Search Sample (c) 2018" << std::endl
<< "-fa file-name -- input FASTA file" << std::endl
<< "-s hi|lo -- run substring search benchmark" << std::endl
<< "-diag -- run diagnostics" << std::endl
<< "-timing -- collect timings" << std::endl
;
}
std::string ifa_name;
bool is_diag = false;
bool is_timing = false;
bool is_bench = false;
bool is_search = false;
bool h_word_set = true;
static
int parse_args(int argc, char *argv[])
{
for (int i = 1; i < argc; ++i)
{
std::string arg = argv[i];
if ((arg == "-h") || (arg == "--help"))
{
show_help();
return 0;
}
if (arg == "-fa" || arg == "--fa")
{
if (i + 1 < argc)
{
ifa_name = argv[++i];
}
else
{
std::cerr << "Error: -fa requires file name" << std::endl;
return 1;
}
continue;
}
if (arg == "-diag" || arg == "--diag" || arg == "-d" || arg == "--d")
is_diag = true;
if (arg == "-timing" || arg == "--timing" || arg == "-t" || arg == "--t")
is_timing = true;
if (arg == "-bench" || arg == "--bench" || arg == "-b" || arg == "--b")
is_bench = true;
if (arg == "-search" || arg == "--search" || arg == "-s" || arg == "--s")
{
is_search = true;
if (i + 1 < argc)
{
std::string a = argv[i+1];
if (a != "-")
{
if (a == "l" || a == "lo")
{
h_word_set = false;
++i;
}
else
if (a == "h" || a == "hi")
{
h_word_set = true;
++i;
}
}
}
}
} return 0;
}
typedef std::map<std::string, unsigned> freq_map;
typedef std::vector<std::pair<unsigned, std::string> > dict_vect;
typedef bm::aggregator<bm::bvector<> > aggregator_type;
bm::chrono_taker::duration_map_type timing_map;
static
int load_FASTA(const std::string& fname, std::vector<char>& seq_vect)
{
bm::chrono_taker tt1("1. Parse FASTA", 1, &timing_map);
seq_vect.resize(0);
std::ifstream fin(fname.c_str(), std::ios::in);
if (!fin.good())
return -1;
std::string line;
for (unsigned i = 0; std::getline(fin, line); ++i)
{
if (line.empty() ||
line.front() == '>')
continue;
for (std::string::iterator it = line.begin(); it != line.end(); ++it)
seq_vect.push_back(*it);
} return 0;
}
class DNA_FingerprintScanner
{
public:
enum { eA = 0, eC, eG, eT, eN, eEnd };
DNA_FingerprintScanner() {}
void Build(const vector<char>& sequence)
{
bm::bvector<>::bulk_insert_iterator iA(m_FPrintBV[eA], bm::BM_SORTED);
bm::bvector<>::bulk_insert_iterator iC(m_FPrintBV[eC], bm::BM_SORTED);
bm::bvector<>::bulk_insert_iterator iG(m_FPrintBV[eG], bm::BM_SORTED);
bm::bvector<>::bulk_insert_iterator iT(m_FPrintBV[eT], bm::BM_SORTED);
bm::bvector<>::bulk_insert_iterator iN(m_FPrintBV[eN], bm::BM_SORTED);
for (size_t i = 0; i < sequence.size(); ++i)
{
unsigned pos = unsigned(i);
switch (sequence[i])
{
case 'A':
iA = pos;
break;
case 'C':
iC = pos;
break;
case 'G':
iG = pos;
break;
case 'T':
iT = pos;
break;
case 'N':
iN = pos;
break;
default:
break;
}
}
}
const bm::bvector<>& GetVector(char letter) const
{
switch (letter)
{
case 'A':
return m_FPrintBV[eA];
case 'C':
return m_FPrintBV[eC];
case 'G':
return m_FPrintBV[eG];
case 'T':
return m_FPrintBV[eT];
case 'N':
return m_FPrintBV[eN];
default:
break;
}
throw runtime_error("Error. Invalid letter!");
}
void Find(const string& word, vector<unsigned>& res)
{
if (word.empty())
return;
bm::bvector<> bv(GetVector(word[0]));
for (size_t i = 1; i < word.size(); ++i)
{
bv.shift_right(); const bm::bvector<>& bv_mask = GetVector(word[i]);
bv &= bv_mask;
auto any = bv.any();
if (!any)
break;
}
unsigned ws = unsigned(word.size()) - 1;
TranslateResults(bv, ws, res);
};
void FindAggFused(const string& word, vector<unsigned>& res)
{
if (word.empty())
return;
m_Agg.reset();
for (size_t i = 0; i < word.size(); ++i)
{
const bm::bvector<>& bv_mask = GetVector(word[i]);
m_Agg.add(&bv_mask);
}
bm::bvector<> bv;
m_Agg.combine_shift_right_and(bv);
unsigned ws = unsigned(word.size()) - 1;
TranslateResults(bv, ws, res);
};
void FindCollection(const vector<tuple<string,int> >& words,
vector<vector<unsigned>>& hits)
{
vector<unique_ptr<aggregator_type> > agg_pipeline;
unsigned ws = 0;
for (const auto& w : words)
{
unique_ptr<aggregator_type> agg_ptr(new aggregator_type());
agg_ptr->set_operation(aggregator_type::BM_SHIFT_R_AND);
const string& word = get<0>(w);
for (size_t i = 0; i < word.size(); ++i)
{
const bm::bvector<>& bv_mask = GetVector(word[i]);
agg_ptr->add(&bv_mask);
}
agg_pipeline.emplace_back(agg_ptr.release());
ws = unsigned(word.size()) - 1;
}
bm::aggregator_pipeline_execute<aggregator_type,
vector<unique_ptr<aggregator_type> >::iterator>(agg_pipeline.begin(), agg_pipeline.end());
for (size_t i = 0; i < agg_pipeline.size(); ++i)
{
const aggregator_type* agg_ptr = agg_pipeline[i].get();
auto bv = agg_ptr->get_target();
vector<unsigned> res;
res.reserve(12000);
TranslateResults(*bv, ws, res);
hits.emplace_back(res);
}
}
protected:
void TranslateResults(const bm::bvector<>& bv,
unsigned left_shift,
vector<unsigned>& res)
{
bm::bvector<>::enumerator en = bv.first();
for (; en.valid(); ++en)
{
auto pos = *en;
res.push_back(pos - left_shift);
}
}
private:
bm::bvector<> m_FPrintBV[eEnd];
aggregator_type m_Agg;
};
static const size_t WORD_SIZE = 28;
using THitList = vector<unsigned>;
static
void generate_kmers(vector<tuple<string,int>>& top_words,
vector<tuple<string,int>>& lo_words,
const vector<char>& data,
size_t N,
unsigned word_size)
{
cout << "k-mer generation... " << endl;
top_words.clear();
lo_words.clear();
if (data.size() < word_size)
return;
size_t end_pos = data.size() - word_size;
size_t i = 0;
map<string, int> words;
while (i < end_pos)
{
string s(&data[i], word_size);
if (s.find('N') == string::npos)
words[s] += 1;
i += word_size;
if (i % 10000 == 0)
{
cout << "\r" << i << "/" << end_pos << flush;
}
}
cout << endl << "Picking k-mer samples..." << flush;
multimap<int,string, greater<int>> dst;
for_each(words.begin(), words.end(), [&](const std::pair<string,int>& p)
{
dst.emplace(p.second, p.first);
});
{
auto it = dst.begin();
for(size_t count = 0; count < N && it !=dst.end(); ++it,++count)
top_words.emplace_back(it->second, it->first);
}
{
auto it = dst.rbegin();
for(size_t count = 0; count < N && it !=dst.rend(); ++it, ++count)
lo_words.emplace_back(it->second, it->first);
}
cout << "OK" << endl;
}
static
void find_word_2way(vector<char>& data,
const char* word, unsigned word_size,
THitList& r)
{
if (data.size() < word_size)
return;
size_t i = 0;
size_t end_pos = data.size() - word_size;
while (i < end_pos)
{
bool found = true;
for (size_t j = i, k = 0, l = word_size - 1; l > k; ++j, ++k, --l)
{
if (data[j] != word[k] || data[i + l] != word[l])
{
found = false;
break;
}
}
if (found)
r.push_back(unsigned(i));
++i;
}
}
static
void find_words(const vector<char>& data,
vector<const char*> words,
unsigned word_size,
vector<vector<unsigned>>& hits)
{
if (data.size() < word_size)
return;
size_t i = 0;
size_t end_pos = data.size() - word_size;
size_t words_size = words.size();
while (i < end_pos)
{
for (size_t idx = 0; idx < words_size; ++idx)
{
auto& word = words[idx];
bool found = true;
for (size_t j = i, k = 0, l = word_size - 1; l > k; ++j, ++k, --l)
{
if (data[j] != word[k] || data[i + l] != word[l])
{
found = false;
break;
}
} if (found)
{
hits[idx].push_back(unsigned(i));
break;
}
} ++i;
} }
static
bool hitlist_compare(const THitList& h1, const THitList& h2)
{
if (h1.size() != h2.size())
{
cerr << "size1 = " << h1.size() << " size2 = " << h2.size() << endl;
return false;
}
for (size_t i = 0; i < h1.size(); ++i)
{
if (h1[i] != h2[i])
return false;
}
return true;
}
int main(int argc, char *argv[])
{
if (argc < 3)
{
show_help();
return 1;
}
std::vector<char> seq_vect;
try
{
auto ret = parse_args(argc, argv);
if (ret != 0)
return ret;
DNA_FingerprintScanner idx;
if (!ifa_name.empty())
{
auto res = load_FASTA(ifa_name, seq_vect);
if (res != 0)
return res;
std::cout << "FASTA sequence size=" << seq_vect.size() << std::endl;
{
bm::chrono_taker tt1("2. Build DNA index", 1, &timing_map);
idx.Build(seq_vect);
}
}
if (is_search)
{
vector<tuple<string,int> > h_words;
vector<tuple<string,int> > l_words;
vector<tuple<string,int>>& words = h_word_set ? h_words : l_words;
generate_kmers(h_words, l_words, seq_vect, 25, WORD_SIZE);
vector<THitList> word_hits;
vector<THitList> word_hits_agg;
{
vector<const char*> word_list;
for (const auto& w : words)
{
word_list.push_back(get<0>(w).c_str());
}
word_hits.resize(words.size());
for_each(word_hits.begin(), word_hits.end(), [](THitList& ht) {
ht.reserve(12000);
});
bm::chrono_taker tt1("6. String search 2-way single pass",
unsigned(words.size()), &timing_map);
find_words(seq_vect, word_list, unsigned(WORD_SIZE), word_hits);
}
{
bm::chrono_taker tt1("7. Aggregated search single pass",
unsigned(words.size()), &timing_map);
idx.FindCollection(words, word_hits_agg);
}
for (size_t word_idx = 0; word_idx < words.size(); ++ word_idx)
{
auto& word = get<0>(words[word_idx]);
THitList hits1;
{
bm::chrono_taker tt1("3. String search 2-way", 1, &timing_map);
find_word_2way(seq_vect,
word.c_str(), unsigned(word.size()),
hits1);
}
THitList hits2;
{
bm::chrono_taker tt1("4. Search with bvector SHIFT+AND", 1, &timing_map);
idx.Find(word, hits2);
}
THitList hits4;
{
bm::chrono_taker tt1("5. Search with aggregator fused SHIFT+AND", 1, &timing_map);
idx.FindAggFused(word, hits4);
}
if (!hitlist_compare(hits1, hits2)
|| !hitlist_compare(hits2, hits4))
{
cout << "Mismatch ERROR for: " << word << endl;
}
else
if (!hitlist_compare(word_hits[word_idx], hits1)
|| !hitlist_compare(word_hits_agg[word_idx], hits1))
{
cout << "Sigle pass mismatch ERROR for: " << word << endl;
}
else
{
cout << word_idx << ": " << word << ": " << hits1.size() << " hits " << endl;
}
}
}
if (is_timing) {
std::cout << std::endl << "Performance:" << std::endl;
bm::chrono_taker::print_duration_map(timing_map, bm::chrono_taker::ct_all);
}
}
catch (std::exception& ex)
{
std::cerr << "Error:" << ex.what() << std::endl;
return 1;
}
return 0;
}