import faiss
import numpy as np
d = 768
n = 1_000
ids = np.arange(n).astype('int64')
training_data = np.random.rand(n, d).astype('float32')
def read_ids_codes():
try:
return np.load("/tmp/ids.npy"), np.load("/tmp/codes.npy")
except FileNotFoundError:
return None, None
def write_ids_codes(ids, codes):
np.save("/tmp/ids.npy", ids)
np.save("/tmp/codes.npy", codes.reshape(len(ids), -1))
def write_template_index(template_index):
faiss.write_index(template_index, "/tmp/template.index")
def read_template_index_instance():
return faiss.read_index("/tmp/template.index")
nbits = 1536
codes = []
database_vector_float32 = np.random.rand(1, d).astype(np.float32)
for i in range(10):
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
code = index.index.sa_encode(database_vector_float32)
codes.append(code)
for i in range(1, 10):
assert np.array_equal(codes[0], codes[i])
ids, codes = read_ids_codes()
database_vector_id, database_vector_float32 = max(ids) + 1 if ids is not None else 1, np.random.rand(1, d).astype(np.float32)
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
code = index.index.sa_encode(database_vector_float32)
if ids is not None and codes is not None:
ids = np.concatenate((ids, [database_vector_id]))
codes = np.vstack((codes, code))
else:
ids = np.array([database_vector_id])
codes = np.array([code])
write_ids_codes(ids, codes)
query_vector_float32 = np.random.rand(1, d).astype(np.float32)
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
ids, codes = read_ids_codes()
index.add_sa_codes(codes, ids)
index.search(query_vector_float32, k=5)
!rm /tmp/ids.npy /tmp/codes.npy
M = d//8
nbits = 8
template_index = faiss.index_factory(d, f"IDMap2,PQ{M}x{nbits}")
template_index.train(training_data)
write_template_index(template_index)
index = read_template_index_instance()
ids, codes = read_ids_codes()
database_vector_id, database_vector_float32 = max(ids) + 1 if ids is not None else 1, np.random.rand(1, d).astype(np.float32)
code = index.index.sa_encode(database_vector_float32)
if ids is not None and codes is not None:
ids = np.concatenate((ids, [database_vector_id]))
codes = np.vstack((codes, code))
else:
ids = np.array([database_vector_id])
codes = np.array([code])
write_ids_codes(ids, codes)
query_vector_float32 = np.random.rand(1, d).astype(np.float32)
id_wrapper_index = read_template_index_instance()
ids, codes = read_ids_codes()
id_wrapper_index.add_sa_codes(codes, ids)
id_wrapper_index.search(query_vector_float32, k=5)
!rm /tmp/ids.npy /tmp/codes.npy /tmp/template.index
n, d
database_vector_ids, database_vector_float32s = np.arange(n), np.random.rand(n, d).astype(np.float32)
query_vector_float32s = np.random.rand(n, d).astype(np.float32)
index = faiss.index_factory(d, "IDMap2,Flat")
index.add_with_ids(database_vector_float32s, database_vector_ids)
_, ground_truth_result_ids= index.search(query_vector_float32s, k=1)
from dataclasses import dataclass
pq_m_nbits = (
(96, 8),
(192, 4),
(192, 8),
(384, 4),
(384, 8),
(768, 4),
)
lsh_nbits = (768, 1536, 3072, 6144, 12288, 24576)
@dataclass
class Record:
type_: str
index: faiss.Index
args: tuple
recall: float
results = []
for m, nbits in pq_m_nbits:
print("pq", m, nbits)
index = faiss.index_factory(d, f"IDMap2,PQ{m}x{nbits}")
index.train(training_data)
index.add_with_ids(database_vector_float32s, database_vector_ids)
_, result_ids = index.search(query_vector_float32s, k=1)
recall = sum(result_ids == ground_truth_result_ids)
results.append(Record("pq", index, (m, nbits), recall))
for nbits in lsh_nbits:
print("lsh", nbits)
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
index.add_with_ids(database_vector_float32s, database_vector_ids)
_, result_ids = index.search(query_vector_float32s, k=1)
recall = sum(result_ids == ground_truth_result_ids)
results.append(Record("lsh", index, (nbits,), recall))
import matplotlib.pyplot as plt
import numpy as np
def create_grouped_bar_chart(x_values, y_values_list, labels_list, xlabel, ylabel, title):
num_bars_per_group = len(x_values)
plt.figure(figsize=(12, 6))
for x, y_values, labels in zip(x_values, y_values_list, labels_list):
num_bars = len(y_values)
bar_width = 0.08 * x
bar_positions = np.arange(num_bars) * bar_width - (num_bars - 1) * bar_width / 2 + x
bars = plt.bar(bar_positions, y_values, width=bar_width)
for bar, label in zip(bars, labels):
height = bar.get_height()
plt.annotate(
label,
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3),
textcoords="offset points",
ha='center', va='bottom'
)
plt.xscale('log')
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.xticks(x_values, labels=[str(x) for x in x_values])
plt.tight_layout()
plt.show()
from collections import defaultdict
x = defaultdict(list)
x[1].append(("flat", 1.00))
for r in results:
y_value = r.recall[0] / n
x_value = int(d * 4 / r.index.sa_code_size())
label = None
if r.type_ == "pq":
label = f"PQ{r.args[0]}x{r.args[1]}"
if r.type_ == "lsh":
label = f"LSH{r.args[0]}"
x[x_value].append((label, y_value))
x_values = sorted(list(x.keys()))
create_grouped_bar_chart(
x_values,
[[e[1] for e in x[x_value]] for x_value in x_values],
[[e[0] for e in x[x_value]] for x_value in x_values],
"compression ratio",
"recall@1 q=1,000 queries",
"recall@1 for a database of n=1,000 d=768 vectors",
)