use crate::error::{AprenderError, Result};
pub use trueno::simulation::{
BackendCategory, BackendSelector, BackendTolerance, JidokaAction, JidokaCondition, JidokaError,
JidokaGuard, StressTestConfig, StressThresholds,
};
pub use trueno::{Backend, Matrix, Vector};
const GPU_THRESHOLD: usize = 100_000;
const PARALLEL_THRESHOLD: usize = 1_000;
#[must_use]
pub fn select_backend(size: usize, gpu_available: bool) -> BackendCategory {
contract_pre_backend_selection!();
contract_pre_gpu_detection_accuracy!();
let result = if size < PARALLEL_THRESHOLD {
BackendCategory::SimdOnly
} else if size < GPU_THRESHOLD {
BackendCategory::SimdParallel
} else if gpu_available {
BackendCategory::Gpu
} else {
BackendCategory::SimdParallel };
contract_post_backend_selection!(&result);
contract_post_gpu_detection_accuracy!(&result);
result
}
#[must_use]
pub fn should_use_gpu(size: usize) -> bool {
size >= GPU_THRESHOLD
}
#[must_use]
pub fn should_use_parallel(size: usize) -> bool {
size >= PARALLEL_THRESHOLD
}
#[derive(Debug, Clone)]
pub struct TrainingGuard {
nan_guard: JidokaGuard,
inf_guard: JidokaGuard,
context: String,
}
impl TrainingGuard {
#[must_use]
pub fn new(context: impl Into<String>) -> Self {
let ctx = context.into();
Self {
nan_guard: JidokaGuard::nan_guard(format!("{ctx}:nan")),
inf_guard: JidokaGuard::inf_guard(format!("{ctx}:inf")),
context: ctx,
}
}
pub fn check_gradients(&self, gradients: &[f32]) -> Result<()> {
self.check_values(gradients, "gradients")
}
pub fn check_weights(&self, weights: &[f32]) -> Result<()> {
self.check_values(weights, "weights")
}
pub fn check_loss(&self, loss: f32) -> Result<()> {
if loss.is_nan() {
return Err(AprenderError::ValidationError {
message: format!("Jidoka: NaN loss detected at {}", self.context),
});
}
if loss.is_infinite() {
return Err(AprenderError::ValidationError {
message: format!("Jidoka: Infinite loss detected at {}", self.context),
});
}
Ok(())
}
fn check_values(&self, values: &[f32], kind: &str) -> Result<()> {
self.nan_guard
.check_output(values)
.map_err(|e| AprenderError::ValidationError {
message: format!("Jidoka: NaN in {kind} at {}: {e}", self.context),
})?;
self.inf_guard
.check_output(values)
.map_err(|e| AprenderError::ValidationError {
message: format!("Jidoka: Inf in {kind} at {}: {e}", self.context),
})?;
Ok(())
}
pub fn check_f64(&self, values: &[f64], kind: &str) -> Result<()> {
for (i, &v) in values.iter().enumerate() {
if v.is_nan() {
return Err(AprenderError::ValidationError {
message: format!("Jidoka: NaN in {kind}[{i}] at {}", self.context),
});
}
if v.is_infinite() {
return Err(AprenderError::ValidationError {
message: format!("Jidoka: Inf in {kind}[{i}] at {}", self.context),
});
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct DivergenceGuard {
guard: JidokaGuard,
}
impl DivergenceGuard {
#[must_use]
pub fn new(tolerance: f32, context: impl Into<String>) -> Self {
Self {
guard: JidokaGuard::divergence_guard(tolerance, context),
}
}
#[must_use]
pub fn default_tolerance(context: impl Into<String>) -> Self {
Self::new(1e-5, context)
}
pub fn check(&self, a: &[f32], b: &[f32]) -> Result<()> {
self.guard
.check_divergence(a, b)
.map_err(|e| AprenderError::ValidationError {
message: format!("Backend divergence: {e}"),
})
}
}
#[derive(Debug, Clone, Copy)]
pub struct ExperimentSeed {
pub master: u64,
pub data_shuffle: u64,
pub weight_init: u64,
pub dropout: u64,
}
impl ExperimentSeed {
#[must_use]
pub fn from_master(master: u64) -> Self {
Self {
master,
data_shuffle: master.wrapping_mul(6_364_136_223_846_793_005),
weight_init: master.wrapping_mul(1_442_695_040_888_963_407),
dropout: master.wrapping_mul(2_685_821_657_736_338_717),
}
}
#[must_use]
pub const fn new(master: u64, data_shuffle: u64, weight_init: u64, dropout: u64) -> Self {
Self {
master,
data_shuffle,
weight_init,
dropout,
}
}
}
impl Default for ExperimentSeed {
fn default() -> Self {
Self::from_master(42)
}
}
#[cfg(test)]
#[path = "compute_tests.rs"]
mod tests;