use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum Dtype {
F32,
}
impl Dtype {
#[must_use]
pub const fn element_size(self) -> usize {
match self {
Dtype::F32 => 4,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DistanceMetric {
Dot,
Cosine,
L2,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum IndexKind {
#[default]
Hnsw,
Vamana,
DiskVamana,
Ivf,
Colbert,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct IndexSpec {
pub kind: IndexKind,
pub pq_subspaces: Option<u32>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum FieldType {
Keyword,
Numeric,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct FilterableField {
pub path: String,
pub field_type: FieldType,
}
impl FilterableField {
#[must_use]
pub fn keyword(path: impl Into<String>) -> Self {
Self {
path: path.into(),
field_type: FieldType::Keyword,
}
}
#[must_use]
pub fn numeric(path: impl Into<String>) -> Self {
Self {
path: path.into(),
field_type: FieldType::Numeric,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum VectorEncryption {
#[default]
None,
Dcpe,
ClientSide,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Descriptor {
pub dim: u32,
pub dtype: Dtype,
pub metric: DistanceMetric,
#[serde(default)]
pub index: IndexSpec,
#[serde(default)]
pub filterable: Vec<FilterableField>,
#[serde(default)]
pub multivector: bool,
#[serde(default)]
pub vector_encryption: VectorEncryption,
}
impl Descriptor {
#[must_use]
pub fn new(dim: u32, dtype: Dtype, metric: DistanceMetric) -> Self {
Self {
dim,
dtype,
metric,
index: IndexSpec::default(),
filterable: Vec::new(),
multivector: false,
vector_encryption: VectorEncryption::None,
}
}
#[must_use]
pub fn with_index(mut self, index: IndexSpec) -> Self {
self.index = index;
self
}
#[must_use]
pub fn with_filterable(mut self, filterable: Vec<FilterableField>) -> Self {
self.filterable = filterable;
self
}
#[must_use]
pub fn with_multivector(mut self, multivector: bool) -> Self {
self.multivector = multivector;
self
}
#[must_use]
pub fn with_vector_encryption(mut self, vector_encryption: VectorEncryption) -> Self {
self.vector_encryption = vector_encryption;
self
}
pub fn decode(bytes: &[u8]) -> std::result::Result<Self, postcard::Error> {
postcard::from_bytes::<Self>(bytes)
.or_else(|_| postcard::from_bytes::<DescriptorV4>(bytes).map(Self::from))
.or_else(|_| postcard::from_bytes::<DescriptorV3>(bytes).map(Self::from))
.or_else(|_| postcard::from_bytes::<DescriptorV2>(bytes).map(Self::from))
.or_else(|_| postcard::from_bytes::<LegacyDescriptor>(bytes).map(Self::from))
}
#[must_use]
pub fn stride(&self) -> usize {
self.dim as usize * self.dtype.element_size()
}
}
#[derive(Deserialize)]
struct DescriptorV4 {
dim: u32,
dtype: Dtype,
metric: DistanceMetric,
index: IndexSpec,
filterable: Vec<FilterableField>,
multivector: bool,
}
impl From<DescriptorV4> for Descriptor {
fn from(v: DescriptorV4) -> Self {
Self {
dim: v.dim,
dtype: v.dtype,
metric: v.metric,
index: v.index,
filterable: v.filterable,
multivector: v.multivector,
vector_encryption: VectorEncryption::None,
}
}
}
#[derive(Deserialize)]
struct DescriptorV3 {
dim: u32,
dtype: Dtype,
metric: DistanceMetric,
index: IndexSpec,
filterable: Vec<FilterableField>,
}
impl From<DescriptorV3> for Descriptor {
fn from(v: DescriptorV3) -> Self {
Self {
dim: v.dim,
dtype: v.dtype,
metric: v.metric,
index: v.index,
filterable: v.filterable,
multivector: false,
vector_encryption: VectorEncryption::None,
}
}
}
#[derive(Deserialize)]
struct DescriptorV2 {
dim: u32,
dtype: Dtype,
metric: DistanceMetric,
index: IndexSpec,
}
impl From<DescriptorV2> for Descriptor {
fn from(v: DescriptorV2) -> Self {
Self {
dim: v.dim,
dtype: v.dtype,
metric: v.metric,
index: v.index,
filterable: Vec::new(),
multivector: false,
vector_encryption: VectorEncryption::None,
}
}
}
#[derive(Deserialize)]
struct LegacyDescriptor {
dim: u32,
dtype: Dtype,
metric: DistanceMetric,
}
impl From<LegacyDescriptor> for Descriptor {
fn from(v: LegacyDescriptor) -> Self {
Self {
dim: v.dim,
dtype: v.dtype,
metric: v.metric,
index: IndexSpec::default(),
filterable: Vec::new(),
multivector: false,
vector_encryption: VectorEncryption::None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stride_matches_dim_and_dtype() {
let d = Descriptor::new(128, Dtype::F32, DistanceMetric::L2);
assert_eq!(d.stride(), 512);
assert_eq!(Dtype::F32.element_size(), 4);
assert_eq!(d.index, IndexSpec::default());
assert_eq!(d.index.kind, IndexKind::Hnsw);
}
#[test]
fn descriptor_roundtrips_through_postcard() {
let d = Descriptor::new(8, Dtype::F32, DistanceMetric::Cosine).with_index(IndexSpec {
kind: IndexKind::DiskVamana,
pq_subspaces: Some(16),
});
let bytes = postcard::to_allocvec(&d).unwrap();
let back: Descriptor = postcard::from_bytes(&bytes).unwrap();
assert_eq!(d, back);
}
#[test]
fn pre_phase2_descriptor_deserializes_with_default_index() {
#[derive(serde::Serialize)]
struct OldDescriptor {
dim: u32,
dtype: Dtype,
metric: DistanceMetric,
}
let old = OldDescriptor {
dim: 16,
dtype: Dtype::F32,
metric: DistanceMetric::L2,
};
let bytes = postcard::to_allocvec(&old).unwrap();
assert!(postcard::from_bytes::<Descriptor>(&bytes).is_err());
let back = Descriptor::decode(&bytes).unwrap();
assert_eq!(back.dim, 16);
assert_eq!(back.metric, DistanceMetric::L2);
assert_eq!(back.index, IndexSpec::default());
}
#[test]
fn decode_reads_current_layout() {
let d = Descriptor::new(8, Dtype::F32, DistanceMetric::Dot).with_index(IndexSpec {
kind: IndexKind::Ivf,
pq_subspaces: Some(8),
});
let bytes = postcard::to_allocvec(&d).unwrap();
assert_eq!(Descriptor::decode(&bytes).unwrap(), d);
}
#[test]
fn pre_filterable_descriptor_decodes_and_keeps_its_index() {
#[derive(serde::Serialize)]
struct DescriptorV2 {
dim: u32,
dtype: Dtype,
metric: DistanceMetric,
index: IndexSpec,
}
let old = DescriptorV2 {
dim: 8,
dtype: Dtype::F32,
metric: DistanceMetric::L2,
index: IndexSpec {
kind: IndexKind::DiskVamana,
pq_subspaces: Some(16),
},
};
let bytes = postcard::to_allocvec(&old).unwrap();
assert!(postcard::from_bytes::<Descriptor>(&bytes).is_err());
let back = Descriptor::decode(&bytes).unwrap();
assert_eq!(back.dim, 8);
assert_eq!(back.index.kind, IndexKind::DiskVamana);
assert_eq!(back.index.pq_subspaces, Some(16));
assert!(back.filterable.is_empty());
}
#[test]
fn descriptor_with_filterable_roundtrips() {
let d = Descriptor::new(4, Dtype::F32, DistanceMetric::L2).with_filterable(vec![
FilterableField::keyword("city"),
FilterableField::numeric("age"),
]);
let bytes = postcard::to_allocvec(&d).unwrap();
assert_eq!(Descriptor::decode(&bytes).unwrap(), d);
}
#[test]
fn descriptor_with_multivector_roundtrips() {
let d = Descriptor::new(128, Dtype::F32, DistanceMetric::Cosine).with_multivector(true);
let bytes = postcard::to_allocvec(&d).unwrap();
let back = Descriptor::decode(&bytes).unwrap();
assert_eq!(back, d);
assert!(back.multivector);
}
#[test]
fn pre_multivector_descriptor_decodes_and_keeps_filterable() {
#[derive(serde::Serialize)]
struct DescriptorV3 {
dim: u32,
dtype: Dtype,
metric: DistanceMetric,
index: IndexSpec,
filterable: Vec<FilterableField>,
}
let old = DescriptorV3 {
dim: 8,
dtype: Dtype::F32,
metric: DistanceMetric::Cosine,
index: IndexSpec {
kind: IndexKind::Ivf,
pq_subspaces: Some(8),
},
filterable: vec![FilterableField::keyword("city")],
};
let bytes = postcard::to_allocvec(&old).unwrap();
assert!(postcard::from_bytes::<Descriptor>(&bytes).is_err());
let back = Descriptor::decode(&bytes).unwrap();
assert_eq!(back.filterable, vec![FilterableField::keyword("city")]);
assert!(!back.multivector);
assert_eq!(back.index.kind, IndexKind::Ivf);
}
#[test]
fn descriptor_with_vector_encryption_roundtrips() {
let d = Descriptor::new(64, Dtype::F32, DistanceMetric::L2)
.with_vector_encryption(VectorEncryption::ClientSide);
let bytes = postcard::to_allocvec(&d).unwrap();
let back = Descriptor::decode(&bytes).unwrap();
assert_eq!(back, d);
assert_eq!(back.vector_encryption, VectorEncryption::ClientSide);
}
#[test]
fn legacy_encrypted_vectors_bool_decodes_as_the_enum() {
#[derive(serde::Serialize)]
struct OldDescriptor {
dim: u32,
dtype: Dtype,
metric: DistanceMetric,
index: IndexSpec,
filterable: Vec<FilterableField>,
multivector: bool,
encrypted_vectors: bool,
}
let make = |enc: bool| OldDescriptor {
dim: 8,
dtype: Dtype::F32,
metric: DistanceMetric::L2,
index: IndexSpec::default(),
filterable: Vec::new(),
multivector: false,
encrypted_vectors: enc,
};
let dcpe = postcard::to_allocvec(&make(true)).unwrap();
assert_eq!(
Descriptor::decode(&dcpe).unwrap().vector_encryption,
VectorEncryption::Dcpe
);
let none = postcard::to_allocvec(&make(false)).unwrap();
assert_eq!(
Descriptor::decode(&none).unwrap().vector_encryption,
VectorEncryption::None
);
}
#[test]
fn pre_vector_encryption_descriptor_decodes_and_keeps_multivector() {
#[derive(serde::Serialize)]
struct DescriptorV4 {
dim: u32,
dtype: Dtype,
metric: DistanceMetric,
index: IndexSpec,
filterable: Vec<FilterableField>,
multivector: bool,
}
let old = DescriptorV4 {
dim: 8,
dtype: Dtype::F32,
metric: DistanceMetric::Cosine,
index: IndexSpec::default(),
filterable: vec![FilterableField::numeric("score")],
multivector: true,
};
let bytes = postcard::to_allocvec(&old).unwrap();
assert!(postcard::from_bytes::<Descriptor>(&bytes).is_err());
let back = Descriptor::decode(&bytes).unwrap();
assert!(back.multivector);
assert_eq!(back.filterable, vec![FilterableField::numeric("score")]);
assert_eq!(back.vector_encryption, VectorEncryption::None);
}
}