ngt 0.8.0

Rust wrappers for NGT nearest neighbor search.
Documentation
use std::marker::PhantomData;
use std::ptr;

use half::f16;
use ngt_sys as sys;
use num_enum::TryFromPrimitive;
use scopeguard::defer;

use crate::error::{make_err, Result};
use crate::ngt::NgtObjectType;
use crate::{NgtDistance, NgtProperties};

#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive)]
#[repr(i32)]
pub enum QgObject {
    Uint8 = 1,
    Float = 2,
    Float16 = 3,
}

mod private {
    pub trait Sealed {}
}

pub trait QgObjectType: private::Sealed {
    fn as_obj() -> QgObject;
}

impl private::Sealed for f32 {}
impl QgObjectType for f32 {
    fn as_obj() -> QgObject {
        QgObject::Float
    }
}

impl private::Sealed for u8 {}
impl QgObjectType for u8 {
    fn as_obj() -> QgObject {
        QgObject::Uint8
    }
}

impl private::Sealed for f16 {}
impl QgObjectType for f16 {
    fn as_obj() -> QgObject {
        QgObject::Float16
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive)]
#[repr(i32)]
pub enum QgDistance {
    L2 = 1,
    Cosine = 4,
}

impl From<QgDistance> for NgtDistance {
    fn from(d: QgDistance) -> Self {
        match d {
            QgDistance::L2 => NgtDistance::L2,
            QgDistance::Cosine => NgtDistance::Cosine,
        }
    }
}

impl TryFrom<NgtDistance> for QgDistance {
    type Error = crate::Error;

    fn try_from(d: NgtDistance) -> Result<Self> {
        match d {
            NgtDistance::L2 => Ok(QgDistance::L2),
            NgtDistance::Cosine => Ok(QgDistance::Cosine),
            _ => Err(format!("Invalid distance {d:?} isn't supported for QG").into()),
        }
    }
}

#[derive(Debug)]
pub struct QgProperties<T> {
    pub(crate) dimension: i32,
    pub(crate) creation_edge_size: i16,
    pub(crate) search_edge_size: i16,
    pub(crate) object_type: QgObject,
    pub(crate) distance_type: QgDistance,
    pub(crate) raw_prop: sys::NGTProperty,
    _marker: PhantomData<T>,
}

unsafe impl<T> Send for QgProperties<T> {}
unsafe impl<T> Sync for QgProperties<T> {}

impl<T> QgProperties<T>
where
    T: QgObjectType,
{
    pub fn dimension(dimension: usize) -> Result<Self> {
        let dimension = i32::try_from(dimension)?;
        let creation_edge_size = 10;
        let search_edge_size = 40;
        let object_type = T::as_obj();
        let distance_type = QgDistance::L2;

        unsafe {
            let ebuf = sys::ngt_create_error_object();
            defer! { sys::ngt_destroy_error_object(ebuf); }

            let raw_prop = sys::ngt_create_property(ebuf);
            if raw_prop.is_null() {
                Err(make_err(ebuf))?
            }

            Self::set_dimension(raw_prop, dimension)?;
            Self::set_creation_edge_size(raw_prop, creation_edge_size)?;
            Self::set_search_edge_size(raw_prop, search_edge_size)?;
            Self::set_object_type(raw_prop, object_type)?;
            Self::set_distance_type(raw_prop, distance_type)?;

            Ok(Self {
                dimension,
                creation_edge_size,
                search_edge_size,
                object_type,
                distance_type,
                raw_prop,
                _marker: PhantomData,
            })
        }
    }

    pub fn try_clone(&self) -> Result<Self> {
        unsafe {
            let ebuf = sys::ngt_create_error_object();
            defer! { sys::ngt_destroy_error_object(ebuf); }

            let raw_prop = sys::ngt_create_property(ebuf);
            if raw_prop.is_null() {
                Err(make_err(ebuf))?
            }

            Self::set_dimension(raw_prop, self.dimension)?;
            Self::set_creation_edge_size(raw_prop, self.creation_edge_size)?;
            Self::set_search_edge_size(raw_prop, self.search_edge_size)?;
            Self::set_object_type(raw_prop, self.object_type)?;
            Self::set_distance_type(raw_prop, self.distance_type)?;

            Ok(Self {
                dimension: self.dimension,
                creation_edge_size: self.creation_edge_size,
                search_edge_size: self.search_edge_size,
                object_type: self.object_type,
                distance_type: self.distance_type,
                raw_prop,
                _marker: PhantomData,
            })
        }
    }

    pub(crate) fn from(index: sys::NGTIndex) -> Result<Self> {
        unsafe {
            let ebuf = sys::ngt_create_error_object();
            defer! { sys::ngt_destroy_error_object(ebuf); }

            let raw_prop = sys::ngt_create_property(ebuf);
            if raw_prop.is_null() {
                Err(make_err(ebuf))?
            }

            if !sys::ngt_get_property(index, raw_prop, ebuf) {
                Err(make_err(ebuf))?
            }

            let dimension = sys::ngt_get_property_dimension(raw_prop, ebuf);
            if dimension < 0 {
                Err(make_err(ebuf))?
            }

            let creation_edge_size = sys::ngt_get_property_edge_size_for_creation(raw_prop, ebuf);
            if creation_edge_size < 0 {
                Err(make_err(ebuf))?
            }

            let search_edge_size = sys::ngt_get_property_edge_size_for_search(raw_prop, ebuf);
            if search_edge_size < 0 {
                Err(make_err(ebuf))?
            }

            let object_type = sys::ngt_get_property_object_type(raw_prop, ebuf);
            if object_type < 0 {
                Err(make_err(ebuf))?
            }
            let object_type = QgObject::try_from(object_type)?;

            let distance_type = sys::ngt_get_property_distance_type(raw_prop, ebuf);
            if distance_type < 0 {
                Err(make_err(ebuf))?
            }
            let distance_type = QgDistance::try_from(distance_type)?;

            Ok(Self {
                dimension,
                creation_edge_size,
                search_edge_size,
                object_type,
                distance_type,
                raw_prop,
                _marker: PhantomData,
            })
        }
    }

    unsafe fn set_dimension(raw_prop: sys::NGTProperty, dimension: i32) -> Result<()> {
        let ebuf = sys::ngt_create_error_object();
        defer! { sys::ngt_destroy_error_object(ebuf); }

        if !sys::ngt_set_property_dimension(raw_prop, dimension, ebuf) {
            Err(make_err(ebuf))?
        }

        Ok(())
    }

    pub fn creation_edge_size(mut self, size: usize) -> Result<Self> {
        let size = i16::try_from(size)?;
        self.creation_edge_size = size;
        unsafe { Self::set_creation_edge_size(self.raw_prop, size)? };
        Ok(self)
    }

    unsafe fn set_creation_edge_size(raw_prop: sys::NGTProperty, size: i16) -> Result<()> {
        let ebuf = sys::ngt_create_error_object();
        defer! { sys::ngt_destroy_error_object(ebuf); }

        if !sys::ngt_set_property_edge_size_for_creation(raw_prop, size, ebuf) {
            Err(make_err(ebuf))?
        }

        Ok(())
    }

    pub fn search_edge_size(mut self, size: usize) -> Result<Self> {
        let size = i16::try_from(size)?;
        self.search_edge_size = size;
        unsafe { Self::set_search_edge_size(self.raw_prop, size)? };
        Ok(self)
    }

    unsafe fn set_search_edge_size(raw_prop: sys::NGTProperty, size: i16) -> Result<()> {
        let ebuf = sys::ngt_create_error_object();
        defer! { sys::ngt_destroy_error_object(ebuf); }

        if !sys::ngt_set_property_edge_size_for_search(raw_prop, size, ebuf) {
            Err(make_err(ebuf))?
        }

        Ok(())
    }

    unsafe fn set_object_type(raw_prop: sys::NGTProperty, object_type: QgObject) -> Result<()> {
        let ebuf = sys::ngt_create_error_object();
        defer! { sys::ngt_destroy_error_object(ebuf); }

        match object_type {
            QgObject::Uint8 => {
                if !sys::ngt_set_property_object_type_integer(raw_prop, ebuf) {
                    Err(make_err(ebuf))?
                }
            }
            QgObject::Float => {
                if !sys::ngt_set_property_object_type_float(raw_prop, ebuf) {
                    Err(make_err(ebuf))?
                }
            }
            QgObject::Float16 => {
                if !sys::ngt_set_property_object_type_float16(raw_prop, ebuf) {
                    Err(make_err(ebuf))?
                }
            }
        }

        Ok(())
    }

    pub fn distance_type(mut self, distance_type: QgDistance) -> Result<Self> {
        self.distance_type = distance_type;
        unsafe { Self::set_distance_type(self.raw_prop, distance_type)? };
        Ok(self)
    }

    unsafe fn set_distance_type(
        raw_prop: sys::NGTProperty,
        distance_type: QgDistance,
    ) -> Result<()> {
        let ebuf = sys::ngt_create_error_object();
        defer! { sys::ngt_destroy_error_object(ebuf); }

        match distance_type {
            QgDistance::L2 => {
                if !sys::ngt_set_property_distance_type_l2(raw_prop, ebuf) {
                    Err(make_err(ebuf))?
                }
            }
            QgDistance::Cosine => {
                if !sys::ngt_set_property_distance_type_cosine(raw_prop, ebuf) {
                    Err(make_err(ebuf))?
                }
            }
        }

        Ok(())
    }
}

impl<T> Drop for QgProperties<T> {
    fn drop(&mut self) {
        if !self.raw_prop.is_null() {
            unsafe { sys::ngt_destroy_property(self.raw_prop) };
            self.raw_prop = ptr::null_mut();
        }
    }
}

impl<T> TryFrom<QgProperties<T>> for NgtProperties<T>
where
    T: QgObjectType,
    T: NgtObjectType,
{
    type Error = crate::Error;

    fn try_from(prop: QgProperties<T>) -> Result<Self> {
        NgtProperties::dimension(prop.dimension as usize)?
            .creation_edge_size(prop.creation_edge_size as usize)?
            .search_edge_size(prop.search_edge_size as usize)?
            .distance_type(prop.distance_type.into())
    }
}

#[derive(Debug, Clone, PartialEq)]
pub struct QgQuantizationParams {
    pub dimension_of_subvector: f32,
    pub max_number_of_edges: usize,
}

impl Default for QgQuantizationParams {
    fn default() -> Self {
        Self {
            dimension_of_subvector: 0.0,
            max_number_of_edges: 128,
        }
    }
}

impl QgQuantizationParams {
    pub(crate) fn into_raw(self) -> sys::NGTQGQuantizationParameters {
        sys::NGTQGQuantizationParameters {
            dimension_of_subvector: self.dimension_of_subvector,
            max_number_of_edges: self.max_number_of_edges,
        }
    }
}