use serde::{Deserialize, Serialize};
use dynvec::distance::Distance as EngineDistance;
use dynvec::encoding::Codec as EngineCodec;
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum VectorType {
Float32,
Float16,
Int8,
}
impl VectorType {
#[must_use]
pub fn engine_codec(self) -> EngineCodec {
match self {
Self::Float32 | Self::Float16 => EngineCodec::Fp16,
Self::Int8 => EngineCodec::Int8Quantized,
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum DistanceMetric {
L2,
InnerProduct,
Cosine,
}
impl DistanceMetric {
#[must_use]
pub fn engine_distance(self) -> EngineDistance {
match self {
Self::L2 => EngineDistance::Euclidean,
Self::InnerProduct => EngineDistance::DotProduct,
Self::Cosine => EngineDistance::Cosine,
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum IndexAlgorithm {
Hnsw,
Flat,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum MetadataFieldType {
Text,
Numeric,
Tag,
Geo,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct MetadataField {
pub name: String,
pub field_type: MetadataFieldType,
#[serde(default)]
pub tag_separator: Option<u8>,
}
impl MetadataField {
#[must_use]
pub fn effective_tag_separator(&self) -> u8 {
self.tag_separator.unwrap_or(b',')
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct VectorSchema {
pub vector_field: String,
pub vector_type: VectorType,
pub dim: u16,
pub distance: DistanceMetric,
pub algorithm: IndexAlgorithm,
pub prefixes: Vec<Vec<u8>>,
pub metadata_fields: Vec<MetadataField>,
}
impl VectorSchema {
#[must_use]
pub fn to_engine_schema(&self, table_name: &str) -> dynvec::storage::TableSchema {
dynvec::storage::TableSchema {
name: table_name.to_string(),
dim: self.dim,
codec: self.vector_type.engine_codec(),
distance: self.distance.engine_distance(),
hnsw: dynvec::index::HnswParams::default(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn vector_type_maps_to_engine_codec() {
assert_eq!(VectorType::Float32.engine_codec(), EngineCodec::Fp16);
assert_eq!(VectorType::Float16.engine_codec(), EngineCodec::Fp16);
assert_eq!(VectorType::Int8.engine_codec(), EngineCodec::Int8Quantized);
}
#[test]
fn distance_metric_maps_to_engine_distance() {
assert_eq!(
DistanceMetric::L2.engine_distance(),
EngineDistance::Euclidean
);
assert_eq!(
DistanceMetric::InnerProduct.engine_distance(),
EngineDistance::DotProduct
);
assert_eq!(
DistanceMetric::Cosine.engine_distance(),
EngineDistance::Cosine
);
}
#[test]
fn schema_compiles_to_engine_schema() {
let schema = VectorSchema {
vector_field: "vec".to_string(),
vector_type: VectorType::Float32,
dim: 16,
distance: DistanceMetric::Cosine,
algorithm: IndexAlgorithm::Hnsw,
prefixes: Vec::new(),
metadata_fields: vec![MetadataField {
name: "title".to_string(),
field_type: MetadataFieldType::Text,
tag_separator: None,
}],
};
let engine = schema.to_engine_schema("docs");
assert_eq!(engine.name, "docs");
assert_eq!(engine.dim, 16);
assert_eq!(engine.codec, EngineCodec::Fp16);
assert_eq!(engine.distance, EngineDistance::Cosine);
}
}