use rand::{RngExt, SeedableRng};
use rand_chacha::ChaCha8Rng;
use rand_xoshiro::Xoshiro256PlusPlus;
use rayon::prelude::*;
use crate::{bitstream, simd};
const RAYON_ENCODE_THRESHOLD: usize = 128;
const RAYON_NEURON_THRESHOLD: usize = 8;
#[derive(Clone, Debug)]
pub struct DenseLayer {
pub n_inputs: usize,
pub n_neurons: usize,
pub length: usize,
pub inv_length: f64,
pub words_per_input: usize,
pub weights: Vec<Vec<f64>>,
packed_weights_flat: Vec<u64>,
weight_seed: u64,
}
impl DenseLayer {
pub fn new(n_inputs: usize, n_neurons: usize, length: usize, seed: u64) -> Self {
assert!(length > 0, "bitstream length must be > 0");
assert!(n_inputs > 0, "n_inputs must be > 0");
assert!(n_neurons > 0, "n_neurons must be > 0");
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut weights = vec![vec![0.0; n_inputs]; n_neurons];
for row in &mut weights {
for p in row {
*p = rng.random::<f64>();
}
}
let mut layer = Self {
n_inputs,
n_neurons,
length,
inv_length: 1.0 / length as f64,
words_per_input: length.div_ceil(64),
weights,
packed_weights_flat: vec![],
weight_seed: seed.wrapping_add(1),
};
layer.refresh_packed_weights();
layer
}
#[inline]
pub fn packed_weights_flat(&self) -> &[u64] {
&self.packed_weights_flat
}
#[inline]
fn weight_slice(&self, neuron: usize, input: usize) -> &[u64] {
let start = (neuron * self.n_inputs + input) * self.words_per_input;
&self.packed_weights_flat[start..start + self.words_per_input]
}
pub fn get_weights(&self) -> Vec<Vec<f64>> {
self.weights.clone()
}
pub fn set_weights(&mut self, weights: Vec<Vec<f64>>) -> Result<(), String> {
if weights.len() != self.n_neurons {
return Err(format!(
"Expected {} rows, got {}.",
self.n_neurons,
weights.len()
));
}
for (row_idx, row) in weights.iter().enumerate() {
if row.len() != self.n_inputs {
return Err(format!(
"Row {} has length {}, expected {}.",
row_idx,
row.len(),
self.n_inputs
));
}
}
self.weights = weights;
self.refresh_packed_weights();
Ok(())
}
pub fn refresh_packed_weights(&mut self) {
let n_inputs = self.n_inputs;
let words = self.words_per_input;
let length = self.length;
let weight_seed = self.weight_seed;
let weights = &self.weights;
let mut packed_weights_flat = vec![0_u64; self.n_neurons * n_inputs * words];
packed_weights_flat
.par_chunks_mut(n_inputs * words)
.enumerate()
.for_each(|(neuron_idx, neuron_chunk)| {
let mut rng =
ChaCha8Rng::seed_from_u64(weight_seed.wrapping_add(neuron_idx as u64));
for (input_idx, input_chunk) in neuron_chunk.chunks_mut(words).enumerate() {
let weight_prob = weights[neuron_idx][input_idx];
if weight_prob <= 0.0 {
input_chunk.fill(0);
} else if weight_prob >= 1.0 {
input_chunk.fill(u64::MAX);
if !length.is_multiple_of(64) {
input_chunk[words - 1] = (1_u64 << (length % 64)) - 1;
}
} else {
let packed =
bitstream::bernoulli_packed_simd(weight_prob, length, &mut rng);
input_chunk.copy_from_slice(&packed);
}
}
});
self.packed_weights_flat = packed_weights_flat;
}
pub fn forward(&self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
if input_values.len() != self.n_inputs {
return Err(format!(
"Expected input of length {}, got {}.",
self.n_inputs,
input_values.len()
));
}
let words = self.words_per_input;
let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
let mut packed_inputs_flat = vec![0_u64; self.n_inputs * words];
for (idx, p) in input_values.iter().copied().enumerate() {
let packed = bitstream::bernoulli_packed(p, self.length, &mut rng);
packed_inputs_flat[idx * words..(idx + 1) * words].copy_from_slice(&packed);
}
let out: Vec<f64> = if self.n_neurons >= RAYON_NEURON_THRESHOLD {
let n_inputs = self.n_inputs;
self.packed_weights_flat
.par_chunks_exact(n_inputs * words)
.map(|neuron_weights| {
let total =
simd::fused_and_popcount_dispatch(neuron_weights, &packed_inputs_flat);
total as f64 * self.inv_length
})
.collect()
} else {
let n_inputs = self.n_inputs;
self.packed_weights_flat
.chunks_exact(n_inputs * words)
.map(|neuron_weights| {
let total =
simd::fused_and_popcount_dispatch(neuron_weights, &packed_inputs_flat);
total as f64 * self.inv_length
})
.collect()
};
Ok(out)
}
pub fn forward_fast(&self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
if input_values.len() != self.n_inputs {
return Err(format!(
"Expected input of length {}, got {}.",
self.n_inputs,
input_values.len()
));
}
let words = self.words_per_input;
let mut packed_inputs_flat = vec![0_u64; self.n_inputs * words];
if self.n_inputs >= RAYON_ENCODE_THRESHOLD {
packed_inputs_flat
.par_chunks_mut(words)
.enumerate()
.for_each(|(idx, chunk)| {
let p = input_values[idx];
let input_seed = seed.wrapping_add(idx as u64);
let mut rng = Xoshiro256PlusPlus::seed_from_u64(input_seed);
let packed = bitstream::bernoulli_packed_simd(p, self.length, &mut rng);
chunk.copy_from_slice(&packed);
});
} else {
packed_inputs_flat
.chunks_mut(words)
.enumerate()
.for_each(|(idx, chunk)| {
let p = input_values[idx];
let input_seed = seed.wrapping_add(idx as u64);
let mut rng = Xoshiro256PlusPlus::seed_from_u64(input_seed);
let packed = bitstream::bernoulli_packed_simd(p, self.length, &mut rng);
chunk.copy_from_slice(&packed);
});
}
let out: Vec<f64> = if self.n_neurons >= RAYON_NEURON_THRESHOLD {
let n_inputs = self.n_inputs;
self.packed_weights_flat
.par_chunks_exact(n_inputs * words)
.map(|neuron_weights| {
let total =
simd::fused_and_popcount_dispatch(neuron_weights, &packed_inputs_flat);
total as f64 * self.inv_length
})
.collect()
} else {
let n_inputs = self.n_inputs;
self.packed_weights_flat
.chunks_exact(n_inputs * words)
.map(|neuron_weights| {
let total =
simd::fused_and_popcount_dispatch(neuron_weights, &packed_inputs_flat);
total as f64 * self.inv_length
})
.collect()
};
Ok(out)
}
pub fn forward_fused(&self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
if input_values.len() != self.n_inputs {
return Err(format!(
"Expected input of length {}, got {}.",
self.n_inputs,
input_values.len()
));
}
let out: Vec<f64> = if self.n_neurons >= RAYON_NEURON_THRESHOLD {
(0..self.n_neurons)
.into_par_iter()
.map(|neuron_idx| {
let total: u64 = input_values
.iter()
.enumerate()
.map(|(input_idx, &p)| {
let input_seed = seed.wrapping_add(input_idx as u64);
let mut rng = Xoshiro256PlusPlus::seed_from_u64(input_seed);
bitstream::encode_and_popcount(
self.weight_slice(neuron_idx, input_idx),
p,
self.length,
&mut rng,
)
})
.sum();
total as f64 * self.inv_length
})
.collect()
} else {
(0..self.n_neurons)
.map(|neuron_idx| {
let total: u64 = input_values
.iter()
.enumerate()
.map(|(input_idx, &p)| {
let input_seed = seed.wrapping_add(input_idx as u64);
let mut rng = Xoshiro256PlusPlus::seed_from_u64(input_seed);
bitstream::encode_and_popcount(
self.weight_slice(neuron_idx, input_idx),
p,
self.length,
&mut rng,
)
})
.sum();
total as f64 * self.inv_length
})
.collect()
};
Ok(out)
}
pub fn forward_batch_into(
&self,
inputs_flat: &[f64],
n_samples: usize,
seed: u64,
output: &mut [f64],
) -> Result<(), String> {
let expected_inputs = n_samples.checked_mul(self.n_inputs).ok_or_else(|| {
"Input size overflow when validating n_samples * n_inputs.".to_string()
})?;
if inputs_flat.len() != expected_inputs {
return Err(format!(
"Expected {} values ({}×{}), got {}.",
expected_inputs,
n_samples,
self.n_inputs,
inputs_flat.len()
));
}
let expected_outputs = n_samples.checked_mul(self.n_neurons).ok_or_else(|| {
"Output size overflow when validating n_samples * n_neurons.".to_string()
})?;
if output.len() != expected_outputs {
return Err(format!(
"Expected output length {} ({}×{}), got {}.",
expected_outputs,
n_samples,
self.n_neurons,
output.len()
));
}
output
.par_chunks_mut(self.n_neurons)
.enumerate()
.for_each(|(sample_idx, out_row)| {
let start = sample_idx * self.n_inputs;
let end = start + self.n_inputs;
let input_row = &inputs_flat[start..end];
let sample_seed = seed.wrapping_add((sample_idx as u64).wrapping_mul(1_000_000));
if let Ok(res) = self.forward_fast(input_row, sample_seed) {
out_row.copy_from_slice(&res);
}
});
Ok(())
}
pub fn forward_batch(
&self,
inputs_flat: &[f64],
n_samples: usize,
seed: u64,
) -> Result<Vec<f64>, String> {
let output_len = n_samples.checked_mul(self.n_neurons).ok_or_else(|| {
"Output size overflow when allocating n_samples * n_neurons.".to_string()
})?;
let mut output = vec![0.0_f64; output_len];
self.forward_batch_into(inputs_flat, n_samples, seed, &mut output)?;
Ok(output)
}
pub fn forward_prepacked(&self, packed_inputs: &[Vec<u64>]) -> Result<Vec<f64>, String> {
if packed_inputs.len() != self.n_inputs {
return Err(format!(
"Expected {} packed inputs, got {}.",
self.n_inputs,
packed_inputs.len()
));
}
let expected_words = self.length.div_ceil(64);
for (idx, pi) in packed_inputs.iter().enumerate() {
if pi.len() != expected_words {
return Err(format!(
"Packed input {} has {} words, expected {}.",
idx,
pi.len(),
expected_words
));
}
}
let out = (0..self.n_neurons)
.into_par_iter()
.map(|neuron_idx| {
let total: u64 = packed_inputs
.iter()
.enumerate()
.map(|(input_idx, input_words)| {
simd::fused_and_popcount_dispatch(
self.weight_slice(neuron_idx, input_idx),
input_words,
)
})
.sum();
total as f64 * self.inv_length
})
.collect();
Ok(out)
}
pub fn forward_prepacked_2d(
&self,
packed_flat: &[u64],
n_inputs: usize,
words: usize,
) -> Result<Vec<f64>, String> {
if n_inputs != self.n_inputs {
return Err(format!(
"Expected {} packed inputs, got {}.",
self.n_inputs, n_inputs
));
}
let expected_words = self.length.div_ceil(64);
if words != expected_words {
return Err(format!(
"Expected {} words per input, got {}.",
expected_words, words
));
}
if packed_flat.len() != n_inputs * words {
return Err(format!(
"Flat buffer length {} != n_inputs({}) * words({}).",
packed_flat.len(),
n_inputs,
words
));
}
let out = (0..self.n_neurons)
.into_par_iter()
.map(|neuron_idx| {
let total: u64 = (0..self.n_inputs)
.map(|input_idx| {
let row_start = input_idx * words;
let input_words = &packed_flat[row_start..row_start + words];
simd::fused_and_popcount_dispatch(
self.weight_slice(neuron_idx, input_idx),
input_words,
)
})
.sum();
total as f64 * self.inv_length
})
.collect();
Ok(out)
}
pub fn forward_numpy_inner(&self, input_values: &[f64], seed: u64) -> Result<Vec<f64>, String> {
self.forward_fast(input_values, seed)
}
}
#[cfg(test)]
mod tests {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use super::DenseLayer;
use crate::bitstream;
#[test]
fn flat_weight_roundtrip() {
let layer = DenseLayer::new(3, 2, 130, 42);
let words = 130_usize.div_ceil(64);
assert_eq!(layer.words_per_input, words);
assert_eq!(layer.packed_weights_flat.len(), 3 * 2 * words);
for neuron in 0..2 {
let mut rng = ChaCha8Rng::seed_from_u64(43 + neuron as u64);
for input in 0..3 {
let expected =
bitstream::bernoulli_packed_simd(layer.weights[neuron][input], 130, &mut rng);
assert_eq!(layer.weight_slice(neuron, input), expected.as_slice());
}
}
}
#[test]
fn forward_fused_matches_forward_fast() {
let layer = DenseLayer::new(16, 8, 1024, 42);
let inputs: Vec<f64> = (0..16).map(|i| (i as f64) / 16.0).collect();
let seed = 999_u64;
let fast = layer
.forward_fast(&inputs, seed)
.expect("forward_fast should succeed");
let fused = layer
.forward_fused(&inputs, seed)
.expect("forward_fused should succeed");
assert_eq!(
fast, fused,
"forward_fused must be bit-identical to forward_fast"
);
}
#[test]
fn forward_batch_matches_sequential_fused() {
let layer = DenseLayer::new(4, 3, 256, 123);
let n_samples = 5;
let inputs_flat: Vec<f64> = (0..(n_samples * 4))
.map(|i| ((i * 17 + 11) % 100) as f64 / 100.0)
.collect();
let seed = 77_u64;
let batch = layer
.forward_batch(&inputs_flat, n_samples, seed)
.expect("forward_batch should succeed");
for sample_idx in 0..n_samples {
let row = &inputs_flat[sample_idx * 4..(sample_idx + 1) * 4];
let sample_seed = seed.wrapping_add((sample_idx as u64).wrapping_mul(1_000_000));
let expected = layer
.forward_fused(row, sample_seed)
.expect("forward_fused should succeed");
let got = &batch[sample_idx * 3..(sample_idx + 1) * 3];
assert_eq!(got, expected.as_slice(), "sample_idx={sample_idx}");
}
}
}