#include <cinttypes>
#include <cstdio>
#include <cstdlib>
#include <memory>
#include <random>
#include <thread>
#include <vector>
#include <gtest/gtest.h>
#include <faiss/AutoTune.h>
#include <faiss/IVFlib.h>
#include <faiss/IndexBinaryIVF.h>
#include <faiss/IndexIVF.h>
#include <faiss/IndexPreTransform.h>
#include <faiss/index_factory.h>
#include <faiss/utils/distances.h>
using namespace faiss;
namespace {
int d = 32;
size_t nt = 5000;
size_t nb = 1000;
size_t nq = 200;
int k = 10;
std::mt19937 rng;
std::vector<float> make_data(size_t n) {
std::vector<float> database(n * d);
std::uniform_real_distribution<> distrib;
for (size_t i = 0; i < n * d; i++) {
database[i] = distrib(rng);
}
return database;
}
std::unique_ptr<Index> make_trained_index(
const char* index_type,
MetricType metric_type) {
auto index =
std::unique_ptr<Index>(index_factory(d, index_type, metric_type));
auto xt = make_data(nt);
index->train(nt, xt.data());
ParameterSpace().set_index_parameter(index.get(), "nprobe", 4);
return index;
}
std::vector<idx_t> search_index(Index* index, const float* xq) {
std::vector<idx_t> I(k * nq);
std::vector<float> D(k * nq);
index->search(nq, xq, k, D.data(), I.data());
return I;
}
void test_lowlevel_access(const char* index_key, MetricType metric) {
std::unique_ptr<Index> index = make_trained_index(index_key, metric);
auto xb = make_data(nb);
index->add(nb, xb.data());
const IndexPreTransform* index_pt =
dynamic_cast<const IndexPreTransform*>(index.get());
int dt = index->d;
const float* xbt = xb.data();
std::unique_ptr<float[]> del_xbt;
if (index_pt) {
dt = index_pt->index->d;
xbt = index_pt->apply_chain(nb, xb.data());
if (xbt != xb.data()) {
del_xbt.reset((float*)xbt);
}
}
IndexIVF* index_ivf = ivflib::extract_index_ivf(index.get());
std::vector<idx_t> list_nos(nb);
std::vector<uint8_t> codes(index_ivf->code_size * nb);
index_ivf->quantizer->assign(nb, xbt, list_nos.data());
index_ivf->encode_vectors(nb, xbt, list_nos.data(), codes.data());
std::vector<float> decoded(nb * dt);
std::vector<uint8_t> codes2(codes);
std::vector<float> decoded2(nb * dt);
index_ivf->decode_vectors(
nb, codes.data(), list_nos.data(), decoded.data());
index_ivf->encode_vectors(nb, xbt, list_nos.data(), codes2.data());
index_ivf->decode_vectors(
nb, codes2.data(), list_nos.data(), decoded2.data());
EXPECT_LT(
faiss::fvec_L2sqr(decoded.data(), decoded2.data(), nb * dt), 1e-5);
const InvertedLists* il = index_ivf->invlists;
for (int list_no = 0; list_no < index_ivf->nlist; list_no++) {
InvertedLists::ScopedCodes ivf_codes(il, list_no);
InvertedLists::ScopedIds ivf_ids(il, list_no);
size_t list_size = il->list_size(list_no);
for (int i = 0; i < list_size; i++) {
const uint8_t* ref_code = ivf_codes.get() + i * il->code_size;
const uint8_t* new_code = codes.data() + ivf_ids[i] * il->code_size;
EXPECT_EQ(memcmp(ref_code, new_code, il->code_size), 0);
}
}
auto xq = make_data(nq);
auto ref_I = search_index(index.get(), xq.data());
const float* xqt = xq.data();
std::unique_ptr<float[]> del_xqt;
if (index_pt) {
xqt = index_pt->apply_chain(nq, xq.data());
if (xqt != xq.data()) {
del_xqt.reset((float*)xqt);
}
}
int nprobe = index_ivf->nprobe;
std::vector<idx_t> q_lists(nq * nprobe);
std::vector<float> q_dis(nq * nprobe);
index_ivf->quantizer->search(nq, xqt, nprobe, q_dis.data(), q_lists.data());
std::unique_ptr<InvertedListScanner> scanner(
index_ivf->get_InvertedListScanner());
for (int i = 0; i < nq; i++) {
std::vector<idx_t> I(k, -1);
float default_dis = metric == METRIC_L2 ? HUGE_VAL : -HUGE_VAL;
std::vector<float> D(k, default_dis);
scanner->set_query(xqt + i * dt);
for (int j = 0; j < nprobe; j++) {
int list_no = q_lists[i * nprobe + j];
if (list_no < 0) {
continue;
}
scanner->set_list(list_no, q_dis[i * nprobe + j]);
scanner->scan_codes(
il->list_size(list_no),
InvertedLists::ScopedCodes(il, list_no).get(),
InvertedLists::ScopedIds(il, list_no).get(),
D.data(),
I.data(),
k);
if (j == 0) {
for (int jj = 0; jj < k; jj++) {
int vno = I[jj];
if (vno < 0) {
break; }
float computed_D = scanner->distance_to_code(
codes.data() + vno * il->code_size);
EXPECT_FLOAT_EQ(computed_D, D[jj]);
}
}
}
if (metric == METRIC_L2) {
maxheap_reorder(k, D.data(), I.data());
} else {
minheap_reorder(k, D.data(), I.data());
}
for (int j = 0; j < k; j++) {
EXPECT_EQ(I[j], ref_I[i * k + j]);
}
}
}
}
TEST(TestLowLevelIVF, IVFFlatL2) {
test_lowlevel_access("IVF32,Flat", METRIC_L2);
}
TEST(TestLowLevelIVF, PCAIVFFlatL2) {
test_lowlevel_access("PCAR16,IVF32,Flat", METRIC_L2);
}
TEST(TestLowLevelIVF, IVFFlatIP) {
test_lowlevel_access("IVF32,Flat", METRIC_INNER_PRODUCT);
}
TEST(TestLowLevelIVF, IVFSQL2) {
test_lowlevel_access("IVF32,SQ8", METRIC_L2);
}
TEST(TestLowLevelIVF, IVFSQIP) {
test_lowlevel_access("IVF32,SQ8", METRIC_INNER_PRODUCT);
}
TEST(TestLowLevelIVF, IVFPQL2) {
test_lowlevel_access("IVF32,PQ4np", METRIC_L2);
}
TEST(TestLowLevelIVF, IVFPQIP) {
test_lowlevel_access("IVF32,PQ4np", METRIC_INNER_PRODUCT);
}
TEST(TestLowLevelIVF, IVFRaBitQ) {
test_lowlevel_access("IVF32,RaBitQ", METRIC_L2);
}
TEST(TestLowLevelIVF, IVFRQ) {
test_lowlevel_access("IVF32,RQ16x8", METRIC_L2);
}
namespace {
int nbit = 256;
std::vector<uint8_t> make_data_binary(size_t n) {
std::vector<uint8_t> database(n * nbit / 8);
std::uniform_int_distribution<> distrib;
for (size_t i = 0; i < n * d; i++) {
database[i] = distrib(rng);
}
return database;
}
std::unique_ptr<IndexBinary> make_trained_index_binary(const char* index_type) {
auto index = std::unique_ptr<IndexBinary>(
index_binary_factory(nbit, index_type));
auto xt = make_data_binary(nt);
index->train(nt, xt.data());
return index;
}
void test_lowlevel_access_binary(const char* index_key) {
std::unique_ptr<IndexBinary> index = make_trained_index_binary(index_key);
IndexBinaryIVF* index_ivf = dynamic_cast<IndexBinaryIVF*>(index.get());
assert(index_ivf);
index_ivf->nprobe = 4;
auto xb = make_data_binary(nb);
index->add(nb, xb.data());
std::vector<idx_t> list_nos(nb);
index_ivf->quantizer->assign(nb, xb.data(), list_nos.data());
const InvertedLists* il = index_ivf->invlists;
auto xq = make_data_binary(nq);
std::vector<idx_t> I_ref(k * nq);
std::vector<int32_t> D_ref(k * nq);
index->search(nq, xq.data(), k, D_ref.data(), I_ref.data());
int nprobe = index_ivf->nprobe;
std::vector<idx_t> q_lists(nq * nprobe);
std::vector<int32_t> q_dis(nq * nprobe);
index_ivf->quantizer->search(
nq, xq.data(), nprobe, q_dis.data(), q_lists.data());
std::unique_ptr<BinaryInvertedListScanner> scanner(
index_ivf->get_InvertedListScanner());
for (int i = 0; i < nq; i++) {
std::vector<idx_t> I(k, -1);
uint32_t default_dis = 1 << 30;
std::vector<int32_t> D(k, default_dis);
scanner->set_query(xq.data() + i * index_ivf->code_size);
for (int j = 0; j < nprobe; j++) {
int list_no = q_lists[i * nprobe + j];
if (list_no < 0) {
continue;
}
scanner->set_list(list_no, q_dis[i * nprobe + j]);
scanner->scan_codes(
il->list_size(list_no),
InvertedLists::ScopedCodes(il, list_no).get(),
InvertedLists::ScopedIds(il, list_no).get(),
D.data(),
I.data(),
k);
if (j == 0) {
for (int jj = 0; jj < k; jj++) {
int vno = I[jj];
if (vno < 0) {
break; }
float computed_D = scanner->distance_to_code(
xb.data() + vno * il->code_size);
EXPECT_EQ(computed_D, D[jj]);
}
}
}
heap_reorder<CMax<int32_t, idx_t>>(k, D.data(), I.data());
for (int j = 0; j < k; j++) {
EXPECT_LE(D[j], D_ref[i * k + k - 1]);
if (D[j] < D_ref[i * k + k - 1]) {
int j2 = 0;
while (j2 < k) {
if (I[j] == I_ref[i * k + j2]) {
break;
}
j2++;
}
EXPECT_LT(j2, k); if (j2 < k) {
EXPECT_EQ(D[j], D_ref[i * k + j2]);
}
}
}
}
}
}
TEST(TestLowLevelIVF, IVFBinary) {
test_lowlevel_access_binary("BIVF32");
}
namespace {
void test_threaded_search(const char* index_key, MetricType metric) {
std::unique_ptr<Index> index = make_trained_index(index_key, metric);
auto xb = make_data(nb);
index->add(nb, xb.data());
const IndexPreTransform* index_pt =
dynamic_cast<const IndexPreTransform*>(index.get());
int dt = index->d;
const float* xbt = xb.data();
std::unique_ptr<float[]> del_xbt;
if (index_pt) {
dt = index_pt->index->d;
xbt = index_pt->apply_chain(nb, xb.data());
if (xbt != xb.data()) {
del_xbt.reset((float*)xbt);
}
}
IndexIVF* index_ivf = ivflib::extract_index_ivf(index.get());
auto xq = make_data(nq);
auto ref_I = search_index(index.get(), xq.data());
const float* xqt = xq.data();
std::unique_ptr<float[]> del_xqt;
if (index_pt) {
xqt = index_pt->apply_chain(nq, xq.data());
if (xqt != xq.data()) {
del_xqt.reset((float*)xqt);
}
}
int nprobe = index_ivf->nprobe;
std::vector<idx_t> q_lists(nq * nprobe);
std::vector<float> q_dis(nq * nprobe);
index_ivf->quantizer->search(nq, xqt, nprobe, q_dis.data(), q_lists.data());
int nproc = 3;
for (int i = 0; i < nq; i++) {
std::vector<idx_t> I(k * nproc, -1);
float default_dis = metric == METRIC_L2 ? HUGE_VAL : -HUGE_VAL;
std::vector<float> D(k * nproc, default_dis);
auto search_function = [index_ivf,
&I,
&D,
dt,
i,
nproc,
xqt,
nprobe,
&q_dis,
&q_lists](int rank) {
const InvertedLists* il = index_ivf->invlists;
std::unique_ptr<InvertedListScanner> scanner(
index_ivf->get_InvertedListScanner());
idx_t* local_I = I.data() + rank * k;
float* local_D = D.data() + rank * k;
scanner->set_query(xqt + i * dt);
for (int j = rank; j < nprobe; j += nproc) {
int list_no = q_lists[i * nprobe + j];
if (list_no < 0) {
continue;
}
scanner->set_list(list_no, q_dis[i * nprobe + j]);
scanner->scan_codes(
il->list_size(list_no),
InvertedLists::ScopedCodes(il, list_no).get(),
InvertedLists::ScopedIds(il, list_no).get(),
local_D,
local_I,
k);
}
};
std::vector<std::thread> threads;
for (int rank = 0; rank < nproc; rank++) {
threads.emplace_back(search_function, rank);
}
for (int rank = 0; rank < nproc; rank++) {
threads[rank].join();
if (rank == 0) {
continue; }
if (metric == METRIC_L2) {
maxheap_addn(
k,
D.data(),
I.data(),
D.data() + rank * k,
I.data() + rank * k,
k);
} else {
minheap_addn(
k,
D.data(),
I.data(),
D.data() + rank * k,
I.data() + rank * k,
k);
}
}
if (metric == METRIC_L2) {
maxheap_reorder(k, D.data(), I.data());
} else {
minheap_reorder(k, D.data(), I.data());
}
for (int j = 0; j < k; j++) {
EXPECT_EQ(I[j], ref_I[i * k + j]);
}
}
}
}
TEST(TestLowLevelIVF, ThreadedSearch) {
test_threaded_search("IVF32,Flat", METRIC_L2);
}