use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::thread::{self, JoinHandle};
#[derive(Debug, Clone)]
pub struct RotationConfig {
pub num_workers: usize,
pub input_capacity: usize,
pub output_capacity: usize,
pub dim: usize,
pub batch_size: usize,
}
impl Default for RotationConfig {
fn default() -> Self {
Self {
num_workers: 4,
input_capacity: 1024,
output_capacity: 1024,
dim: 768,
batch_size: 16,
}
}
}
pub type VectorKey = u64;
#[derive(Clone)]
pub struct RotationInput {
pub key: VectorKey,
pub vector: Vec<f32>,
pub seq: u64,
}
#[derive(Clone)]
pub struct RotationOutput {
pub key: VectorKey,
pub rotated: Vec<f32>,
pub seq: u64,
pub rotation_time_ns: u64,
}
#[derive(Debug, Clone, Default)]
pub struct PipelineStats {
pub submitted: u64,
pub completed: u64,
pub total_rotation_ns: u64,
pub in_flight: u64,
}
impl PipelineStats {
pub fn avg_rotation_ns(&self) -> f64 {
if self.completed == 0 {
return 0.0;
}
self.total_rotation_ns as f64 / self.completed as f64
}
pub fn throughput(&self) -> f64 {
if self.total_rotation_ns == 0 {
return 0.0;
}
self.completed as f64 / (self.total_rotation_ns as f64 / 1e9)
}
}
struct BoundedChannel<T> {
buffer: Mutex<Vec<T>>,
capacity: usize,
}
impl<T> BoundedChannel<T> {
fn new(capacity: usize) -> Self {
Self {
buffer: Mutex::new(Vec::with_capacity(capacity)),
capacity,
}
}
fn try_push(&self, item: T) -> Result<(), T> {
let mut buffer = self.buffer.lock().unwrap();
if buffer.len() >= self.capacity {
return Err(item);
}
buffer.push(item);
Ok(())
}
#[allow(dead_code)]
fn push_single(&self, item: T) -> bool {
self.try_push(item).is_ok()
}
fn try_pop(&self) -> Option<T> {
let mut buffer = self.buffer.lock().unwrap();
buffer.pop()
}
fn try_pop_batch(&self, max: usize) -> Vec<T> {
let mut buffer = self.buffer.lock().unwrap();
let len = buffer.len();
let drain_count = len.min(max);
let start = len.saturating_sub(drain_count);
buffer.drain(start..).collect()
}
fn len(&self) -> usize {
self.buffer.lock().unwrap().len()
}
}
impl<T: Clone> BoundedChannel<T> {
fn push_blocking(&self, item: T) {
loop {
match self.try_push(item.clone()) {
Ok(()) => return,
Err(_) => {
std::thread::sleep(std::time::Duration::from_micros(10));
}
}
}
}
}
pub struct RotationPipeline {
#[allow(dead_code)]
config: RotationConfig,
input: Arc<BoundedChannel<RotationInput>>,
output: Arc<BoundedChannel<RotationOutput>>,
workers: Vec<JoinHandle<()>>,
shutdown: Arc<AtomicBool>,
seq_counter: AtomicU64,
stats: Arc<PipelineStatsInner>,
}
struct PipelineStatsInner {
submitted: AtomicU64,
completed: AtomicU64,
total_rotation_ns: AtomicU64,
}
impl RotationPipeline {
pub fn new(config: RotationConfig) -> Self {
let input = Arc::new(BoundedChannel::new(config.input_capacity));
let output = Arc::new(BoundedChannel::new(config.output_capacity));
let shutdown = Arc::new(AtomicBool::new(false));
let stats = Arc::new(PipelineStatsInner {
submitted: AtomicU64::new(0),
completed: AtomicU64::new(0),
total_rotation_ns: AtomicU64::new(0),
});
let mut workers = Vec::with_capacity(config.num_workers);
for _ in 0..config.num_workers {
let input = Arc::clone(&input);
let output = Arc::clone(&output);
let shutdown = Arc::clone(&shutdown);
let stats = Arc::clone(&stats);
let batch_size = config.batch_size;
let handle = thread::spawn(move || {
worker_loop(input, output, shutdown, stats, batch_size);
});
workers.push(handle);
}
Self {
config,
input,
output,
workers,
shutdown,
seq_counter: AtomicU64::new(0),
stats,
}
}
pub fn submit(&self, key: VectorKey, vector: Vec<f32>) {
let seq = self.seq_counter.fetch_add(1, Ordering::Relaxed);
let input = RotationInput { key, vector, seq };
self.input.push_blocking(input);
self.stats.submitted.fetch_add(1, Ordering::Relaxed);
}
pub fn submit_batch(&self, items: Vec<(VectorKey, Vec<f32>)>) {
for (key, vector) in items {
self.submit(key, vector);
}
}
pub fn try_recv(&self) -> Option<RotationOutput> {
self.output.try_pop()
}
pub fn recv(&self) -> Option<RotationOutput> {
loop {
if let Some(output) = self.output.try_pop() {
return Some(output);
}
if self.shutdown.load(Ordering::Acquire) && self.input.len() == 0 {
return self.output.try_pop();
}
std::thread::sleep(std::time::Duration::from_micros(10));
}
}
pub fn recv_batch(&self, max: usize) -> Vec<RotationOutput> {
self.output.try_pop_batch(max)
}
pub fn stats(&self) -> PipelineStats {
let submitted = self.stats.submitted.load(Ordering::Relaxed);
let completed = self.stats.completed.load(Ordering::Relaxed);
PipelineStats {
submitted,
completed,
total_rotation_ns: self.stats.total_rotation_ns.load(Ordering::Relaxed),
in_flight: submitted.saturating_sub(completed),
}
}
pub fn flush(&self) -> Vec<RotationOutput> {
let mut results = Vec::new();
loop {
let stats = self.stats();
if stats.completed >= stats.submitted {
break;
}
results.extend(self.recv_batch(64));
std::thread::sleep(std::time::Duration::from_micros(100));
}
results.extend(self.recv_batch(1024));
results
}
pub fn shutdown(mut self) -> Vec<RotationOutput> {
self.shutdown.store(true, Ordering::Release);
for handle in self.workers.drain(..) {
let _ = handle.join();
}
let mut results = Vec::new();
while let Some(output) = self.output.try_pop() {
results.push(output);
}
results
}
}
fn worker_loop(
input: Arc<BoundedChannel<RotationInput>>,
output: Arc<BoundedChannel<RotationOutput>>,
shutdown: Arc<AtomicBool>,
stats: Arc<PipelineStatsInner>,
batch_size: usize,
) {
loop {
let batch = input.try_pop_batch(batch_size);
if batch.is_empty() {
if shutdown.load(Ordering::Acquire) {
break;
}
std::thread::sleep(std::time::Duration::from_micros(10));
continue;
}
for item in batch {
let start = std::time::Instant::now();
let mut rotated = item.vector;
hadamard_transform(&mut rotated);
let rotation_time_ns = start.elapsed().as_nanos() as u64;
let result = RotationOutput {
key: item.key,
rotated,
seq: item.seq,
rotation_time_ns,
};
output.push_blocking(result);
stats.completed.fetch_add(1, Ordering::Relaxed);
stats
.total_rotation_ns
.fetch_add(rotation_time_ns, Ordering::Relaxed);
}
}
}
pub fn hadamard_transform(data: &mut [f32]) {
let n = data.len();
if n == 0 {
return;
}
let n_pow2 = n.next_power_of_two();
if n_pow2 != n {
normalize_vector(data);
return;
}
let mut h = 1;
while h < n {
for i in (0..n).step_by(h * 2) {
for j in i..(i + h) {
let x = data[j];
let y = data[j + h];
data[j] = x + y;
data[j + h] = x - y;
}
}
h *= 2;
}
let scale = 1.0 / (n as f32).sqrt();
for x in data.iter_mut() {
*x *= scale;
}
}
fn normalize_vector(data: &mut [f32]) {
let norm: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for x in data.iter_mut() {
*x /= norm;
}
}
}
pub struct SyncRotator {
#[allow(dead_code)]
buffer: Vec<f32>,
}
impl SyncRotator {
pub fn new(dim: usize) -> Self {
Self {
buffer: vec![0.0; dim],
}
}
pub fn rotate_inplace(&self, data: &mut [f32]) {
hadamard_transform(data);
}
pub fn rotate(&self, vector: &[f32]) -> Vec<f32> {
let mut rotated = vector.to_vec();
hadamard_transform(&mut rotated);
rotated
}
pub fn rotate_batch(&self, vectors: &[Vec<f32>]) -> Vec<Vec<f32>> {
vectors.iter().map(|v| self.rotate(v)).collect()
}
pub fn rotate_batch_flat(&self, flat_data: &mut [f32], dim: usize) {
let num_vectors = flat_data.len() / dim;
for i in 0..num_vectors {
let start = i * dim;
let slice = &mut flat_data[start..start + dim];
hadamard_transform(slice);
}
}
}
impl Default for SyncRotator {
fn default() -> Self {
Self::new(768)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hadamard_basic() {
let mut data = vec![1.0, 0.0, 0.0, 0.0];
hadamard_transform(&mut data);
for &x in &data {
assert!((x - 0.5).abs() < 0.01, "x = {}", x);
}
}
#[test]
fn test_hadamard_preserves_norm() {
let mut data: Vec<f32> = (0..16).map(|i| i as f32 / 16.0).collect();
let original_norm: f32 = data.iter().map(|x| x * x).sum();
hadamard_transform(&mut data);
let transformed_norm: f32 = data.iter().map(|x| x * x).sum();
assert!(
(original_norm - transformed_norm).abs() < 0.01,
"norm changed: {} -> {}",
original_norm,
transformed_norm
);
}
#[test]
fn test_sync_rotator() {
let rotator = SyncRotator::new(4);
let vector = vec![1.0, 2.0, 3.0, 4.0];
let rotated = rotator.rotate(&vector);
assert_eq!(rotated.len(), 4);
assert_eq!(vector, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_pipeline_basic() {
let config = RotationConfig {
num_workers: 2,
input_capacity: 16,
output_capacity: 16,
dim: 4,
batch_size: 4,
};
let pipeline = RotationPipeline::new(config);
for i in 0..10 {
let vector = vec![i as f32; 4];
pipeline.submit(i, vector);
}
let results = pipeline.flush();
assert_eq!(results.len(), 10);
}
#[test]
fn test_pipeline_ordering() {
let config = RotationConfig {
num_workers: 1, input_capacity: 32,
output_capacity: 32,
dim: 4,
batch_size: 1,
};
let pipeline = RotationPipeline::new(config);
for i in 0..5 {
pipeline.submit(i as u64, vec![i as f32; 4]);
}
let mut results = pipeline.flush();
results.sort_by_key(|r| r.seq);
for (i, result) in results.iter().enumerate() {
assert_eq!(result.key, i as u64);
}
}
#[test]
fn test_pipeline_stats() {
let config = RotationConfig::default();
let pipeline = RotationPipeline::new(config);
for i in 0..5 {
pipeline.submit(i, vec![0.0; 768]);
}
let initial_stats = pipeline.stats();
assert_eq!(initial_stats.submitted, 5);
let _ = pipeline.flush();
let final_stats = pipeline.stats();
assert_eq!(final_stats.completed, 5);
assert!(final_stats.total_rotation_ns > 0);
}
#[test]
fn test_pipeline_shutdown() {
let config = RotationConfig {
num_workers: 2,
dim: 4,
..Default::default()
};
let pipeline = RotationPipeline::new(config);
pipeline.submit(1, vec![1.0; 4]);
pipeline.submit(2, vec![2.0; 4]);
let results = pipeline.shutdown();
assert!(results.len() <= 2); }
}