use alloc::format;
use alloc::string::String;
use alloc::vec;
use alloc::vec::Vec;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
#[derive(Default)]
pub enum QuantizationType {
#[default]
F32 = 0,
F16 = 1,
Int8 = 2,
Binary = 3,
}
impl From<u8> for QuantizationType {
fn from(value: u8) -> Self {
match value {
0 => Self::F32,
1 => Self::F16,
2 => Self::Int8,
3 => Self::Binary,
_ => Self::F32,
}
}
}
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct VectorEmbeddingHeader {
pub object_id: u64,
pub model_id: u32,
pub dimensions: u32,
pub quantization: u8,
pub reserved: [u8; 7],
}
impl VectorEmbeddingHeader {
pub const SIZE: usize = 24;
pub fn new(
object_id: u64,
model_id: u32,
dimensions: u32,
quantization: QuantizationType,
) -> Self {
Self {
object_id,
model_id,
dimensions,
quantization: quantization as u8,
reserved: [0u8; 7],
}
}
pub fn quantization_type(&self) -> QuantizationType {
QuantizationType::from(self.quantization)
}
pub fn data_size(&self) -> usize {
match self.quantization_type() {
QuantizationType::F32 => self.dimensions as usize * 4,
QuantizationType::F16 => self.dimensions as usize * 2,
QuantizationType::Int8 => self.dimensions as usize + 8, QuantizationType::Binary => (self.dimensions as usize).div_ceil(8),
}
}
pub fn total_size(&self) -> usize {
Self::SIZE + self.data_size()
}
pub fn to_bytes(&self) -> [u8; Self::SIZE] {
let mut buf = [0u8; Self::SIZE];
buf[0..8].copy_from_slice(&self.object_id.to_le_bytes());
buf[8..12].copy_from_slice(&self.model_id.to_le_bytes());
buf[12..16].copy_from_slice(&self.dimensions.to_le_bytes());
buf[16] = self.quantization;
buf
}
pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Self {
Self {
object_id: u64::from_le_bytes([
buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7],
]),
model_id: u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]),
dimensions: u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]),
quantization: buf[16],
reserved: [
buf[17], buf[18], buf[19], buf[20], buf[21], buf[22], buf[23],
],
}
}
}
#[derive(Debug, Clone)]
pub struct VectorEmbedding {
pub header: VectorEmbeddingHeader,
pub data: Vec<f32>,
}
impl VectorEmbedding {
pub fn new(object_id: u64, model_id: u32, data: Vec<f32>) -> Self {
let dimensions = data.len() as u32;
Self {
header: VectorEmbeddingHeader::new(
object_id,
model_id,
dimensions,
QuantizationType::F32,
),
data,
}
}
pub fn with_quantization(
object_id: u64,
model_id: u32,
data: Vec<f32>,
quantization: QuantizationType,
) -> Self {
let dimensions = data.len() as u32;
Self {
header: VectorEmbeddingHeader::new(object_id, model_id, dimensions, quantization),
data,
}
}
pub fn object_id(&self) -> u64 {
self.header.object_id
}
pub fn dimensions(&self) -> usize {
self.header.dimensions as usize
}
pub fn as_slice(&self) -> &[f32] {
&self.data
}
pub fn l2_norm(&self) -> f32 {
let sum: f32 = self.data.iter().map(|x| x * x).sum();
libm::sqrtf(sum)
}
pub fn normalize(&mut self) {
let norm = self.l2_norm();
if norm > 1e-10 {
for x in &mut self.data {
*x /= norm;
}
}
}
pub fn normalized(&self) -> Self {
let mut copy = self.clone();
copy.normalize();
copy
}
}
pub const HNSW_MAX_NEIGHBORS: usize = 32;
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct HnswNode {
pub object_id: u64,
pub layer: u8,
pub neighbor_count: u8,
pub reserved: [u8; 6],
pub neighbors: [u64; HNSW_MAX_NEIGHBORS],
}
impl HnswNode {
pub const SIZE: usize = 8 + 1 + 1 + 6 + (HNSW_MAX_NEIGHBORS * 8);
pub fn new(object_id: u64, layer: u8) -> Self {
Self {
object_id,
layer,
neighbor_count: 0,
reserved: [0u8; 6],
neighbors: [0u64; HNSW_MAX_NEIGHBORS],
}
}
pub fn add_neighbor(&mut self, neighbor_id: u64) -> bool {
if (self.neighbor_count as usize) < HNSW_MAX_NEIGHBORS {
self.neighbors[self.neighbor_count as usize] = neighbor_id;
self.neighbor_count += 1;
true
} else {
false
}
}
pub fn remove_neighbor(&mut self, neighbor_id: u64) -> bool {
for i in 0..self.neighbor_count as usize {
if self.neighbors[i] == neighbor_id {
for j in i..self.neighbor_count as usize - 1 {
self.neighbors[j] = self.neighbors[j + 1];
}
self.neighbor_count -= 1;
self.neighbors[self.neighbor_count as usize] = 0;
return true;
}
}
false
}
pub fn has_neighbor(&self, neighbor_id: u64) -> bool {
for i in 0..self.neighbor_count as usize {
if self.neighbors[i] == neighbor_id {
return true;
}
}
false
}
pub fn get_neighbors(&self) -> &[u64] {
&self.neighbors[..self.neighbor_count as usize]
}
pub fn set_neighbors(&mut self, neighbors: &[u64]) {
let count = neighbors.len().min(HNSW_MAX_NEIGHBORS);
self.neighbors[..count].copy_from_slice(&neighbors[..count]);
self.neighbor_count = count as u8;
}
pub fn to_bytes(&self) -> [u8; Self::SIZE] {
let mut buf = [0u8; Self::SIZE];
buf[0..8].copy_from_slice(&self.object_id.to_le_bytes());
buf[8] = self.layer;
buf[9] = self.neighbor_count;
for (i, &neighbor) in self.neighbors.iter().enumerate() {
let offset = 16 + i * 8;
buf[offset..offset + 8].copy_from_slice(&neighbor.to_le_bytes());
}
buf
}
pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Self {
let object_id = u64::from_le_bytes([
buf[0], buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7],
]);
let layer = buf[8];
let neighbor_count = buf[9];
let reserved = [buf[10], buf[11], buf[12], buf[13], buf[14], buf[15]];
let mut neighbors = [0u64; HNSW_MAX_NEIGHBORS];
for (i, neighbor) in neighbors.iter_mut().enumerate() {
let offset = 16 + i * 8;
*neighbor = u64::from_le_bytes([
buf[offset],
buf[offset + 1],
buf[offset + 2],
buf[offset + 3],
buf[offset + 4],
buf[offset + 5],
buf[offset + 6],
buf[offset + 7],
]);
}
Self {
object_id,
layer,
neighbor_count,
reserved,
neighbors,
}
}
}
#[derive(Debug, Clone)]
pub struct VectorSearchResult {
pub object_id: u64,
pub path: Option<String>,
pub score: f32,
pub distance: f32,
}
impl VectorSearchResult {
pub fn new(object_id: u64, score: f32, distance: f32) -> Self {
Self {
object_id,
path: None,
score,
distance,
}
}
pub fn with_path(object_id: u64, path: String, score: f32, distance: f32) -> Self {
Self {
object_id,
path: Some(path),
score,
distance,
}
}
}
impl PartialEq for VectorSearchResult {
fn eq(&self, other: &Self) -> bool {
self.object_id == other.object_id
}
}
impl Eq for VectorSearchResult {}
impl PartialOrd for VectorSearchResult {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for VectorSearchResult {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
other
.score
.partial_cmp(&self.score)
.unwrap_or(core::cmp::Ordering::Equal)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DistanceMetric {
#[default]
Cosine,
Euclidean,
DotProduct,
Manhattan,
Hamming,
}
impl DistanceMetric {
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"cosine" | "cos" => Some(Self::Cosine),
"euclidean" | "l2" | "euclid" => Some(Self::Euclidean),
"dot" | "dotproduct" | "inner" => Some(Self::DotProduct),
"manhattan" | "l1" => Some(Self::Manhattan),
"hamming" => Some(Self::Hamming),
_ => None,
}
}
pub fn name(&self) -> &'static str {
match self {
Self::Cosine => "cosine",
Self::Euclidean => "euclidean",
Self::DotProduct => "dot",
Self::Manhattan => "manhattan",
Self::Hamming => "hamming",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum IndexType {
#[default]
Hnsw,
Ivf,
Flat,
}
impl IndexType {
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"hnsw" => Some(Self::Hnsw),
"ivf" => Some(Self::Ivf),
"flat" | "brute" => Some(Self::Flat),
_ => None,
}
}
pub fn name(&self) -> &'static str {
match self {
Self::Hnsw => "hnsw",
Self::Ivf => "ivf",
Self::Flat => "flat",
}
}
}
#[derive(Debug, Clone)]
pub struct IndexParams {
pub index_type: IndexType,
pub metric: DistanceMetric,
pub hnsw_m: usize,
pub hnsw_ef_construction: usize,
pub hnsw_ef_search: usize,
pub ivf_nlist: usize,
pub ivf_nprobe: usize,
}
impl Default for IndexParams {
fn default() -> Self {
Self {
index_type: IndexType::Hnsw,
metric: DistanceMetric::Cosine,
hnsw_m: 16,
hnsw_ef_construction: 200,
hnsw_ef_search: 50,
ivf_nlist: 100,
ivf_nprobe: 10,
}
}
}
#[derive(Debug, Clone)]
pub struct VectorIndexMeta {
pub index_type: IndexType,
pub vector_count: u64,
pub dimensions: u32,
pub distance_metric: DistanceMetric,
pub hnsw_m: u16,
pub hnsw_ef_construction: u16,
pub max_layer: u8,
pub entry_point: u64,
}
impl Default for VectorIndexMeta {
fn default() -> Self {
Self {
index_type: IndexType::Hnsw,
vector_count: 0,
dimensions: 0,
distance_metric: DistanceMetric::Cosine,
hnsw_m: 16,
hnsw_ef_construction: 200,
max_layer: 0,
entry_point: 0,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SearchFilter {
pub extensions: Option<Vec<String>>,
pub min_size: Option<u64>,
pub max_size: Option<u64>,
pub modified_after: Option<u64>,
pub modified_before: Option<u64>,
pub path_prefix: Option<String>,
pub min_score: Option<f32>,
}
impl SearchFilter {
pub fn new() -> Self {
Self::default()
}
pub fn with_extensions(mut self, extensions: Vec<String>) -> Self {
self.extensions = Some(extensions);
self
}
pub fn with_min_size(mut self, size: u64) -> Self {
self.min_size = Some(size);
self
}
pub fn with_max_size(mut self, size: u64) -> Self {
self.max_size = Some(size);
self
}
pub fn with_modified_after(mut self, timestamp: u64) -> Self {
self.modified_after = Some(timestamp);
self
}
pub fn with_path_prefix(mut self, prefix: String) -> Self {
self.path_prefix = Some(prefix);
self
}
pub fn with_min_score(mut self, score: f32) -> Self {
self.min_score = Some(score);
self
}
pub fn matches_basic(&self, result: &VectorSearchResult) -> bool {
if let Some(min_score) = self.min_score {
if result.score < min_score {
return false;
}
}
if let (Some(prefix), Some(path)) = (&self.path_prefix, &result.path) {
if !path.starts_with(prefix) {
return false;
}
}
if let (Some(extensions), Some(path)) = (&self.extensions, &result.path) {
let has_valid_ext = extensions.iter().any(|ext| {
path.to_lowercase()
.ends_with(&format!(".{}", ext.to_lowercase()))
});
if !has_valid_ext {
return false;
}
}
true
}
}
#[derive(Debug, Clone)]
pub enum VectorError {
DimensionMismatch {
expected: usize,
actual: usize,
},
IndexNotFound,
ObjectNotFound(u64),
EmptyIndex,
InvalidQuantization(u8),
SerializationError,
IoError(alloc::string::String),
CorruptedIndex,
DatasetNotFound(String),
NotSupported(String),
}
impl core::fmt::Display for VectorError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::DimensionMismatch { expected, actual } => {
write!(
f,
"dimension mismatch: expected {}, got {}",
expected, actual
)
}
Self::IndexNotFound => write!(f, "vector index not found"),
Self::ObjectNotFound(id) => write!(f, "object {} not found in index", id),
Self::EmptyIndex => write!(f, "vector index is empty"),
Self::InvalidQuantization(q) => write!(f, "invalid quantization type: {}", q),
Self::SerializationError => write!(f, "serialization error"),
Self::IoError(msg) => write!(f, "IO error: {}", msg),
Self::CorruptedIndex => write!(f, "corrupted vector index"),
Self::DatasetNotFound(name) => write!(f, "dataset not found: {}", name),
Self::NotSupported(feature) => write!(f, "feature not supported: {}", feature),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vector_embedding_header_serialization() {
let header = VectorEmbeddingHeader::new(12345, 42, 512, QuantizationType::F32);
let bytes = header.to_bytes();
let restored = VectorEmbeddingHeader::from_bytes(&bytes);
assert_eq!(restored.object_id, 12345);
assert_eq!(restored.model_id, 42);
assert_eq!(restored.dimensions, 512);
assert_eq!(restored.quantization_type(), QuantizationType::F32);
}
#[test]
fn test_vector_embedding_data_size() {
let header_f32 = VectorEmbeddingHeader::new(1, 1, 512, QuantizationType::F32);
assert_eq!(header_f32.data_size(), 512 * 4);
let header_f16 = VectorEmbeddingHeader::new(1, 1, 512, QuantizationType::F16);
assert_eq!(header_f16.data_size(), 512 * 2);
let header_int8 = VectorEmbeddingHeader::new(1, 1, 512, QuantizationType::Int8);
assert_eq!(header_int8.data_size(), 512 + 8);
let header_binary = VectorEmbeddingHeader::new(1, 1, 512, QuantizationType::Binary);
assert_eq!(header_binary.data_size(), 64);
}
#[test]
fn test_hnsw_node_neighbors() {
let mut node = HnswNode::new(1, 0);
assert_eq!(node.neighbor_count, 0);
assert!(node.add_neighbor(2));
assert!(node.add_neighbor(3));
assert!(node.add_neighbor(4));
assert_eq!(node.neighbor_count, 3);
assert!(node.has_neighbor(2));
assert!(node.has_neighbor(3));
assert!(!node.has_neighbor(5));
assert!(node.remove_neighbor(3));
assert_eq!(node.neighbor_count, 2);
assert!(!node.has_neighbor(3));
let neighbors = node.get_neighbors();
assert_eq!(neighbors.len(), 2);
assert_eq!(neighbors[0], 2);
assert_eq!(neighbors[1], 4);
}
#[test]
fn test_hnsw_node_serialization() {
let mut node = HnswNode::new(12345, 2);
node.add_neighbor(100);
node.add_neighbor(200);
node.add_neighbor(300);
let bytes = node.to_bytes();
let restored = HnswNode::from_bytes(&bytes);
assert_eq!(restored.object_id, 12345);
assert_eq!(restored.layer, 2);
assert_eq!(restored.neighbor_count, 3);
assert_eq!(restored.neighbors[0], 100);
assert_eq!(restored.neighbors[1], 200);
assert_eq!(restored.neighbors[2], 300);
}
#[test]
fn test_vector_embedding_normalize() {
let mut embedding = VectorEmbedding::new(1, 1, vec![3.0, 4.0]);
embedding.normalize();
let norm = embedding.l2_norm();
assert!((norm - 1.0).abs() < 1e-6);
assert!((embedding.data[0] - 0.6).abs() < 1e-6);
assert!((embedding.data[1] - 0.8).abs() < 1e-6);
}
#[test]
fn test_search_result_ordering() {
let mut results = [
VectorSearchResult::new(1, 0.5, 0.5),
VectorSearchResult::new(2, 0.9, 0.1),
VectorSearchResult::new(3, 0.7, 0.3),
];
results.sort();
assert_eq!(results[0].object_id, 2);
assert_eq!(results[1].object_id, 3);
assert_eq!(results[2].object_id, 1);
}
#[test]
fn test_search_filter() {
let filter = SearchFilter::new()
.with_min_score(0.5)
.with_extensions(vec!["jpg".into(), "png".into()]);
let good = VectorSearchResult::with_path(1, "/photos/cat.jpg".into(), 0.8, 0.2);
let low_score = VectorSearchResult::with_path(2, "/photos/dog.jpg".into(), 0.3, 0.7);
let wrong_ext = VectorSearchResult::with_path(3, "/docs/readme.txt".into(), 0.9, 0.1);
assert!(filter.matches_basic(&good));
assert!(!filter.matches_basic(&low_score));
assert!(!filter.matches_basic(&wrong_ext));
}
#[test]
fn test_distance_metric_from_str() {
assert_eq!(
DistanceMetric::from_str("cosine"),
Some(DistanceMetric::Cosine)
);
assert_eq!(
DistanceMetric::from_str("L2"),
Some(DistanceMetric::Euclidean)
);
assert_eq!(
DistanceMetric::from_str("dot"),
Some(DistanceMetric::DotProduct)
);
assert_eq!(DistanceMetric::from_str("invalid"), None);
}
}