#include <gtest/gtest.h>
#include <cstddef>
#include <memory>
#include <vector>
#include <faiss/IndexIVF.h>
#include <faiss/clone_index.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/index_factory.h>
#include <faiss/invlists/InvertedLists.h>
#include <faiss/utils/random.h>
namespace {
int d = 64;
}
std::vector<float> get_random_vectors(size_t n, int seed) {
std::vector<float> x(n * d);
faiss::rand_smooth_vectors(n, d, x.data(), seed);
seed++;
return x;
}
struct DispatchingInvertedLists : faiss::ReadOnlyInvertedLists {
DispatchingInvertedLists(size_t nlist, size_t code_size)
: faiss::ReadOnlyInvertedLists(nlist, code_size) {
use_iterator = true;
}
faiss::InvertedListsIterator* get_iterator(
size_t list_no,
void* inverted_list_context = nullptr) const override {
assert(inverted_list_context);
auto il =
static_cast<const faiss::InvertedLists*>(inverted_list_context);
return il->get_iterator(list_no);
}
using idx_t = faiss::idx_t;
size_t list_size(size_t list_no) const override {
FAISS_THROW_MSG("use iterator interface");
}
const uint8_t* get_codes(size_t list_no) const override {
FAISS_THROW_MSG("use iterator interface");
}
const idx_t* get_ids(size_t list_no) const override {
FAISS_THROW_MSG("use iterator interface");
}
};
TEST(COMMON, test_common_trained_index) {
int N = 3; int nt = 500; int nb = 200; int nq = 10; int k = 4;
std::unique_ptr<faiss::IndexIVF> empty_index(dynamic_cast<faiss::IndexIVF*>(
faiss::index_factory(d, "IVF32,PQ8np")));
auto xt = get_random_vectors(nt, 123);
empty_index->train(nt, xt.data());
empty_index->nprobe = 4;
std::vector<std::vector<faiss::idx_t>> ref_I(N);
for (int i = 0; i < N; i++) {
std::unique_ptr<faiss::Index> index(
faiss::clone_index(empty_index.get()));
auto xb = get_random_vectors(nb, 1234 + i);
auto xq = get_random_vectors(nq, 12345 + i);
index->add(nb, xb.data());
std::vector<float> D(k * nq);
std::vector<faiss::idx_t> I(k * nq);
index->search(nq, xq.data(), k, D.data(), I.data());
ref_I[i] = I;
}
std::vector<faiss::ArrayInvertedLists> sub_invlists;
for (int i = 0; i < N; i++) {
sub_invlists.emplace_back(empty_index->nlist, empty_index->code_size);
faiss::InvertedLists* invlists = &sub_invlists.back();
empty_index->replace_invlists(invlists, false);
empty_index->reset(); auto xb = get_random_vectors(nb, 1234 + i);
empty_index->add(nb, xb.data());
}
DispatchingInvertedLists di(empty_index->nlist, empty_index->code_size);
empty_index->replace_invlists(&di, false);
std::vector<std::vector<faiss::idx_t>> new_I(N);
#pragma omp parallel for
for (int i = 0; i < N; i++) {
auto xq = get_random_vectors(nq, 12345 + i);
std::vector<float> D(k * nq);
std::vector<faiss::idx_t> I(k * nq);
faiss::SearchParametersIVF params;
params.nprobe = empty_index->nprobe;
params.inverted_list_context = &sub_invlists[i];
empty_index->search(nq, xq.data(), k, D.data(), I.data(), ¶ms);
new_I[i] = I;
}
for (int i = 0; i < N; i++) {
ASSERT_EQ(ref_I[i], new_I[i]);
}
}