use crate::hnsw::HnswIndex;
use crate::persistence::header::{FileHeader, Flags, MetadataSectionHeader};
use crate::storage::VectorStorage;
use std::cmp::min;
pub const MIN_CHUNK_SIZE: usize = 64;
pub trait ChunkedWriter {
fn export_chunked(&self, chunk_size: usize) -> ChunkIter<'_>;
}
pub struct ChunkIter<'a> {
storage: &'a VectorStorage,
index: &'a HnswIndex,
chunk_size: usize,
state: SerializationState,
buffer: Vec<u8>,
header_bytes: [u8; 64],
vector_data_offset: usize, node_index: usize,
neighbor_offset: usize,
tombstone_offset: usize, metadata_section: Vec<u8>, metadata_section_offset: usize, }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SerializationState {
Header,
VectorData,
IndexNodes,
IndexNeighbors,
Tombstones, MetadataSection, Done,
}
impl<'a> ChunkedWriter for (&'a VectorStorage, &'a HnswIndex) {
fn export_chunked(&self, chunk_size: usize) -> ChunkIter<'a> {
let (storage, index) = self;
let chunk_size = chunk_size.max(MIN_CHUNK_SIZE);
let dimensions = storage.dimensions();
let vector_count = storage.len() as u64;
let vector_data_size = (storage.raw_data().len() * 4) as u64;
let nodes_size = (index.node_count() * 16) as u64; let neighbors_size = index.neighbors.buffer.len() as u64;
let index_offset = 64 + vector_data_size;
let tombstone_offset_start = index_offset + nodes_size + neighbors_size;
let metadata_section = if index.metadata.is_empty() {
Vec::new()
} else {
match index.metadata.to_postcard() {
Ok(serialized) => {
let crc = crc32fast::hash(&serialized);
#[allow(clippy::cast_possible_truncation)]
let meta_header =
MetadataSectionHeader::new_postcard(serialized.len() as u32, crc);
let mut section = Vec::with_capacity(16 + serialized.len());
section.extend_from_slice(meta_header.as_bytes());
section.extend_from_slice(&serialized);
section
}
Err(_) => {
Vec::new()
}
}
};
let has_metadata = !metadata_section.is_empty();
let mut header = FileHeader::new(dimensions);
header.vector_count = vector_count;
header.index_offset = index_offset;
header.metadata_offset = tombstone_offset_start; header.hnsw_m = index.config.m;
header.hnsw_m0 = index.config.m0;
if has_metadata {
header.flags |= Flags::HAS_METADATA;
}
#[allow(clippy::cast_possible_truncation)]
{
header.deleted_count = index.deleted_count as u32;
}
header.update_checksum();
ChunkIter {
storage,
index,
chunk_size,
state: SerializationState::Header,
buffer: Vec::with_capacity(chunk_size),
header_bytes: *header.as_bytes(),
vector_data_offset: 0,
node_index: 0,
neighbor_offset: 0,
tombstone_offset: 0,
metadata_section,
metadata_section_offset: 0,
}
}
}
impl Iterator for ChunkIter<'_> {
type Item = Vec<u8>;
fn next(&mut self) -> Option<Self::Item> {
if self.state == SerializationState::Done {
return None;
}
self.buffer.clear();
while self.buffer.len() < self.chunk_size && self.state != SerializationState::Done {
let space_left = self.chunk_size - self.buffer.len();
match self.state {
SerializationState::Header => {
let bytes = &self.header_bytes;
if space_left >= bytes.len() {
self.buffer.extend_from_slice(bytes);
self.state = SerializationState::VectorData;
} else {
debug_assert!(
false,
"chunk_size {} < MIN_CHUNK_SIZE {}",
self.chunk_size, MIN_CHUNK_SIZE
);
self.buffer.extend_from_slice(&bytes[..space_left]);
break;
}
}
SerializationState::VectorData => {
let data = self.storage.raw_data();
let remaining_floats = data.len() - self.vector_data_offset;
if remaining_floats == 0 {
self.state = SerializationState::IndexNodes;
continue;
}
let floats_to_copy = min(remaining_floats, space_left / 4);
if floats_to_copy > 0 {
let end = self.vector_data_offset + floats_to_copy;
let slice = &data[self.vector_data_offset..end];
let byte_slice = bytemuck::cast_slice(slice);
self.buffer.extend_from_slice(byte_slice);
self.vector_data_offset += floats_to_copy;
}
if self.vector_data_offset == data.len() {
self.state = SerializationState::IndexNodes;
} else if floats_to_copy == 0 {
break;
}
}
SerializationState::IndexNodes => {
let nodes = &self.index.nodes;
let remaining_nodes = nodes.len() - self.node_index;
if remaining_nodes == 0 {
self.state = SerializationState::IndexNeighbors;
continue;
}
let nodes_to_copy = min(remaining_nodes, space_left / 16);
if nodes_to_copy > 0 {
let end = self.node_index + nodes_to_copy;
let slice = &nodes[self.node_index..end];
let byte_slice: &[u8] = bytemuck::cast_slice(slice);
self.buffer.extend_from_slice(byte_slice);
self.node_index += nodes_to_copy;
}
if self.node_index == nodes.len() {
self.state = SerializationState::IndexNeighbors;
} else if nodes_to_copy == 0 {
break;
}
}
SerializationState::IndexNeighbors => {
let neighbors = &self.index.neighbors.buffer;
let remaining_bytes = neighbors.len() - self.neighbor_offset;
if remaining_bytes == 0 {
self.state = SerializationState::Tombstones;
continue;
}
let bytes_to_copy = min(remaining_bytes, space_left);
let end = self.neighbor_offset + bytes_to_copy;
self.buffer
.extend_from_slice(&neighbors[self.neighbor_offset..end]);
self.neighbor_offset += bytes_to_copy;
if self.neighbor_offset == neighbors.len() {
self.state = SerializationState::Tombstones;
} else if bytes_to_copy == 0 {
break;
}
}
SerializationState::Tombstones => {
let total_bits = self.storage.deleted.len();
let total_bytes = (total_bits + 7) / 8;
let remaining_bytes = total_bytes - self.tombstone_offset;
if remaining_bytes == 0 {
self.state = SerializationState::MetadataSection;
continue;
}
let bytes_to_produce = min(remaining_bytes, space_left);
for _ in 0..bytes_to_produce {
let byte_idx = self.tombstone_offset;
let start_bit = byte_idx * 8;
let mut byte: u8 = 0;
for bit_offset in 0..8 {
let bit_idx = start_bit + bit_offset;
if bit_idx < total_bits {
if self.storage.deleted[bit_idx] {
byte |= 1 << bit_offset;
}
}
}
self.buffer.push(byte);
self.tombstone_offset += 1;
}
if self.tombstone_offset == total_bytes {
self.state = SerializationState::MetadataSection;
} else if bytes_to_produce == 0 {
break;
}
}
SerializationState::MetadataSection => {
let remaining_bytes =
self.metadata_section.len() - self.metadata_section_offset;
if remaining_bytes == 0 {
self.state = SerializationState::Done;
continue;
}
let bytes_to_copy = min(remaining_bytes, space_left);
let start = self.metadata_section_offset;
let end = start + bytes_to_copy;
self.buffer
.extend_from_slice(&self.metadata_section[start..end]);
self.metadata_section_offset += bytes_to_copy;
if self.metadata_section_offset == self.metadata_section.len() {
self.state = SerializationState::Done;
} else if bytes_to_copy == 0 {
break;
}
}
SerializationState::Done => break,
}
}
if self.buffer.is_empty() {
None
} else {
Some(self.buffer.clone())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hnsw::{HnswConfig, HnswIndex};
use crate::storage::VectorStorage;
#[test]
fn test_chunked_export_empty() {
let config = HnswConfig::new(128);
let storage = VectorStorage::new(&config, None);
let index = HnswIndex::new(config, &storage).unwrap();
let writer = (&storage, &index);
let mut iter = writer.export_chunked(1024);
let chunk1 = iter.next();
assert!(chunk1.is_some());
let data = chunk1.unwrap();
assert!(data.len() >= 64);
let header = FileHeader::from_bytes(&data[0..64]).unwrap();
assert_eq!(header.vector_count, 0);
}
#[test]
fn test_chunked_export_data() {
let config = HnswConfig::new(4); let mut storage = VectorStorage::new(&config, None);
#[allow(clippy::cast_precision_loss)]
for i in 0..10 {
storage.insert(&[i as f32; 4]).unwrap();
}
let index = HnswIndex::new(config, &storage).unwrap();
let writer = (&storage, &index);
let chunk_size = 70;
let iter = writer.export_chunked(chunk_size);
let mut total_bytes = 0;
for chunk in iter {
assert!(chunk.len() <= chunk_size);
total_bytes += chunk.len();
}
let expected = 64 + 160 + 2; assert_eq!(total_bytes, expected);
}
#[test]
fn test_chunk_size_zero_edge_case() {
let config = HnswConfig::new(4);
let mut storage = VectorStorage::new(&config, None);
storage.insert(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let index = HnswIndex::new(config, &storage).unwrap();
let writer = (&storage, &index);
let iter = writer.export_chunked(0);
let chunks: Vec<_> = iter.collect();
assert!(!chunks.is_empty(), "Should produce at least one chunk");
for chunk in &chunks {
assert!(
chunk.len() <= MIN_CHUNK_SIZE,
"Chunk size {} exceeds clamped minimum {}",
chunk.len(),
MIN_CHUNK_SIZE
);
}
let header = FileHeader::from_bytes(&chunks[0][0..64]).unwrap();
assert_eq!(header.vector_count, 1);
}
#[test]
fn test_chunk_size_one_edge_case() {
let config = HnswConfig::new(4);
let storage = VectorStorage::new(&config, None);
let index = HnswIndex::new(config, &storage).unwrap();
let writer = (&storage, &index);
let iter = writer.export_chunked(1);
let chunks: Vec<_> = iter.collect();
assert!(!chunks.is_empty());
assert_eq!(chunks[0].len(), MIN_CHUNK_SIZE);
}
#[test]
fn test_chunk_size_just_below_minimum() {
let config = HnswConfig::new(4);
let storage = VectorStorage::new(&config, None);
let index = HnswIndex::new(config, &storage).unwrap();
let writer = (&storage, &index);
let iter = writer.export_chunked(63);
let chunks: Vec<_> = iter.collect();
assert!(!chunks.is_empty());
assert_eq!(chunks[0].len(), MIN_CHUNK_SIZE);
}
#[test]
fn test_chunk_size_exactly_minimum() {
let config = HnswConfig::new(4);
let mut storage = VectorStorage::new(&config, None);
#[allow(clippy::cast_precision_loss)]
for i in 0..5 {
storage.insert(&[i as f32; 4]).unwrap();
}
let index = HnswIndex::new(config, &storage).unwrap();
let writer = (&storage, &index);
let iter = writer.export_chunked(MIN_CHUNK_SIZE);
let mut total_bytes = 0;
let mut chunk_count = 0;
for chunk in iter {
assert!(
chunk.len() <= MIN_CHUNK_SIZE,
"Chunk {} has size {} > {}",
chunk_count,
chunk.len(),
MIN_CHUNK_SIZE
);
total_bytes += chunk.len();
chunk_count += 1;
}
assert!(
chunk_count > 1,
"Expected multiple chunks, got {chunk_count}"
);
let expected = 64 + 80 + 1;
assert_eq!(total_bytes, expected);
}
#[test]
fn test_chunk_size_edge_case_data_integrity() {
let config = HnswConfig::new(4);
let mut storage = VectorStorage::new(&config, None);
let test_vectors = vec![
[1.0_f32, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
];
for v in &test_vectors {
storage.insert(v).unwrap();
}
let index = HnswIndex::new(config, &storage).unwrap();
let writer = (&storage, &index);
let iter = writer.export_chunked(0);
let mut combined = Vec::new();
for chunk in iter {
combined.extend_from_slice(&chunk);
}
let header = FileHeader::from_bytes(&combined[0..64]).unwrap();
assert_eq!(header.vector_count, 3);
assert_eq!(header.dimensions, 4);
let vector_bytes = &combined[64..64 + 48]; let vectors: &[f32] = bytemuck::cast_slice(vector_bytes);
assert_eq!(vectors[0..4], test_vectors[0]);
assert_eq!(vectors[4..8], test_vectors[1]);
assert_eq!(vectors[8..12], test_vectors[2]);
}
}