use crate::types::distance::DistanceMetric;
use crate::types::err::{Error, ErrorCode};
use crate::types::filter::*;
use crate::types::record::*;
use crate::utils::file;
use rayon::prelude::*;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use sqlx::any::AnyRow;
use std::any::Any;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
use std::fmt::Debug;
use std::path::Path;
mod idx_flat;
mod idx_ivfpq;
pub use idx_flat::{IndexFlat, ParamsFlat};
pub use idx_ivfpq::{IndexIVFPQ, ParamsIVFPQ};
pub type TableName = String;
#[allow(missing_docs)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum SourceType {
SQLITE,
POSTGRES,
MYSQL,
}
impl From<&str> for SourceType {
fn from(scheme: &str) -> Self {
match scheme {
"sqlite" => SourceType::SQLITE,
"postgres" | "postgresql" => SourceType::POSTGRES,
"mysql" => SourceType::MYSQL,
_ => panic!("Unsupported database scheme: {scheme}."),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SourceConfig {
pub table: TableName,
pub primary_key: ColumnName,
pub vector: ColumnName,
pub metadata: Option<Vec<ColumnName>>,
pub filter: Option<String>,
}
#[cfg(test)]
impl Default for SourceConfig {
fn default() -> Self {
SourceConfig {
table: "embeddings".into(),
primary_key: "id".into(),
vector: "vector".into(),
metadata: None,
filter: None,
}
}
}
impl SourceConfig {
pub fn new(
table: impl Into<TableName>,
primary_key: impl Into<ColumnName>,
vector: impl Into<ColumnName>,
) -> Self {
SourceConfig {
table: table.into(),
primary_key: primary_key.into(),
vector: vector.into(),
metadata: None,
filter: None,
}
}
pub fn with_metadata(
mut self,
metadata: Vec<impl Into<ColumnName>>,
) -> Self {
self.metadata = Some(metadata.into_iter().map(|s| s.into()).collect());
self
}
pub fn with_filter(mut self, filter: impl Into<String>) -> Self {
let filter: String = filter.into();
self.filter = Some(filter.trim().to_string());
self
}
pub fn columns(&self) -> Vec<ColumnName> {
let mut columns = vec![&self.primary_key, &self.vector];
if let Some(metadata) = &self.metadata {
columns.extend(metadata.iter());
}
columns.into_iter().map(|s| s.to_owned()).collect()
}
pub(crate) fn to_query(&self) -> String {
let table = &self.table;
let columns = self.columns().join(", ");
let filter = match &self.filter {
Some(filter) => format!("WHERE {}", filter),
None => String::new(),
};
let query = format!("SELECT {columns} FROM {table} {filter}");
query.trim().to_string()
}
pub(crate) fn to_query_after(&self, checkpoint: &RecordID) -> String {
let table = &self.table;
let pk = &self.primary_key;
let columns = self.columns().join(", ");
let mut filter = format!("WHERE {pk} > {}", checkpoint.0);
if let Some(string) = &self.filter {
filter.push_str(&format!(" AND ({string})"));
}
let query = format!("SELECT {columns} FROM {table} {filter}");
query.trim().to_string()
}
pub(crate) fn to_record(
&self,
row: &AnyRow,
) -> Result<(RecordID, Record), Error> {
let id = RecordID::from_row(&self.primary_key, row)?;
let vector = Vector::from_row(&self.vector, row)?;
let mut metadata = HashMap::new();
if let Some(metadata_columns) = &self.metadata {
for column in metadata_columns {
let value = RowOps::from_row(column.to_owned(), row)?;
metadata.insert(column.to_owned(), value);
}
}
let record = Record { vector, data: metadata };
Ok((id, record))
}
}
#[allow(missing_docs)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum IndexAlgorithm {
Flat(ParamsFlat), IVFPQ(ParamsIVFPQ), }
impl IndexAlgorithm {
pub fn name(&self) -> &str {
match self {
Self::Flat(_) => "FLAT",
Self::IVFPQ(_) => "IVFPQ",
}
}
}
impl PartialEq for IndexAlgorithm {
fn eq(&self, other: &Self) -> bool {
self.name() == other.name()
}
}
impl Eq for IndexAlgorithm {}
impl IndexAlgorithm {
pub(crate) fn initialize(&self) -> Result<Box<dyn VectorIndex>, Error> {
macro_rules! initialize {
($index_type:ident, $params:expr) => {{
let index = $index_type::new($params)?;
Ok(Box::new(index))
}};
}
match self.to_owned() {
Self::Flat(params) => initialize!(IndexFlat, params),
Self::IVFPQ(params) => initialize!(IndexIVFPQ, params),
}
}
pub(crate) fn load_index(
&self,
path: impl AsRef<Path>,
) -> Result<Box<dyn VectorIndex>, Error> {
macro_rules! load {
($index_type:ident) => {{
let index = Self::_load_index::<$index_type>(path)?;
Ok(Box::new(index))
}};
}
match self {
Self::Flat(_) => load!(IndexFlat),
Self::IVFPQ(_) => load!(IndexIVFPQ),
}
}
pub(crate) fn persist_index(
&self,
tmp_dir: impl AsRef<Path>,
path: impl AsRef<Path>,
index: &dyn VectorIndex,
) -> Result<(), Error> {
macro_rules! persist {
($index_type:ident) => {{
Self::_persist_index::<$index_type>(tmp_dir, path, index)
}};
}
match self {
Self::Flat(_) => persist!(IndexFlat),
Self::IVFPQ(_) => persist!(IndexIVFPQ),
}
}
fn _load_index<T: VectorIndex + IndexOps + 'static>(
path: impl AsRef<Path>,
) -> Result<T, Error> {
let index = T::load(path)?;
Ok(index)
}
fn _persist_index<T: VectorIndex + IndexOps + 'static>(
tmp_dir: impl AsRef<Path>,
path: impl AsRef<Path>,
index: &dyn VectorIndex,
) -> Result<(), Error> {
let index = index.as_any().downcast_ref::<T>().ok_or_else(|| {
let code = ErrorCode::InternalError;
let message = "Failed to downcast index to concrete type.";
Error::new(code, message)
})?;
index.persist(tmp_dir, path)?;
Ok(())
}
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct IndexMetadata {
pub built: bool,
pub last_inserted: Option<RecordID>,
}
#[derive(Debug)]
pub struct SearchResult {
pub id: RecordID,
pub data: HashMap<ColumnName, Option<DataValue>>,
pub distance: f32,
}
impl PartialEq for SearchResult {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for SearchResult {}
impl PartialOrd for SearchResult {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SearchResult {
fn cmp(&self, other: &Self) -> Ordering {
self.distance.partial_cmp(&other.distance).unwrap_or(Ordering::Equal)
}
}
pub trait IndexOps: Debug + Serialize + DeserializeOwned {
fn new(params: impl IndexParams) -> Result<Self, Error>;
fn load(path: impl AsRef<Path>) -> Result<Self, Error> {
file::read_binary_file(path)
}
fn persist(
&self,
tmp_dir: impl AsRef<Path>,
path: impl AsRef<Path>,
) -> Result<(), Error> {
file::write_binary_file(tmp_dir, path, self)
}
}
pub trait VectorIndex: Debug + Send + Sync {
fn metric(&self) -> &DistanceMetric;
fn metadata(&self) -> &IndexMetadata;
fn build(
&mut self,
records: HashMap<RecordID, Record>,
) -> Result<(), Error>;
fn insert(
&mut self,
records: HashMap<RecordID, Record>,
) -> Result<(), Error>;
fn update(
&mut self,
records: HashMap<RecordID, Record>,
) -> Result<(), Error>;
fn delete(&mut self, ids: Vec<RecordID>) -> Result<(), Error>;
fn search(
&self,
query: Vector,
k: usize,
filters: Filters,
) -> Result<Vec<SearchResult>, Error>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn as_any(&self) -> &dyn Any;
}
pub trait IndexParams: Debug + Default + Clone {
fn metric(&self) -> &DistanceMetric;
fn as_any(&self) -> &dyn Any;
}
pub(crate) fn downcast_params<T: IndexParams + 'static>(
params: impl IndexParams,
) -> Result<T, Error> {
params.as_any().downcast_ref::<T>().cloned().ok_or_else(|| {
let code = ErrorCode::InternalError;
let message = "Failed to downcast index parameters to concrete type.";
Error::new(code, message)
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_source_config_new() {
let config = SourceConfig::new("table", "id", "embedding");
let query = config.to_query();
assert_eq!(query, "SELECT id, embedding FROM table");
}
#[test]
fn test_source_config_new_complete() {
let config = SourceConfig::new("table", "id", "embedding")
.with_metadata(vec!["metadata"])
.with_filter("id > 100");
let query = config.to_query();
let expected =
"SELECT id, embedding, metadata FROM table WHERE id > 100";
assert_eq!(query, expected);
}
}
#[cfg(test)]
mod index_tests {
use super::*;
const DIMENSION: usize = 128;
const K: usize = 10;
pub fn test_index(index: &mut impl VectorIndex) {
populate_index(index);
test_search(index);
test_search_with_filters(index);
test_search_after_insert(index);
test_search_after_update(index);
test_search_after_delete(index);
}
pub fn populate_index(index: &mut impl VectorIndex) {
let mut records = HashMap::new();
for i in 0..100 {
let id = RecordID(i as u32);
let vector = Vector::from(vec![i as f32; DIMENSION]);
let data = HashMap::from([(
"number".into(),
Some(DataValue::Integer(1000 + i)),
)]);
let record = Record { vector, data };
records.insert(id, record);
}
index.build(records).unwrap();
assert_eq!(index.len(), 100);
}
pub fn test_search(index: &impl VectorIndex) {
let results = search_index(index, Filters::NONE);
assert_eq!(results.len(), K);
assert!(results.contains(&RecordID(0)));
}
pub fn test_search_with_filters(index: &impl VectorIndex) {
let filters = Filters::from("number > 1010");
let results = search_index(index, filters);
assert_eq!(results.len(), K);
assert!(results.contains(&RecordID(11)));
assert!(!results.contains(&RecordID(0)));
}
pub fn test_search_after_insert(index: &mut impl VectorIndex) {
let id = RecordID(100);
let vector = Vector::from(vec![0.1; DIMENSION]);
let data = HashMap::from([(
"number".to_string(),
Some(DataValue::Integer(2000)),
)]);
let record = Record { vector, data };
let records = HashMap::from([(id, record)]);
index.insert(records).unwrap();
let results = search_index(index, Filters::NONE);
assert_eq!(results.len(), K);
assert!(results.contains(&RecordID(100)));
assert!(results.contains(&RecordID(0)));
}
pub fn test_search_after_update(index: &mut impl VectorIndex) {
let id = RecordID(0);
let vector = Vector::from(vec![100.0; DIMENSION]);
let data = HashMap::from([(
"number".to_string(),
Some(DataValue::Integer(2000)),
)]);
let record = Record { vector, data };
let records = HashMap::from([(id, record)]);
index.update(records).unwrap();
let results = search_index(index, Filters::NONE);
assert_eq!(results.len(), K);
assert!(!results.contains(&RecordID(0)));
assert!(results.contains(&RecordID(1)));
}
pub fn test_search_after_delete(index: &mut impl VectorIndex) {
let ids = vec![RecordID(1), RecordID(2)];
index.delete(ids).unwrap();
let results = search_index(index, Filters::NONE);
assert_eq!(results.len(), K);
assert!(!results.contains(&RecordID(1)));
assert!(!results.contains(&RecordID(2)));
assert!(results.contains(&RecordID(3)));
}
fn search_index(
index: &impl VectorIndex,
filters: Filters,
) -> Vec<RecordID> {
let query = Vector::from(vec![0.0; DIMENSION]);
index
.search(query, K, filters)
.unwrap()
.iter()
.map(|result| result.id)
.collect()
}
}