native_neural_network_std 0.2.1

Ergonomic std wrapper for the `native_neural_network` crate (no_std) — std-friendly re-exports and utilities.
Documentation
use native_neural_network_std as nn;
mod utils {
    include!("test_utils.rs");
}

#[test]
fn cpu_backend_abstraction_real() {
    let previous = native_neural_network_std::std::engine_std::get_compute_backend();
    native_neural_network_std::std::engine_std::set_compute_backend(
        native_neural_network_std::std::engine_std::ComputeBackend::Cpu,
    );
    assert_eq!(
        native_neural_network_std::std::engine_std::get_compute_backend(),
        native_neural_network_std::std::engine_std::ComputeBackend::Cpu
    );
    let layers = vec![
        native_neural_network_std::std::layers_std::LayerSpec::Dense(
            native_neural_network_std::std::layers_std::DenseLayerDesc {
                input_size: 2,
                output_size: 1,
                weight_offset: 0,
                bias_offset: 0,
                activation:
                    native_neural_network_std::std::activations_std::ActivationKind::Identity,
            },
        ),
    ];
    let plan = native_neural_network_std::std::layers_std::LayerPlanStd::new(
        layers,
        vec![2.0f32, 3.0f32],
        vec![1.0f32],
    );
    let input = vec![1.0f32, 2.0f32];
    let mut output = vec![0.0f32; 1];
    let mut scratch =
        vec![
            0.0f32;
            native_neural_network_std::std::engine_std::required_batch_scratch_len(&plan, 1)
                .unwrap_or(0)
        ];
    native_neural_network_std::std::engine_std::forward_plan_big_kernel(
        &plan,
        &input,
        &mut output,
        1,
        &mut scratch,
    )
    .expect("cpu forward");
    assert!((output[0] - 9.0).abs() < 1e-6);
    native_neural_network_std::std::engine_std::set_compute_backend(previous);
}

#[test]
fn deterministic_same_seed_runs_equal() {
    let seed = 0xfeedfaceu32;
    let mut r1 = utils::XorShift32::new(seed);
    let mut r2 = utils::XorShift32::new(seed);
    let len = 64usize;
    let mut v1 = vec![0f32; len];
    let mut v2 = vec![0f32; len];
    for i in 0..len {
        v1[i] = r1.next_f32();
        v2[i] = r2.next_f32();
    }
    for i in 0..len {
        assert!(utils::approx_eq(v1[i], v2[i], 0.0));
    }
    let layers = vec![16usize, 32usize, 8usize];
    let wcount =
        nn::std::network_std::NeuralNetworkStd::expected_weights_count(&layers).unwrap_or(0);
    let bcount =
        nn::std::network_std::NeuralNetworkStd::expected_biases_count(&layers).unwrap_or(0);
    let mut r3 = utils::XorShift32::new(seed);
    let mut weights_a = vec![0f32; wcount];
    let mut biases_a = vec![0f32; bcount];
    for i in 0..wcount {
        weights_a[i] = r3.next_f32();
    }
    for i in 0..bcount {
        biases_a[i] = r3.next_f32();
    }
    let mut r4 = utils::XorShift32::new(seed);
    let mut weights_b = vec![0f32; wcount];
    let mut biases_b = vec![0f32; bcount];
    for i in 0..wcount {
        weights_b[i] = r4.next_f32();
    }
    for i in 0..bcount {
        biases_b[i] = r4.next_f32();
    }
    assert_eq!(weights_a, weights_b);
    assert_eq!(biases_a, biases_b);
    let net_a = nn::std::network_std::NeuralNetworkStd::from_parts(
        layers.clone(),
        weights_a.clone(),
        biases_a.clone(),
    )
    .expect("construct");
    let net_b = nn::std::network_std::NeuralNetworkStd::from_parts(
        layers.clone(),
        weights_b.clone(),
        biases_b.clone(),
    )
    .expect("construct");
    assert!(nn::std::network_std::validate_network_parts(
        &layers,
        &net_a.weights,
        &net_a.biases
    ));
    assert!(nn::std::network_std::validate_network_parts(
        &layers,
        &net_b.weights,
        &net_b.biases
    ));
    let sa = nn::std::network_std::network_stats(&layers, &net_a.weights, &net_a.biases);
    let sb = nn::std::network_std::network_stats(&layers, &net_b.weights, &net_b.biases);
    assert_eq!(
        sa.map(|s| (s.total_weights, s.total_biases)),
        sb.map(|s| (s.total_weights, s.total_biases))
    );
}