use native_neural_network_std::std::activations_std::ActivationKind;
use native_neural_network_std::std::engine_std::forward_plan;
use native_neural_network_std::std::layers_std as layers_std_mod;
use native_neural_network_std::std::layers_std::{DenseLayerDesc, LayerPlanStd, LayerSpec};
#[test]
fn cpu_backend_abstraction_real() {
let previous = native_neural_network_std::std::engine_std::get_compute_backend();
native_neural_network_std::std::engine_std::set_compute_backend(
native_neural_network_std::std::engine_std::ComputeBackend::Cpu,
);
assert_eq!(
native_neural_network_std::std::engine_std::get_compute_backend(),
native_neural_network_std::std::engine_std::ComputeBackend::Cpu
);
let layers = vec![
native_neural_network_std::std::layers_std::LayerSpec::Dense(
native_neural_network_std::std::layers_std::DenseLayerDesc {
input_size: 2,
output_size: 1,
weight_offset: 0,
bias_offset: 0,
activation:
native_neural_network_std::std::activations_std::ActivationKind::Identity,
},
),
];
let plan = native_neural_network_std::std::layers_std::LayerPlanStd::new(
layers,
vec![2.0f32, 3.0f32],
vec![1.0f32],
);
let input = vec![1.0f32, 2.0f32];
let mut output = vec![0.0f32; 1];
let mut scratch =
vec![
0.0f32;
native_neural_network_std::std::engine_std::required_batch_scratch_len(&plan, 1)
.unwrap_or(0)
];
native_neural_network_std::std::engine_std::forward_plan_big_kernel(
&plan,
&input,
&mut output,
1,
&mut scratch,
)
.expect("cpu forward");
assert!((output[0] - 9.0).abs() < 1e-6);
native_neural_network_std::std::engine_std::set_compute_backend(previous);
}
#[test]
fn validate_forward_io_mismatch() {
let layers = vec![LayerSpec::Dense(DenseLayerDesc {
input_size: 2,
output_size: 2,
weight_offset: 0,
bias_offset: 0,
activation: ActivationKind::Identity,
})];
let weights = vec![0f32; 4];
let biases = vec![0f32; 2];
let plan = LayerPlanStd::new(layers, weights, biases);
let ok = native_neural_network_std::std::engine_std::validate_forward_io(&plan, 1, 1);
assert!(!ok);
}
#[test]
fn engine_forward_errors_do_not_panic() {
let layers = vec![LayerSpec::Dense(DenseLayerDesc {
input_size: 3,
output_size: 1,
weight_offset: 0,
bias_offset: 0,
activation: ActivationKind::Identity,
})];
let weights = vec![0f32; 3];
let biases = vec![0f32; 1];
let plan = LayerPlanStd::new(layers, weights, biases);
let mut out = vec![0f32; 1];
let mut scratch = vec![0f32; 1];
let res = native_neural_network_std::std::engine_std::forward_plan(
&plan,
&[0f32; 3],
&mut out,
&mut scratch,
);
assert!(res.is_err());
}
#[test]
fn native_conversion_roundtrip() {
let specs = vec![LayerSpec::Dense(DenseLayerDesc {
input_size: 2,
output_size: 2,
weight_offset: 0,
bias_offset: 0,
activation: ActivationKind::Identity,
})];
let native_buf = layers_std_mod::to_native_vec(&specs);
let used = native_buf.len();
let mut out = vec![
LayerSpec::Dense(DenseLayerDesc {
input_size: 0,
output_size: 0,
weight_offset: 0,
bias_offset: 0,
activation: ActivationKind::Identity
});
used
];
layers_std_mod::fill_std_slice_from_native(&native_buf[..used], &mut out);
assert_eq!(out[0].input_size(), 2);
}
#[test]
fn validate_forward_io_ok() {
let layers = vec![LayerSpec::Dense(DenseLayerDesc {
input_size: 2,
output_size: 1,
weight_offset: 0,
bias_offset: 0,
activation: ActivationKind::Identity,
})];
let weights = vec![1.0f32, 1.0f32];
let biases = vec![0.0f32];
let plan = LayerPlanStd::new(layers, weights, biases);
let ok = native_neural_network_std::std::engine_std::validate_forward_io(&plan, 2, 1);
assert!(ok);
}
#[test]
fn forward_plan_returns_err_on_bad_outputs() {
let layers = vec![LayerSpec::Dense(DenseLayerDesc {
input_size: 3,
output_size: 2,
weight_offset: 0,
bias_offset: 0,
activation: ActivationKind::Identity,
})];
let weights = vec![0f32; 6];
let biases = vec![0f32; 2];
let plan = LayerPlanStd::new(layers, weights, biases);
let mut out = vec![0f32; 1];
let mut scratch = vec![0f32; 16];
let res = forward_plan(&plan, &[0f32; 3], &mut out, &mut scratch);
assert!(res.is_err());
}
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Mutex, OnceLock};
static INIT_HOOK_CALLED: AtomicBool = AtomicBool::new(false);
static INIT_TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
fn init_test_guard() -> std::sync::MutexGuard<'static, ()> {
INIT_TEST_LOCK
.get_or_init(|| Mutex::new(()))
.lock()
.expect("init test lock")
}
fn init_hook_set() -> Result<(), native_neural_network_std::InitError> {
INIT_HOOK_CALLED.store(true, Ordering::SeqCst);
Ok(())
}
fn failing_hook() -> Result<(), native_neural_network_std::InitError> {
Err(native_neural_network_std::InitError::HookFailed(
"failing",
"boom".into(),
))
}
#[test]
fn init_executes_registered_hook() {
let _guard = init_test_guard();
native_neural_network_std::shutdown();
assert!(native_neural_network_std::register_init_hook(
"test_set",
native_neural_network_std::InitSubsystem::Other("test"),
init_hook_set
)
.is_ok());
let res = native_neural_network_std::init();
assert!(res.is_ok());
assert!(INIT_HOOK_CALLED.load(Ordering::SeqCst));
assert!(native_neural_network_std::is_initialized() || INIT_HOOK_CALLED.load(Ordering::SeqCst));
native_neural_network_std::shutdown();
}
#[test]
fn try_init_runs_only_requested_subsystems() {
let _guard = init_test_guard();
native_neural_network_std::shutdown();
assert!(native_neural_network_std::register_init_hook(
"crypto_set",
native_neural_network_std::InitSubsystem::Crypto,
init_hook_set
)
.is_ok());
assert!(native_neural_network_std::register_init_hook(
"failing",
native_neural_network_std::InitSubsystem::Other("bad"),
failing_hook
)
.is_ok());
let res =
native_neural_network_std::try_init(&[native_neural_network_std::InitSubsystem::Crypto]);
assert!(res.is_ok());
assert!(native_neural_network_std::is_initialized());
native_neural_network_std::shutdown();
}
#[test]
fn register_after_init_returns_error() {
let _guard = init_test_guard();
native_neural_network_std::shutdown();
assert!(native_neural_network_std::init().is_ok());
let res = native_neural_network_std::register_init_hook(
"late",
native_neural_network_std::InitSubsystem::Other("x"),
init_hook_set,
);
assert!(res.is_err());
native_neural_network_std::shutdown();
}