#include <cstdio>
#include <cstdlib>
#include <memory>
#include <random>
#include <vector>
#include <gtest/gtest.h>
#include <faiss/IVFlib.h>
#include <faiss/IndexIVF.h>
#include <faiss/VectorTransform.h>
#include <faiss/index_factory.h>
namespace {
using idx_t = faiss::idx_t;
int d = 64;
size_t nb = 8000;
size_t nq = 200;
std::mt19937 rng;
std::vector<float> make_data(size_t n) {
std::vector<float> database(n * d);
std::uniform_real_distribution<> distrib;
for (size_t i = 0; i < n * d; i++) {
database[i] = distrib(rng);
}
return database;
}
std::unique_ptr<faiss::Index> make_index(
const char* index_type,
const std::vector<float>& x) {
auto index =
std::unique_ptr<faiss::Index>(faiss::index_factory(d, index_type));
index->train(nb, x.data());
index->add(nb, x.data());
return index;
}
bool test_search_centroid(const char* index_key) {
std::vector<float> xb = make_data(nb); auto index = make_index(index_key, xb);
std::vector<idx_t> centroid_ids(nb);
faiss::ivflib::search_centroid(
index.get(), xb.data(), nb, centroid_ids.data());
const faiss::IndexIVF* ivf = faiss::ivflib::extract_index_ivf(index.get());
for (int i = 0; i < nb; i++) {
bool found = false;
int list_no = centroid_ids[i];
int list_size = ivf->invlists->list_size(list_no);
auto* list = ivf->invlists->get_ids(list_no);
for (int j = 0; j < list_size; j++) {
if (list[j] == i) {
found = true;
break;
}
}
if (!found) {
return false;
}
}
return true;
}
int test_search_and_return_centroids(const char* index_key) {
std::vector<float> xb = make_data(nb); auto index = make_index(index_key, xb);
std::vector<idx_t> centroid_ids(nb);
faiss::ivflib::search_centroid(
index.get(), xb.data(), nb, centroid_ids.data());
faiss::IndexIVF* ivf = faiss::ivflib::extract_index_ivf(index.get());
ivf->nprobe = 4;
std::vector<float> xq = make_data(nq);
int k = 5;
std::vector<idx_t> refI(nq * k);
std::vector<float> refD(nq * k);
index->search(nq, xq.data(), k, refD.data(), refI.data());
std::vector<idx_t> newI(nq * k);
std::vector<float> newD(nq * k);
std::vector<idx_t> query_centroid_ids(nq);
std::vector<idx_t> result_centroid_ids(nq * k);
faiss::ivflib::search_and_return_centroids(
index.get(),
nq,
xq.data(),
k,
newD.data(),
newI.data(),
query_centroid_ids.data(),
result_centroid_ids.data());
if (newI != refI) {
return 1;
}
for (int i = 0; i < nq * k; i++) {
int list_no = result_centroid_ids[i];
int result_no = newI[i];
if (result_no < 0) {
continue;
}
bool found = false;
int list_size = ivf->invlists->list_size(list_no);
auto* list = ivf->invlists->get_ids(list_no);
for (int j = 0; j < list_size; j++) {
if (list[j] == result_no) {
found = true;
break;
}
}
if (!found) {
return 2;
}
}
return 0;
}
}
TEST(testSearchCentroid, IVFFlat) {
bool ok = test_search_centroid("IVF32,Flat");
EXPECT_TRUE(ok);
}
TEST(testSearchCentroid, PCAIVFFlat) {
bool ok = test_search_centroid("PCA16,IVF32,Flat");
EXPECT_TRUE(ok);
}
TEST(testSearchAndReturnCentroids, IVFFlat) {
int err = test_search_and_return_centroids("IVF32,Flat");
EXPECT_NE(err, 1);
EXPECT_NE(err, 2);
}
TEST(testSearchAndReturnCentroids, PCAIVFFlat) {
int err = test_search_and_return_centroids("PCA16,IVF32,Flat");
EXPECT_NE(err, 1);
EXPECT_NE(err, 2);
}