use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct BatchConfig {
pub dim: usize,
pub enable_rotation: bool,
pub parallel_threshold: usize,
pub rotation_threads: usize,
pub initial_capacity: usize,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
dim: 768,
enable_rotation: true,
parallel_threshold: 100,
rotation_threads: 4,
initial_capacity: 10_000,
}
}
}
pub type VectorKey = u64;
#[derive(Debug, Clone, Default)]
pub struct BatchWriteStats {
pub vectors_added: usize,
pub bytes_processed: usize,
pub rotation_time_ns: u64,
pub copy_time_ns: u64,
pub batches_processed: usize,
}
impl BatchWriteStats {
pub fn rotation_mb_per_sec(&self) -> f64 {
if self.rotation_time_ns == 0 {
return 0.0;
}
let mb = self.bytes_processed as f64 / (1024.0 * 1024.0);
mb / (self.rotation_time_ns as f64 / 1e9)
}
}
#[derive(Clone)]
pub struct StoredVector {
pub key: VectorKey,
pub data: Vec<f32>,
pub index: u32,
}
#[derive(Debug, Clone)]
pub enum BatchWriteError {
DimensionMismatch { expected: usize, actual: usize },
KeyCountMismatch { vectors: usize, keys: usize },
DuplicateKey(VectorKey),
BuildError(String),
}
impl std::fmt::Display for BatchWriteError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::DimensionMismatch { expected, actual } => {
write!(
f,
"dimension mismatch: expected {}, got {}",
expected, actual
)
}
Self::KeyCountMismatch { vectors, keys } => {
write!(f, "key count mismatch: {} vectors, {} keys", vectors, keys)
}
Self::DuplicateKey(k) => write!(f, "duplicate key: {}", k),
Self::BuildError(s) => write!(f, "build error: {}", s),
}
}
}
impl std::error::Error for BatchWriteError {}
pub struct BatchSegmentWriter {
config: BatchConfig,
vectors: Vec<StoredVector>,
key_to_index: HashMap<VectorKey, u32>,
#[allow(dead_code)]
rotation_buffer: Vec<f32>,
stats: BatchWriteStats,
}
impl BatchSegmentWriter {
pub fn new(config: BatchConfig) -> Self {
let initial_capacity = config.initial_capacity;
let dim = config.dim;
Self {
config,
vectors: Vec::with_capacity(initial_capacity),
key_to_index: HashMap::with_capacity(initial_capacity),
rotation_buffer: vec![0.0; dim],
stats: BatchWriteStats::default(),
}
}
pub fn add(&mut self, key: VectorKey, vector: &[f32]) -> Result<u32, BatchWriteError> {
if vector.len() != self.config.dim {
return Err(BatchWriteError::DimensionMismatch {
expected: self.config.dim,
actual: vector.len(),
});
}
if self.key_to_index.contains_key(&key) {
return Err(BatchWriteError::DuplicateKey(key));
}
let index = self.vectors.len() as u32;
let data = if self.config.enable_rotation {
let start = std::time::Instant::now();
let rotated = self.rotate_vector(vector);
self.stats.rotation_time_ns += start.elapsed().as_nanos() as u64;
rotated
} else {
vector.to_vec()
};
self.vectors.push(StoredVector { key, data, index });
self.key_to_index.insert(key, index);
self.stats.vectors_added += 1;
self.stats.bytes_processed += vector.len() * 4;
Ok(index)
}
pub fn add_batch(
&mut self,
keys: &[VectorKey],
vectors: &[Vec<f32>],
) -> Result<Vec<u32>, BatchWriteError> {
if keys.len() != vectors.len() {
return Err(BatchWriteError::KeyCountMismatch {
vectors: vectors.len(),
keys: keys.len(),
});
}
let mut indices = Vec::with_capacity(keys.len());
for (key, vector) in keys.iter().zip(vectors.iter()) {
let index = self.add(*key, vector)?;
indices.push(index);
}
self.stats.batches_processed += 1;
Ok(indices)
}
pub fn add_batch_contiguous(
&mut self,
flat_data: &[f32],
keys: &[VectorKey],
) -> Result<Vec<u32>, BatchWriteError> {
let dim = self.config.dim;
let num_vectors = flat_data.len() / dim;
if flat_data.len() % dim != 0 {
return Err(BatchWriteError::DimensionMismatch {
expected: dim * keys.len(),
actual: flat_data.len(),
});
}
if keys.len() != num_vectors {
return Err(BatchWriteError::KeyCountMismatch {
vectors: num_vectors,
keys: keys.len(),
});
}
for key in keys {
if self.key_to_index.contains_key(key) {
return Err(BatchWriteError::DuplicateKey(*key));
}
}
let start_index = self.vectors.len() as u32;
let mut indices = Vec::with_capacity(num_vectors);
if self.config.enable_rotation && num_vectors >= self.config.parallel_threshold {
let rotated_vectors = self.rotate_batch_parallel(flat_data, num_vectors);
for (i, (key, data)) in keys.iter().zip(rotated_vectors.into_iter()).enumerate() {
let index = start_index + i as u32;
self.vectors.push(StoredVector {
key: *key,
data,
index,
});
self.key_to_index.insert(*key, index);
indices.push(index);
}
} else {
for (i, key) in keys.iter().enumerate() {
let start = i * dim;
let vector = &flat_data[start..start + dim];
let data = if self.config.enable_rotation {
let start_time = std::time::Instant::now();
let rotated = self.rotate_vector(vector);
self.stats.rotation_time_ns += start_time.elapsed().as_nanos() as u64;
rotated
} else {
vector.to_vec()
};
let index = start_index + i as u32;
self.vectors.push(StoredVector {
key: *key,
data,
index,
});
self.key_to_index.insert(*key, index);
indices.push(index);
}
}
self.stats.vectors_added += num_vectors;
self.stats.bytes_processed += flat_data.len() * 4;
self.stats.batches_processed += 1;
Ok(indices)
}
fn rotate_vector(&self, vector: &[f32]) -> Vec<f32> {
let mut rotated = vector.to_vec();
hadamard_transform(&mut rotated);
rotated
}
fn rotate_batch_parallel(&self, flat_data: &[f32], num_vectors: usize) -> Vec<Vec<f32>> {
use std::thread;
let start = std::time::Instant::now();
let dim = self.config.dim;
let num_threads = self.config.rotation_threads.min(num_vectors);
let chunk_size = (num_vectors + num_threads - 1) / num_threads;
let flat_data = Arc::new(flat_data.to_vec());
let mut handles = Vec::with_capacity(num_threads);
for t in 0..num_threads {
let flat_data = Arc::clone(&flat_data);
let start_vec = t * chunk_size;
let end_vec = (start_vec + chunk_size).min(num_vectors);
handles.push(thread::spawn(move || {
let mut results = Vec::with_capacity(end_vec - start_vec);
for i in start_vec..end_vec {
let start_idx = i * dim;
let mut rotated = flat_data[start_idx..start_idx + dim].to_vec();
hadamard_transform(&mut rotated);
results.push(rotated);
}
results
}));
}
let mut all_results = Vec::with_capacity(num_vectors);
for handle in handles {
let chunk_results = handle.join().unwrap();
all_results.extend(chunk_results);
}
let _elapsed = start.elapsed().as_nanos() as u64;
all_results
}
pub fn len(&self) -> usize {
self.vectors.len()
}
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
pub fn stats(&self) -> &BatchWriteStats {
&self.stats
}
pub fn get(&self, key: VectorKey) -> Option<&[f32]> {
self.key_to_index
.get(&key)
.map(|&idx| self.vectors[idx as usize].data.as_slice())
}
pub fn get_by_index(&self, index: u32) -> Option<&[f32]> {
self.vectors.get(index as usize).map(|v| v.data.as_slice())
}
pub fn build(self) -> Result<BuiltSegment, BatchWriteError> {
Ok(BuiltSegment {
vectors: self.vectors,
key_to_index: self.key_to_index,
dim: self.config.dim,
stats: self.stats,
})
}
}
pub struct BuiltSegment {
pub vectors: Vec<StoredVector>,
pub key_to_index: HashMap<VectorKey, u32>,
pub dim: usize,
pub stats: BatchWriteStats,
}
impl BuiltSegment {
pub fn get(&self, key: VectorKey) -> Option<&[f32]> {
self.key_to_index
.get(&key)
.map(|&idx| self.vectors[idx as usize].data.as_slice())
}
pub fn get_all_data(&self) -> Vec<f32> {
let mut data = Vec::with_capacity(self.vectors.len() * self.dim);
for v in &self.vectors {
data.extend_from_slice(&v.data);
}
data
}
pub fn len(&self) -> usize {
self.vectors.len()
}
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
}
fn hadamard_transform(data: &mut [f32]) {
let n = data.len();
if n == 0 || (n & (n - 1)) != 0 {
return;
}
let mut h = 1;
while h < n {
for i in (0..n).step_by(h * 2) {
for j in i..(i + h) {
let x = data[j];
let y = data[j + h];
data[j] = x + y;
data[j + h] = x - y;
}
}
h *= 2;
}
let scale = 1.0 / (n as f32).sqrt();
for x in data.iter_mut() {
*x *= scale;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batch_writer_basic() {
let config = BatchConfig {
dim: 4,
enable_rotation: false,
..Default::default()
};
let mut writer = BatchSegmentWriter::new(config);
let idx = writer.add(1, &[1.0, 2.0, 3.0, 4.0]).unwrap();
assert_eq!(idx, 0);
let retrieved = writer.get(1).unwrap();
assert_eq!(retrieved, &[1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_batch_writer_contiguous() {
let config = BatchConfig {
dim: 4,
enable_rotation: false,
..Default::default()
};
let mut writer = BatchSegmentWriter::new(config);
let flat_data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ];
let keys = vec![10, 20, 30];
let indices = writer.add_batch_contiguous(&flat_data, &keys).unwrap();
assert_eq!(indices, vec![0, 1, 2]);
assert_eq!(writer.len(), 3);
assert_eq!(writer.get(10).unwrap(), &[1.0, 2.0, 3.0, 4.0]);
assert_eq!(writer.get(20).unwrap(), &[5.0, 6.0, 7.0, 8.0]);
assert_eq!(writer.get(30).unwrap(), &[9.0, 10.0, 11.0, 12.0]);
}
#[test]
fn test_batch_writer_rotation() {
let config = BatchConfig {
dim: 4, enable_rotation: true,
..Default::default()
};
let mut writer = BatchSegmentWriter::new(config);
let _ = writer.add(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
let rotated = writer.get(1).unwrap();
let norm_sq: f32 = rotated.iter().map(|x| x * x).sum();
assert!((norm_sq - 1.0).abs() < 0.1, "norm_sq = {}", norm_sq);
}
#[test]
fn test_duplicate_key_error() {
let config = BatchConfig {
dim: 4,
enable_rotation: false,
..Default::default()
};
let mut writer = BatchSegmentWriter::new(config);
writer.add(1, &[1.0, 2.0, 3.0, 4.0]).unwrap();
let result = writer.add(1, &[5.0, 6.0, 7.0, 8.0]);
assert!(matches!(result, Err(BatchWriteError::DuplicateKey(1))));
}
#[test]
fn test_dimension_mismatch_error() {
let config = BatchConfig {
dim: 4,
enable_rotation: false,
..Default::default()
};
let mut writer = BatchSegmentWriter::new(config);
let result = writer.add(1, &[1.0, 2.0, 3.0]);
assert!(matches!(
result,
Err(BatchWriteError::DimensionMismatch {
expected: 4,
actual: 3
})
));
}
#[test]
fn test_build_segment() {
let config = BatchConfig {
dim: 4,
enable_rotation: false,
..Default::default()
};
let mut writer = BatchSegmentWriter::new(config);
let flat_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let keys = vec![100, 200];
writer.add_batch_contiguous(&flat_data, &keys).unwrap();
let segment = writer.build().unwrap();
assert_eq!(segment.len(), 2);
assert_eq!(segment.get(100).unwrap(), &[1.0, 2.0, 3.0, 4.0]);
assert_eq!(segment.get(200).unwrap(), &[5.0, 6.0, 7.0, 8.0]);
}
#[test]
fn test_hadamard_transform() {
let mut data = vec![1.0, 0.0, 0.0, 0.0];
hadamard_transform(&mut data);
for &x in &data {
assert!((x - 0.5).abs() < 0.01, "x = {}", x);
}
}
}