use crate::{
bucket::Bucket,
structs::{ColdDiskLevelIndex, FloatElement, LevelIndexConfig},
DliError, DliResult, Id,
};
use serde::{Deserialize, Serialize};
use std::{
collections::HashSet,
fs::File,
io::{Seek, SeekFrom, Write as _},
marker::PhantomData,
os::unix::fs::FileExt as _,
path::{Path, PathBuf},
};
pub const BLOCK_SIZE: usize = 65_536;
const HEADER_BYTES: usize = 4;
pub fn records_per_block<F: FloatElement>(input_shape: usize) -> usize {
let bytes_per_record = input_shape * std::mem::size_of::<F>() + std::mem::size_of::<Id>();
(BLOCK_SIZE - HEADER_BYTES) / bytes_per_record
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColdDiskBucket {
pub extents: Vec<u32>,
#[serde(default)]
pub count: usize,
}
#[derive(Debug)]
pub struct ColdStorage<F: FloatElement> {
pub disk_buckets: Vec<ColdDiskBucket>,
data_file: File,
data_path: PathBuf,
pub input_shape: usize,
pub ids: HashSet<Id>,
pub bucket_size: usize,
_marker: PhantomData<F>,
}
pub fn encode_block<F: FloatElement>(
records: &[F],
ids: &[Id],
input_shape: usize,
) -> [u8; BLOCK_SIZE] {
assert_eq!(records.len(), ids.len() * input_shape);
let n = ids.len();
assert!(n <= records_per_block::<F>(input_shape));
let mut buf = [0u8; BLOCK_SIZE];
buf[0..2].copy_from_slice(&(n as u16).to_le_bytes());
let bytes_per_vector = input_shape * std::mem::size_of::<F>();
let bytes_per_record = bytes_per_vector + std::mem::size_of::<Id>();
for i in 0..n {
let base = HEADER_BYTES + i * bytes_per_record;
let float_bytes: &[u8] =
bytemuck::cast_slice(&records[i * input_shape..(i + 1) * input_shape]);
buf[base..base + bytes_per_vector].copy_from_slice(float_bytes);
buf[base + bytes_per_vector..base + bytes_per_record]
.copy_from_slice(&ids[i].to_le_bytes());
}
buf
}
pub fn decode_block<F: FloatElement>(
buf: &[u8; BLOCK_SIZE],
input_shape: usize,
) -> DliResult<(Vec<F>, Vec<Id>)> {
let count = u16::from_le_bytes([buf[0], buf[1]]) as usize;
let rpb = records_per_block::<F>(input_shape);
if count > rpb {
return Err(DliError::IoError(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("cold block count {count} exceeds max {rpb}"),
)));
}
let bytes_per_vector = input_shape * std::mem::size_of::<F>();
let bytes_per_record = bytes_per_vector + std::mem::size_of::<Id>();
let mut records: Vec<F> = Vec::with_capacity(count * input_shape);
let mut ids: Vec<Id> = Vec::with_capacity(count);
for i in 0..count {
let base = HEADER_BYTES + i * bytes_per_record;
let float_bytes = &buf[base..base + bytes_per_vector];
let floats: &[F] = bytemuck::cast_slice(float_bytes);
records.extend_from_slice(floats);
let id_bytes: [u8; 4] = buf[base + bytes_per_vector..base + bytes_per_record]
.try_into()
.expect("id slice is always 4 bytes");
ids.push(u32::from_le_bytes(id_bytes));
}
Ok((records, ids))
}
impl<F: FloatElement> ColdStorage<F> {
pub fn new(
data_path: &Path,
n_buckets: usize,
input_shape: usize,
bucket_size: usize,
) -> DliResult<Self> {
if !data_path.exists() {
File::create(data_path)?;
}
let data_file = File::options().read(true).write(true).open(data_path)?;
let disk_buckets = ColdStorage::<F>::empty_disk_buckets(n_buckets);
Ok(Self {
disk_buckets,
data_file,
data_path: data_path.to_path_buf(),
input_shape,
_marker: PhantomData,
ids: HashSet::new(),
bucket_size,
})
}
fn empty_disk_buckets(n_buckets: usize) -> Vec<ColdDiskBucket> {
vec![
ColdDiskBucket {
extents: vec![],
count: 0,
};
n_buckets
]
}
pub fn load(
data_path: &Path,
meta_path: &Path,
input_shape: usize,
bucket_size: usize,
ids: HashSet<Id>,
) -> DliResult<Self> {
let disk_buckets = load_metadata(meta_path)?;
let data_file = File::options().read(true).write(true).open(data_path)?;
Ok(Self {
disk_buckets,
data_file,
data_path: data_path.to_path_buf(),
input_shape,
_marker: PhantomData,
ids,
bucket_size,
})
}
pub fn delete(&mut self, id: Id) -> bool {
self.ids.remove(&id)
}
pub fn insert(
&mut self,
records: Vec<F>,
ids: Vec<Id>,
assignments: &[usize],
) -> DliResult<()> {
let mut bucket_records: std::collections::HashMap<usize, (Vec<F>, Vec<Id>)> =
std::collections::HashMap::new();
let input_shape = self.input_shape;
let mut record_offset = 0;
for (id, &bucket_idx) in ids.iter().zip(assignments.iter()) {
let record_slice = &records[record_offset..record_offset + input_shape];
let bucket_entry = bucket_records
.entry(bucket_idx)
.or_insert_with(|| (Vec::with_capacity(input_shape * 10), Vec::new()));
bucket_entry.0.extend_from_slice(record_slice);
bucket_entry.1.push(*id);
record_offset += input_shape;
}
for (bucket_idx, (recs, ids)) in bucket_records {
self.append(bucket_idx, &recs, &ids)?;
}
let meta_path = meta_path_for(&self.data_path);
save_metadata(&self.disk_buckets, &meta_path)?;
Ok(())
}
fn append(&mut self, bucket_idx: usize, records: &[F], ids: &[Id]) -> DliResult<()> {
if records.is_empty() {
return Ok(());
}
debug_assert_eq!(records.len(), ids.len() * self.input_shape);
let rpb = records_per_block::<F>(self.input_shape);
let mut record_offset = 0;
let mut id_offset = 0;
while id_offset < ids.len() {
let remaining_ids = ids.len() - id_offset;
let records_in_this_block = std::cmp::min(rpb, remaining_ids);
let records_end = record_offset + records_in_this_block * self.input_shape;
let block_records = &records[record_offset..records_end];
let block_ids = &ids[id_offset..id_offset + records_in_this_block];
let block = encode_block(block_records, block_ids, self.input_shape);
self.data_file.seek(SeekFrom::End(0))?;
self.data_file.write_all(&block)?;
let block_idx = (self.data_file.stream_position()? / BLOCK_SIZE as u64) - 1;
if let Some(disk_bucket) = self.disk_buckets.get_mut(bucket_idx) {
disk_bucket.extents.push(block_idx as u32);
disk_bucket.count += records_in_this_block;
} else {
self.disk_buckets.push(ColdDiskBucket {
extents: vec![block_idx as u32],
count: records_in_this_block,
});
}
record_offset = records_end;
id_offset += records_in_this_block;
}
self.ids.extend(ids);
Ok(())
}
pub fn read_bucket(&self, bucket_id: usize) -> DliResult<Bucket<F>> {
let bucket = &self.disk_buckets[bucket_id];
if bucket.extents.is_empty() {
return Ok(Bucket::<F>::from_parts(vec![], vec![], self.input_shape));
}
let total_count = bucket.count;
let input_shape = self.input_shape;
let mut all_records: Vec<F> = Vec::with_capacity(total_count * input_shape);
let mut all_ids: Vec<Id> = Vec::with_capacity(total_count);
for &block_idx in &bucket.extents {
let mut buf = Box::new([0u8; BLOCK_SIZE]);
let byte_offset = block_idx as u64 * BLOCK_SIZE as u64;
self.data_file
.read_exact_at(buf.as_mut_slice(), byte_offset)?;
let (recs, ids) = decode_block::<F>(&buf, input_shape)?;
for (rec, id) in recs.chunks_exact(input_shape).zip(ids) {
if self.ids.contains(&id) {
all_records.extend(rec);
all_ids.push(id);
}
}
}
Ok(Bucket::<F>::from_parts(all_records, all_ids, input_shape))
}
pub fn get_data(&mut self) -> DliResult<(Vec<F>, Vec<Id>)> {
let file_size = self.data_file.metadata()?.len();
if file_size == 0 {
return Ok((vec![], vec![]));
}
let mut buf = vec![0u8; file_size as usize];
self.data_file.read_exact_at(&mut buf, 0)?;
let mut all_records = Vec::new();
let mut all_ids = Vec::new();
for bucket in &self.disk_buckets {
if bucket.extents.is_empty() {
continue;
}
let mut bucket_records: Vec<F> = Vec::new();
let mut bucket_ids = Vec::new();
for &block_idx in &bucket.extents {
let start = block_idx as usize * BLOCK_SIZE;
let end = start + BLOCK_SIZE;
let block: &[u8; BLOCK_SIZE] = buf[start..end].try_into().map_err(|_| {
DliError::IoError(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid block size in cold storage file",
))
})?;
let (recs, ids) = decode_block::<F>(block, self.input_shape)?;
for (id, rec) in ids.into_iter().zip(recs.chunks_exact(self.input_shape)) {
if self.ids.contains(&id) {
bucket_records.extend(rec);
bucket_ids.push(id);
}
}
}
all_records.extend(bucket_records);
all_ids.extend(bucket_ids);
}
self.disk_buckets = ColdStorage::<F>::empty_disk_buckets(self.disk_buckets.len());
self.ids.clear();
Ok((all_records, all_ids))
}
pub fn bucket_occupied(&self, bucket_idx: usize) -> usize {
assert!(bucket_idx < self.disk_buckets.len());
self.disk_buckets[bucket_idx].count
}
pub fn n_buckets(&self) -> usize {
self.disk_buckets.len()
}
pub fn dump(
&self,
_working_dir: &Path,
_level_id: usize,
config: &LevelIndexConfig,
) -> DliResult<ColdDiskLevelIndex> {
Ok(ColdDiskLevelIndex {
cold_data_path: self.data_path.clone(),
ids: self.ids.iter().cloned().collect::<Vec<_>>(),
config: config.clone(),
})
}
}
pub fn meta_path_for(data_path: &Path) -> std::path::PathBuf {
data_path.with_extension("cold.meta.json")
}
pub fn save_metadata(disk_buckets: &[ColdDiskBucket], meta_path: &Path) -> DliResult<()> {
let json = serde_json::to_string(disk_buckets)?;
std::fs::write(meta_path, json)?;
Ok(())
}
pub fn load_metadata(meta_path: &Path) -> DliResult<Vec<ColdDiskBucket>> {
let json = std::fs::read_to_string(meta_path)?;
Ok(serde_json::from_str(&json)?)
}
#[cfg(test)]
mod tests {
use super::*;
use half::f16;
#[test]
fn test_encode_decode_round_trip_f32() {
let input_shape = 4;
let records: Vec<f32> = (0..8).map(|x| x as f32).collect();
let ids: Vec<Id> = vec![10, 20];
let block = encode_block::<f32>(&records, &ids, input_shape);
let (dec_records, dec_ids) = decode_block::<f32>(&block, input_shape).unwrap();
assert_eq!(dec_records, records);
assert_eq!(dec_ids, ids);
}
#[test]
fn test_encode_decode_round_trip_f16() {
let input_shape = 4;
let records: Vec<f16> = (0..8).map(|x| f16::from_f32(x as f32)).collect();
let ids: Vec<Id> = vec![1, 2];
let block = encode_block::<f16>(&records, &ids, input_shape);
let (dec_records, dec_ids) = decode_block::<f16>(&block, input_shape).unwrap();
assert_eq!(dec_records, records);
assert_eq!(dec_ids, ids);
}
#[test]
fn test_encode_decode_empty_block() {
let input_shape = 4;
let block = encode_block::<f32>(&[], &[], input_shape);
let (dec_records, dec_ids) = decode_block::<f32>(&block, input_shape).unwrap();
assert!(dec_records.is_empty());
assert!(dec_ids.is_empty());
}
#[test]
fn test_encode_decode_partial_block() {
let input_shape = 768;
let records: Vec<f32> = vec![0.5f32; input_shape];
let ids: Vec<Id> = vec![42];
let block = encode_block::<f32>(&records, &ids, input_shape);
let count = u16::from_le_bytes([block[0], block[1]]);
assert_eq!(count, 1);
let (dec_records, dec_ids) = decode_block::<f32>(&block, input_shape).unwrap();
assert_eq!(dec_records, records);
assert_eq!(dec_ids, ids);
}
#[test]
fn test_encode_decode_full_block_f16_768() {
let input_shape = 768;
let rpb = records_per_block::<f16>(input_shape);
assert_eq!(rpb, 42, "expected 42 records/block for f16 dim=768");
let records: Vec<f16> = (0..rpb * input_shape)
.map(|i| f16::from_f32((i % 100) as f32))
.collect();
let ids: Vec<Id> = (0..rpb as u32).collect();
let block = encode_block::<f16>(&records, &ids, input_shape);
let (dec_records, dec_ids) = decode_block::<f16>(&block, input_shape).unwrap();
assert_eq!(dec_records.len(), rpb * input_shape);
assert_eq!(dec_ids.len(), rpb);
assert_eq!(dec_records, records);
assert_eq!(dec_ids, ids);
}
#[test]
fn test_append_single_record() {
use tempfile::TempDir;
let temp_dir = TempDir::new().unwrap();
let data_path = temp_dir.path().join("test.cold.data");
let input_shape = 4;
let n_buckets = 4;
let bucket_size = 100;
let mut cold_storage =
ColdStorage::<f32>::new(&data_path, n_buckets, input_shape, bucket_size).unwrap();
let records = vec![1.0f32, 2.0f32, 3.0f32, 4.0f32];
let ids = vec![10];
cold_storage.append(0, &records, &ids).unwrap();
assert_eq!(
cold_storage.disk_buckets[0].count, 1,
"bucket 0 should have 1 record"
);
assert_eq!(
cold_storage.disk_buckets[0].extents.len(),
1,
"bucket 0 should have 1 extent"
);
assert_eq!(
cold_storage.disk_buckets[0].extents[0], 0,
"first extent should be at block 0"
);
assert_eq!(cold_storage.disk_buckets[1].extents.is_empty(), true);
assert_eq!(cold_storage.disk_buckets[2].extents.is_empty(), true);
assert_eq!(cold_storage.disk_buckets[3].extents.is_empty(), true);
let bucket = cold_storage.read_bucket(0).unwrap();
assert_eq!(bucket.occupied(), 1);
assert_eq!(bucket.ids, &[10]);
assert_eq!(bucket.record(0), &records[..input_shape]);
}
#[test]
fn test_append_multiple_records() {
use tempfile::TempDir;
let temp_dir = TempDir::new().unwrap();
let data_path = temp_dir.path().join("test.cold.data");
let input_shape = 3;
let n_buckets = 2;
let bucket_size = 100;
let mut cold_storage =
ColdStorage::<f32>::new(&data_path, n_buckets, input_shape, bucket_size).unwrap();
let records = vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32];
let ids = vec![100, 101];
cold_storage.append(0, &records, &ids).unwrap();
assert_eq!(cold_storage.disk_buckets[0].count, 2);
assert_eq!(cold_storage.disk_buckets[0].extents.len(), 1);
let bucket = cold_storage.read_bucket(0).unwrap();
assert_eq!(bucket.occupied(), 2);
assert_eq!(bucket.ids, &[100, 101]);
assert_eq!(bucket.record(0), &records[..input_shape]);
assert_eq!(bucket.record(1), &records[input_shape..]);
}
#[test]
fn test_append_to_different_buckets() {
use tempfile::TempDir;
let temp_dir = TempDir::new().unwrap();
let data_path = temp_dir.path().join("test.cold.data");
let input_shape = 2;
let n_buckets = 3;
let bucket_size = 100;
let mut cold_storage =
ColdStorage::<f32>::new(&data_path, n_buckets, input_shape, bucket_size).unwrap();
let records_b0 = vec![1.0f32, 2.0f32];
let ids_b0 = vec![10];
cold_storage.append(0, &records_b0, &ids_b0).unwrap();
let records_b2 = vec![3.0f32, 4.0f32, 5.0f32, 6.0f32];
let ids_b2 = vec![20, 21];
cold_storage.append(2, &records_b2, &ids_b2).unwrap();
assert_eq!(cold_storage.disk_buckets[0].count, 1);
assert_eq!(cold_storage.disk_buckets[0].extents.len(), 1);
assert_eq!(cold_storage.disk_buckets[0].extents[0], 0);
assert_eq!(cold_storage.disk_buckets[1].extents.is_empty(), true);
assert_eq!(cold_storage.disk_buckets[1].extents.len(), 0);
assert_eq!(cold_storage.disk_buckets[2].count, 2);
assert_eq!(cold_storage.disk_buckets[2].extents.len(), 1);
assert_eq!(cold_storage.disk_buckets[2].extents[0], 1);
let bucket_0 = cold_storage.read_bucket(0).unwrap();
assert_eq!(bucket_0.ids, &[10]);
let bucket_2 = cold_storage.read_bucket(2).unwrap();
assert_eq!(bucket_2.ids, &[20, 21]);
}
#[test]
fn test_append_multiple_blocks_to_same_bucket() {
use tempfile::TempDir;
let temp_dir = TempDir::new().unwrap();
let data_path = temp_dir.path().join("test.cold.data");
let input_shape = 4;
let n_buckets = 1;
let bucket_size = 100;
let mut cold_storage =
ColdStorage::<f32>::new(&data_path, n_buckets, input_shape, bucket_size).unwrap();
let records_1 = vec![1.0f32, 2.0f32, 3.0f32, 4.0f32];
let ids_1 = vec![1];
cold_storage.append(0, &records_1, &ids_1).unwrap();
assert_eq!(cold_storage.disk_buckets[0].count, 1);
assert_eq!(cold_storage.disk_buckets[0].extents.len(), 1);
assert_eq!(cold_storage.disk_buckets[0].extents[0], 0);
let records_2 = vec![5.0f32, 6.0f32, 7.0f32, 8.0f32];
let ids_2 = vec![2];
cold_storage.append(0, &records_2, &ids_2).unwrap();
assert_eq!(
cold_storage.disk_buckets[0].count, 2,
"bucket should have 2 total records"
);
assert_eq!(
cold_storage.disk_buckets[0].extents.len(),
2,
"bucket should have 2 extents"
);
assert_eq!(
cold_storage.disk_buckets[0].extents[0], 0,
"first extent at block 0"
);
assert_eq!(
cold_storage.disk_buckets[0].extents[1], 1,
"second extent at block 1"
);
let bucket = cold_storage.read_bucket(0).unwrap();
assert_eq!(bucket.occupied(), 2);
assert_eq!(bucket.ids, &[1, 2]);
}
#[test]
fn test_append_empty_records() {
use tempfile::TempDir;
let temp_dir = TempDir::new().unwrap();
let data_path = temp_dir.path().join("test.cold.data");
let input_shape = 4;
let n_buckets = 2;
let bucket_size = 100;
let mut cold_storage =
ColdStorage::<f32>::new(&data_path, n_buckets, input_shape, bucket_size).unwrap();
let records: Vec<f32> = vec![];
let ids: Vec<Id> = vec![];
cold_storage.append(0, &records, &ids).unwrap();
assert_eq!(cold_storage.disk_buckets[0].extents.is_empty(), true);
assert_eq!(cold_storage.disk_buckets[0].extents.len(), 0);
let file_size = cold_storage.data_file.metadata().unwrap().len();
assert_eq!(
file_size, 0,
"file should remain empty after appending empty records"
);
}
#[test]
fn test_insert_empty_bucket() {
use tempfile::TempDir;
let temp_dir = TempDir::new().unwrap();
let data_path = temp_dir.path().join("test.cold.data");
let input_shape = 4;
let n_buckets = 3;
let bucket_size = 100;
let mut cold_storage =
ColdStorage::<f32>::new(&data_path, n_buckets, input_shape, bucket_size).unwrap();
let records = vec![1.0f32, 2.0f32, 3.0f32, 4.0f32];
let ids = vec![42];
let assignments = vec![1];
cold_storage
.insert(records.clone(), ids.clone(), &assignments)
.unwrap();
assert_eq!(cold_storage.disk_buckets[0].extents.is_empty(), true);
assert_eq!(cold_storage.disk_buckets[1].count, 1);
assert_eq!(cold_storage.disk_buckets[1].extents.len(), 1);
assert_eq!(cold_storage.disk_buckets[2].extents.is_empty(), true);
let bucket = cold_storage.read_bucket(1).unwrap();
assert_eq!(bucket.records_slice(), &records);
}
#[test]
fn test_insert_multiple_buckets() {
use tempfile::TempDir;
let temp_dir = TempDir::new().unwrap();
let data_path = temp_dir.path().join("test.cold.data");
let input_shape = 3;
let n_buckets = 4;
let bucket_size = 100;
let mut cold_storage =
ColdStorage::<f32>::new(&data_path, n_buckets, input_shape, bucket_size).unwrap();
let records = vec![
1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, 8.0f32, 9.0f32, 10.0f32, 11.0f32, 12.0f32, 13.0f32, 14.0f32, 15.0f32, 16.0f32, 17.0f32, 18.0f32, ];
let ids = vec![100, 101, 102, 103, 104, 105];
let assignments = vec![0, 0, 2, 2, 3, 3];
cold_storage
.insert(records.clone(), ids.clone(), &assignments)
.unwrap();
assert_eq!(cold_storage.disk_buckets[0].count, 2);
assert_eq!(cold_storage.disk_buckets[0].extents.len(), 1);
assert_eq!(cold_storage.disk_buckets[1].extents.is_empty(), true);
assert_eq!(cold_storage.disk_buckets[2].count, 2);
assert_eq!(cold_storage.disk_buckets[2].extents.len(), 1);
assert_eq!(cold_storage.disk_buckets[3].count, 2);
assert_eq!(cold_storage.disk_buckets[3].extents.len(), 1);
let bucket_0 = cold_storage.read_bucket(0).unwrap();
let bucket_2 = cold_storage.read_bucket(2).unwrap();
let bucket_3 = cold_storage.read_bucket(3).unwrap();
assert_eq!(bucket_0.records_slice().len(), 6); assert_eq!(bucket_2.records_slice().len(), 6);
assert_eq!(bucket_3.records_slice().len(), 6);
}
#[test]
fn test_insert_overflowing_bucket() {
use tempfile::TempDir;
let temp_dir = TempDir::new().unwrap();
let data_path = temp_dir.path().join("test.cold.data");
let input_shape = 4;
let n_buckets = 1;
let bucket_size = 2;
let mut cold_storage =
ColdStorage::<f32>::new(&data_path, n_buckets, input_shape, bucket_size).unwrap();
let records_1 = vec![
1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, 8.0f32, ];
let ids_1 = vec![10, 11];
let assignments_1 = vec![0, 0];
cold_storage
.insert(records_1, ids_1, &assignments_1)
.unwrap();
assert_eq!(cold_storage.disk_buckets[0].count, 2);
assert_eq!(cold_storage.disk_buckets[0].extents.len(), 1);
let records_2 = vec![
9.0f32, 10.0f32, 11.0f32, 12.0f32, 13.0f32, 14.0f32, 15.0f32, 16.0f32, 17.0f32, 18.0f32, 19.0f32, 20.0f32, ];
let ids_2 = vec![20, 21, 22];
let assignments_2 = vec![0, 0, 0];
cold_storage
.insert(records_2, ids_2, &assignments_2)
.unwrap();
assert_eq!(
cold_storage.disk_buckets[0].count, 5,
"bucket should have 5 records after overflow"
);
assert!(
cold_storage.disk_buckets[0].extents.len() >= 1,
"bucket should have at least 1 extent"
);
let bucket = cold_storage.read_bucket(0).unwrap();
assert_eq!(bucket.records_slice().len(), 20); }
#[test]
fn test_insert_single_record_per_bucket() {
use tempfile::TempDir;
let temp_dir = TempDir::new().unwrap();
let data_path = temp_dir.path().join("test.cold.data");
let input_shape = 2;
let n_buckets = 5;
let bucket_size = 100;
let mut cold_storage =
ColdStorage::<f32>::new(&data_path, n_buckets, input_shape, bucket_size).unwrap();
let records = vec![
1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, 8.0f32, 9.0f32, 10.0f32, ];
let ids = vec![1, 2, 3, 4, 5];
let assignments = vec![0, 1, 2, 3, 4];
cold_storage.insert(records, ids, &assignments).unwrap();
for i in 0..n_buckets {
assert_eq!(
cold_storage.disk_buckets[i].count, 1,
"bucket {} should have 1 record",
i
);
assert_eq!(
cold_storage.disk_buckets[i].extents.len(),
1,
"bucket {} should have 1 extent",
i
);
}
let mut max_block_idx = -1i32;
for i in 0..n_buckets {
let block_idx = cold_storage.disk_buckets[i].extents[0] as i32;
assert!(
block_idx >= 0,
"bucket {} should have a valid block index",
i
);
max_block_idx = max_block_idx.max(block_idx);
}
assert!(
max_block_idx + 1 <= 5,
"should use at most 5 blocks for 5 records"
);
}
#[test]
fn test_insert_interleaved_assignments() {
use tempfile::TempDir;
let temp_dir = TempDir::new().unwrap();
let data_path = temp_dir.path().join("test.cold.data");
let input_shape = 2;
let n_buckets = 2;
let bucket_size = 100;
let mut cold_storage =
ColdStorage::<f32>::new(&data_path, n_buckets, input_shape, bucket_size).unwrap();
let records = vec![
1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, 8.0f32, ];
let ids = vec![10, 20, 11, 21];
let assignments = vec![0, 1, 0, 1];
cold_storage.insert(records, ids, &assignments).unwrap();
assert_eq!(cold_storage.disk_buckets[0].count, 2);
assert_eq!(cold_storage.disk_buckets[0].extents.len(), 1);
assert_eq!(cold_storage.disk_buckets[1].count, 2);
assert_eq!(cold_storage.disk_buckets[1].extents.len(), 1);
let bucket_0 = cold_storage.read_bucket(0).unwrap();
let bucket_1 = cold_storage.read_bucket(1).unwrap();
assert_eq!(bucket_0.records_slice().len(), 4); assert_eq!(bucket_1.records_slice().len(), 4); }
#[test]
fn test_insert_actual_block_overflow() {
use tempfile::TempDir;
let temp_dir = TempDir::new().unwrap();
let data_path = temp_dir.path().join("test.cold.data");
let input_shape = 4;
let n_buckets = 1;
let bucket_size = 10000;
let mut cold_storage =
ColdStorage::<f32>::new(&data_path, n_buckets, input_shape, bucket_size).unwrap();
let num_records = 3286;
let mut records = Vec::with_capacity(num_records * input_shape);
let mut ids = Vec::with_capacity(num_records);
let mut assignments = Vec::with_capacity(num_records);
for i in 0..num_records {
for j in 0..input_shape {
records.push((i as f32) + (j as f32) * 0.1);
}
ids.push(i as u32);
assignments.push(0); }
cold_storage
.insert(records.clone(), ids.clone(), &assignments)
.unwrap();
assert_eq!(
cold_storage.disk_buckets[0].count, num_records,
"bucket should have {} records",
num_records
);
assert!(
cold_storage.disk_buckets[0].extents.len() >= 2,
"should have at least 2 extents to hold 3286 records"
);
let bucket = cold_storage.read_bucket(0).unwrap();
assert_eq!(
bucket.records_slice().len(),
num_records * input_shape,
"should read back all {} records",
num_records
);
let read_records = bucket.records_slice();
assert_eq!(read_records[0], 0.0);
assert_eq!(read_records[1], 0.1);
assert_eq!(read_records[2], 0.2);
assert_eq!(read_records[3], 0.3);
let last_idx = (num_records - 1) * input_shape;
assert_eq!(read_records[last_idx], 3285.0);
assert_eq!(read_records[last_idx + 1], 3285.1);
assert_eq!(read_records[last_idx + 2], 3285.2);
assert_eq!(read_records[last_idx + 3], 3285.3);
}
}