from ctypes import (Structure,
POINTER,
c_char,
c_uint,
c_uint8,
c_void_p,
c_char_p,
c_int,
c_double,
c_float,
cdll)
import ctypes.util
import csv
class NativeError(Exception):
pass
class APIError(Exception):
pass
class Index(object):
class ObjectDistance(Structure):
_fields_ = [
('id', c_uint),
('distance', c_float)
]
def __repr__(self):
return "ObjectDistances({},{})".format(
self.id,
self.distance)
__ngt = cdll.LoadLibrary(ctypes.util.find_library("ngt"))
__ngt.ngt_open_index.argtypes = [c_char_p, c_void_p]
__ngt.ngt_open_index.restype = c_void_p
__ngt.ngt_create_graph_and_tree.argtypes = [c_char_p, c_void_p, c_void_p]
__ngt.ngt_create_graph_and_tree.restype = c_void_p
__ngt.ngt_create_property.argtypes = [c_void_p]
__ngt.ngt_create_property.restype = c_void_p
__ngt.ngt_save_index.argtypes = [c_void_p, c_char_p, c_void_p]
__ngt.ngt_get_property.argtypes = [c_void_p, c_void_p, c_void_p]
__ngt.ngt_get_property_dimension.argtypes = [c_void_p, c_void_p]
__ngt.ngt_set_property_dimension.argtypes = [c_void_p, c_int, c_void_p]
__ngt.ngt_set_property_edge_size_for_creation.argtypes = [
c_void_p, c_int, c_void_p]
__ngt.ngt_set_property_edge_size_for_search.argtypes = [
c_void_p, c_int, c_void_p]
__ngt.ngt_get_property_object_type.argtypes = [c_void_p]
__ngt.ngt_is_property_object_type_float.argtypes = [c_int]
__ngt.ngt_is_property_object_type_integer.argtypes = [c_int]
__ngt.ngt_set_property_object_type_float.argtypes = [c_void_p, c_void_p]
__ngt.ngt_set_property_object_type_integer.argtypes = [c_void_p, c_void_p]
__ngt.ngt_set_property_distance_type_l1.argtypes = [c_void_p, c_void_p]
__ngt.ngt_set_property_distance_type_l2.argtypes = [c_void_p, c_void_p]
__ngt.ngt_set_property_distance_type_angle.argtypes = [c_void_p, c_void_p]
__ngt.ngt_set_property_distance_type_hamming.argtypes = [c_void_p, c_void_p]
__ngt.ngt_set_property_distance_type_jaccard.argtypes = [c_void_p, c_void_p]
__ngt.ngt_set_property_distance_type_cosine.argtypes = [c_void_p, c_void_p]
__ngt.ngt_create_empty_results.argtype = [c_void_p]
__ngt.ngt_create_empty_results.restype = c_void_p
__ngt.ngt_get_size.argtypes = [c_void_p, c_void_p]
__ngt.ngt_get_result.argtypes = [c_void_p, c_int, c_void_p]
__ngt.ngt_get_result.restype = ObjectDistance
__ngt.ngt_create_index.argtypes = [c_void_p, c_int, c_void_p]
__ngt.ngt_remove_index.argtypes = [c_void_p, c_uint, c_void_p]
__ngt.ngt_get_object_space.argtypes = [c_void_p]
__ngt.ngt_get_object_space.restype = c_void_p
__ngt.ngt_get_object_as_float.argtypes = [c_void_p, c_int, c_void_p]
__ngt.ngt_get_object_as_float.restype = POINTER(c_float)
__ngt.ngt_get_object_as_integer.argtypes = [c_void_p, c_int, c_void_p]
__ngt.ngt_get_object_as_integer.restype = POINTER(c_uint8)
__ngt.ngt_destroy_results.argtypes = [c_void_p]
__ngt.ngt_destroy_property.argtypes = [c_void_p]
__ngt.ngt_close_index.argtypes = [c_void_p]
__ngt.ngt_get_property_edge_size_for_creation.argtypes = [
c_void_p, c_void_p]
__ngt.ngt_get_property_edge_size_for_search.argtypes = [c_void_p, c_void_p]
__ngt.ngt_get_property_distance_type.argtypes = [c_void_p, c_void_p]
__ngt.ngt_create_error_object.restype = c_void_p
__ngt.ngt_get_error_string.argtypes = [c_void_p]
__ngt.ngt_get_error_string.restype = c_char_p
__ngt.ngt_clear_error_string.argtypes = [c_void_p]
__ngt.ngt_destroy_error_object.argtypes = [c_void_p]
@staticmethod
def _check_error_num(status, error):
if status == 0:
message = Index.__ngt.ngt_get_error_string(error)
Index.__ngt.ngt_clear_error_string(error)
raise NativeError(message)
@staticmethod
def _check_error_pnum(status, error):
if status < 0:
message = Index.__ngt.ngt_get_error_string(error)
Index.__ngt.ngt_clear_error_string(error)
raise NativeError(message)
@staticmethod
def _check_error_obj(status, error):
if status is None:
message = Index.__ngt.ngt_get_error_string(error)
Index.__ngt.ngt_clear_error_string(error)
raise NativeError(message)
@staticmethod
def create(path, dimension,
edge_size_for_creation=10, edge_size_for_search=40,
object_type="Float", distance_type="L2"):
err = None
prop = None
index = None
try:
err = Index.__ngt.ngt_create_error_object()
prop = Index.__ngt.ngt_create_property(err)
Index._check_error_obj(prop, err)
stat = Index.__ngt.ngt_set_property_dimension(prop, dimension, err)
Index._check_error_num(stat, err)
stat = Index.__ngt.ngt_set_property_edge_size_for_creation(
prop, edge_size_for_creation, err)
Index._check_error_num(stat, err)
stat = Index.__ngt.ngt_set_property_edge_size_for_search(
prop, edge_size_for_search, err)
Index._check_error_num(stat, err)
if object_type == "Float":
stat = Index.__ngt.ngt_set_property_object_type_float(
prop, err)
Index._check_error_num(stat, err)
elif object_type == "Integer":
stat = Index.__ngt.ngt_set_property_object_type_integer(
prop, err)
Index._check_error_num(stat, err)
if distance_type == "L1":
stat = Index.__ngt.ngt_set_property_distance_type_l1(
prop, err)
Index._check_error_num(stat, err)
elif distance_type == "L2":
stat = Index.__ngt.ngt_set_property_distance_type_l2(
prop, err)
Index._check_error_num(stat, err)
elif distance_type == "Angle":
stat = Index.__ngt.ngt_set_property_distance_type_angle(
prop, err)
Index._check_error_num(stat, err)
elif distance_type == "Hamming":
stat = Index.__ngt.ngt_set_property_distance_type_hamming(
prop, err)
Index._check_error_num(stat, err)
elif distance_type == "Jaccard":
stat = Index.__ngt.ngt_set_property_distance_type_jaccard(
prop, err)
Index._check_error_num(stat, err)
elif distance_type == "Cosine":
stat = Index.__ngt.ngt_set_property_distance_type_cosine(
prop, err)
Index._check_error_num(stat, err)
index = Index.__ngt.ngt_create_graph_and_tree(path, prop, err)
Index._check_error_obj(index, err)
Index.__ngt.ngt_close_index(index)
index = None
obj = Index(path)
finally:
if err is not None:
Index.__ngt.ngt_destroy_error_object(err)
if prop is not None:
Index.__ngt.ngt_destroy_property(prop)
if index is not None:
Index.__ngt.ngt_close_index(index)
return obj
def __init__(self, path):
self.path = path
self.err = self.__ngt.ngt_create_error_object()
self.index = self.__ngt.ngt_open_index(path, self.err)
self._check_error_obj(self.index, self.err)
self.prop = self.__ngt.ngt_create_property(self.err)
self._check_error_obj(self.prop, self.err)
stat = self.__ngt.ngt_get_property(self.index, self.prop, self.err)
self._check_error_num(stat, self.err)
self.dim = self.__ngt.ngt_get_property_dimension(self.prop, self.err)
self._check_error_pnum(self.dim, self.err)
self.otype = self.__ngt.ngt_get_property_object_type(
self.prop, self.err)
self._check_error_obj(self.otype, self.err)
self.is_float = self.__ngt.ngt_is_property_object_type_float(
self.otype) > 0
self.ospace = self.__ngt.ngt_get_object_space(self.index, self.err)
self._check_error_obj(self.ospace, self.err)
self.__ngt.ngt_search_index.argtypes = [
c_void_p,
(c_double * self.dim),
c_int,
c_int,
c_float,
c_float,
c_void_p,
c_void_p
]
self.__ngt.ngt_insert_index.argtypes = [
c_void_p, (c_double * self.dim), c_int, c_void_p]
self.__ngt.ngt_batch_append_index.argtypes = [
c_void_p,
c_void_p,
c_int,
c_void_p
]
self.__ngt.ngt_insert_index.restype = c_uint
def search(self, query, k=20, epsilon=0.1):
try:
err = self.__ngt.ngt_create_error_object()
results = self.__ngt.ngt_create_empty_results(err)
self._check_error_obj(results, err)
cvec = (c_double * len(query))(*query)
stat = self.__ngt.ngt_search_index(
self.index, cvec, self.dim, k, epsilon, -1.0, results, err)
self._check_error_num(stat, err)
rsize = self.__ngt.ngt_get_size(results, err)
self._check_error_pnum(rsize, err)
ret = []
for i in range(rsize):
result = self.__ngt.ngt_get_result(results, i, err)
self._check_error_obj(result, err)
ret.append(result)
finally:
if err is not None:
self.__ngt.ngt_destroy_error_object(err)
if results is not None:
self.__ngt.ngt_destroy_results(results)
return ret
def insert_object(self, object):
cvec = (c_double * len(object))(*object)
id = self.__ngt.ngt_insert_index(self.index, cvec, self.dim, self.err)
self._check_error_pnum(id, self.err)
return id
def insert(self, objects, num_threads=8):
idList = []
for object in objects:
idList.append(self.insert_object(object))
self.build_index(num_threads)
return idList
def insert_blob(self, objects, num_threads=8):
data_count = len(objects)
if len(objects[0]) != self.dim:
message = "insert_blob: Inconsistent dimensionality. " \
+ "The expected is {}. The specified is {}".format(
self.dim, len(objects[0]))
raise APIError(message)
merged_vectors = []
for object in objects:
merged_vectors.extend(object)
cvec = (c_float * len(merged_vectors))(*merged_vectors)
stat = self.__ngt.ngt_batch_append_index(
self.index, cvec, data_count, self.err)
self._check_error_num(stat, self.err)
self.build_index(num_threads)
return
def insert_from_tsv(self, path, num_threads=8, dlmt='\t'):
idList = []
with open(path) as f:
reader = csv.reader(f, delimiter=dlmt)
for row in reader:
object = [float(x) for x in row]
idList.append(self.insert_object(object))
self.build_index(num_threads)
return idList
def build_index(self, num_threads=8):
stat = self.__ngt.ngt_create_index(self.index, num_threads, self.err)
self._check_error_num(stat, self.err)
def remove(self, id):
stat = self.__ngt.ngt_remove_index(self.index, id, self.err)
self._check_error_num(stat, self.err)
def get_object(self, id):
vec = None
try:
err = self.__ngt.ngt_create_error_object()
if self.is_float:
cvec = self.__ngt.ngt_get_object_as_float(self.ospace, id, err)
else:
cvec = self.__ngt.ngt_get_object_as_integer(
self.ospace, id, err)
vec = []
try:
for i in range(self.dim):
vec.append(cvec[i])
except:
message = Index.__ngt.ngt_get_error_string(err)
Index.__ngt.ngt_clear_error_string(err)
raise NativeError(message)
finally:
if err is not None:
self.__ngt.ngt_destroy_error_object(err)
return vec
def save(self, path=None):
if path is None:
path = self.path
stat = self.__ngt.ngt_save_index(self.index, path, self.err)
self._check_error_num(stat, self.err)
def __del__(self):
if self.err:
self.__ngt.ngt_destroy_error_object(self.err)
if self.prop:
self.__ngt.ngt_destroy_property(self.prop)
if self.index:
self.__ngt.ngt_close_index(self.index)