use rayon::prelude::*;
use crate::math::tensor::rank_2::vector_list::VectorList;
use crate::math::tensor::{RandType, TensorRandFiller};
#[derive(Debug, Clone, PartialEq)]
pub enum VectorSamplingMethod<'a> {
Uniform { low: f64, high: f64 },
UniformCentered {
box_size: &'a [f64],
},
GaussianPerAxis {
mean: &'a [f64],
std: &'a [f64],
},
JitteredLattice {
spacings: &'a [f64],
sigmas: &'a [f64],
},
}
#[derive(Debug, Clone, PartialEq)]
pub enum VectorSamplingError {
InvalidUniformBounds {
low: f64,
high: f64,
},
InvalidParameterLength {
parameter: &'static str,
expected: usize,
got: usize,
},
InvalidParameterValue {
parameter: &'static str,
index: usize,
value: f64,
rule: &'static str,
},
}
pub fn sample_vectors(
vectors: &mut VectorList<f64>,
method: VectorSamplingMethod<'_>,
) -> Result<(), VectorSamplingError> {
let dim = vectors.dim();
let n = vectors.num_vectors();
if dim == 0 || n == 0 {
return Ok(());
}
match method {
VectorSamplingMethod::Uniform { low, high } => {
if !low.is_finite() || !high.is_finite() || low >= high {
return Err(VectorSamplingError::InvalidUniformBounds { low, high });
}
let mut filler = TensorRandFiller::new(RandType::Uniform { low, high }, None);
filler.refresh(vectors.as_tensor_mut());
}
VectorSamplingMethod::UniformCentered { box_size } => {
validate_len("box_size", box_size.len(), dim)?;
validate_finite_nonnegative("box_size", box_size)?;
let mut filler = TensorRandFiller::new(
RandType::Uniform {
low: 0.0,
high: 1.0,
},
None,
);
filler.refresh(vectors.as_tensor_mut());
vectors
.as_tensor_mut()
.data
.par_chunks_mut(dim)
.for_each(|row| {
for k in 0..dim {
let half_span = 0.5 * box_size[k];
row[k] = (2.0 * row[k] - 1.0) * half_span;
}
});
}
VectorSamplingMethod::GaussianPerAxis { mean, std } => {
validate_len("mean", mean.len(), dim)?;
validate_len("std", std.len(), dim)?;
validate_finite("mean", mean)?;
validate_finite_nonnegative("std", std)?;
let mut filler = TensorRandFiller::new(
RandType::Normal {
mean: 0.0,
std: 1.0,
},
None,
);
filler.refresh(vectors.as_tensor_mut());
vectors
.as_tensor_mut()
.data
.par_chunks_mut(dim)
.for_each(|row| {
for k in 0..dim {
row[k] = mean[k] + row[k] * std[k];
}
});
}
VectorSamplingMethod::JitteredLattice { spacings, sigmas } => {
validate_len("spacings", spacings.len(), dim)?;
validate_len("sigmas", sigmas.len(), dim)?;
validate_finite_nonnegative("spacings", spacings)?;
validate_finite_nonnegative("sigmas", sigmas)?;
let mut filler = TensorRandFiller::new(
RandType::Normal {
mean: 0.0,
std: 1.0,
},
None,
);
filler.refresh(vectors.as_tensor_mut());
let side = ((n as f64).powf(1.0 / dim as f64).ceil() as usize).max(1);
vectors
.as_tensor_mut()
.data
.par_chunks_mut(dim)
.enumerate()
.for_each(|(vector_idx, row)| {
let mut lattice_idx = vector_idx;
for k in 0..dim {
let grid_coord = lattice_idx % side;
lattice_idx /= side;
let base = grid_coord as f64 * spacings[k];
row[k] = base + row[k] * sigmas[k];
}
});
}
}
Ok(())
}
fn validate_len(
parameter: &'static str,
got: usize,
expected: usize,
) -> Result<(), VectorSamplingError> {
if got != expected {
return Err(VectorSamplingError::InvalidParameterLength {
parameter,
expected,
got,
});
}
Ok(())
}
fn validate_finite(parameter: &'static str, values: &[f64]) -> Result<(), VectorSamplingError> {
for (index, &value) in values.iter().enumerate() {
if !value.is_finite() {
return Err(VectorSamplingError::InvalidParameterValue {
parameter,
index,
value,
rule: "finite",
});
}
}
Ok(())
}
fn validate_finite_nonnegative(
parameter: &'static str,
values: &[f64],
) -> Result<(), VectorSamplingError> {
for (index, &value) in values.iter().enumerate() {
if !value.is_finite() || value < 0.0 {
return Err(VectorSamplingError::InvalidParameterValue {
parameter,
index,
value,
rule: "finite and non-negative",
});
}
}
Ok(())
}