use native_neural_network_std as nn;
use nn::modules::activations::ActivationKind;
use nn::std::engine_std;
use nn::std::layers_std::{DenseLayerDesc, LayerPlanStd, LayerSpec};
use nn::std::rnn_std as rstd;
use std::fs;
use std::path::{Path, PathBuf};
use std::time::{Duration, Instant};
const TOPOLOGY: [usize; 6] = [96, 192, 256, 192, 128, 64];
const BATCH_SIZE_SLOT: usize = 2;
const LEARNING_RATE_SLOT: f32 = 0.006;
const WEIGHT_SCALE_SLOT: f32 = 0.008;
const INPUT_VALUES_SLOT: &str = "__INSERT_192_FLOAT_VALUES_COMMA_SEPARATED__";
#[path = "../dataset/generate_dataset.rs"]
mod generate_dataset;
#[derive(Clone)]
struct XorShift32 {
state: u32,
}
impl XorShift32 {
fn new(seed: u32) -> Self {
let initial = if seed == 0 { 0xA341_316C } else { seed };
Self { state: initial }
}
fn next_u32(&mut self) -> u32 {
let mut x = self.state;
x ^= x << 13;
x ^= x >> 17;
x ^= x << 5;
self.state = x;
x
}
fn next_f32_signed(&mut self) -> f32 {
let unit = self.next_u32() as f32 / u32::MAX as f32;
unit * 2.0 - 1.0
}
}
fn output_dir() -> PathBuf {
std::env::var("NNS_OUTPUT_DIR")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from("rnn_examples"))
}
fn ensure_output_dirs() {
fs::create_dir_all(output_dir()).expect("create output directory");
}
fn parse_slot_csv(slot: &str, expected: usize, name: &str) -> Vec<f32> {
let values: Vec<f32> = slot
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| s.parse::<f32>().expect("invalid float in slot"))
.collect();
if values.len() != expected {
panic!("{name} expects {expected} values, got {}", values.len());
}
values
}
fn resolve_input_slot(slot: &str, env_key: &str) -> String {
if !slot.contains("__INSERT") {
return slot.to_string();
}
std::env::var(env_key).unwrap_or_else(|_| {
panic!("template slot unresolved: set {env_key} or replace the __INSERT__ token in source")
})
}
fn required_counts(topology: &[usize]) -> nn::std::model_format_std::DecodedCountsStd {
rstd::rnn_required_from_topology(topology).expect("required counts from topology")
}
fn random_params(
weights_len: usize,
biases_len: usize,
scale: f32,
seed: u32,
) -> (Vec<f32>, Vec<f32>) {
let mut rng = XorShift32::new(seed);
let mut weights = vec![0.0; weights_len];
let mut biases = vec![0.0; biases_len];
for w in &mut weights {
*w = rng.next_f32_signed() * scale;
}
for b in &mut biases {
*b = rng.next_f32_signed() * scale * 0.12;
}
(weights, biases)
}
fn make_layer_scratch(count: usize) -> Vec<LayerSpec> {
vec![
LayerSpec::Dense(DenseLayerDesc {
input_size: 0,
output_size: 0,
weight_offset: 0,
bias_offset: 0,
activation: ActivationKind::Identity,
});
count
]
}
fn model_plan(topology: &[usize], weights: &[f32], biases: &[f32]) -> LayerPlanStd {
let specs = nn::std::layers_std::build_specs_from_layers(
topology,
ActivationKind::Identity,
ActivationKind::Identity,
weights.len(),
biases.len(),
)
.expect("build dense specs");
LayerPlanStd::new(specs, weights.to_vec(), biases.to_vec())
}
fn run_vectorized(
topology: &[usize],
weights: &[f32],
biases: &[f32],
input_batch: &[f32],
batch: usize,
) -> Vec<f32> {
let plan = model_plan(topology, weights, biases);
let output_size = *topology.last().expect("output size");
let mut out = vec![0f32; batch * output_size];
let scratch_len =
engine_std::required_batch_scratch_len(&plan, batch).expect("required batch scratch");
let mut scratch = vec![0f32; scratch_len];
engine_std::forward_batch_big_kernel(&plan, input_batch, &mut out, batch, &mut scratch)
.expect("vectorized forward");
out
}
fn eval_reconstruction_loss(
topology: &[usize],
output_batch: &[f32],
input_batch: &[f32],
batch: usize,
) -> f32 {
let input_size = topology[0];
let output_size = *topology.last().expect("output size");
let common = input_size.min(output_size);
if common == 0 || batch == 0 {
return 0.0;
}
let mut mse = 0.0f32;
let mut count = 0usize;
for b in 0..batch {
for i in 0..common {
let t = input_batch[b * input_size + i] * 0.75;
let y = output_batch[b * output_size + i];
let d = y - t;
mse += d * d;
count += 1;
}
}
mse / count.max(1) as f32
}
fn normalize_batch_in_place(values: &mut [f32], feature_dim: usize, batch: usize) {
if feature_dim == 0 || batch == 0 {
return;
}
for f in 0..feature_dim {
let mut mean = 0.0f32;
for b in 0..batch {
mean += values[b * feature_dim + f];
}
mean /= batch as f32;
let mut var = 0.0f32;
for b in 0..batch {
let d = values[b * feature_dim + f] - mean;
var += d * d;
}
let std = (var / batch as f32 + 1e-6).sqrt();
for b in 0..batch {
values[b * feature_dim + f] = (values[b * feature_dim + f] - mean) / std;
}
}
}
#[derive(Clone, Copy)]
struct TrainConfig {
runtime_seconds: u64,
batch: usize,
learning_rate: f32,
seed: u32,
}
fn self_train_parallel_vectorized(
topology: &[usize],
weights: &mut [f32],
biases: &mut [f32],
base_input_batch: &[f32],
cfg: TrainConfig,
) -> f32 {
let input_size = topology[0];
let output_size = *topology.last().expect("output size");
let mut rng = XorShift32::new(cfg.seed ^ 0x77AA_1133);
let mut last_loss = 0.0f32;
let start = Instant::now();
let budget = Duration::from_secs(cfg.runtime_seconds.max(1));
let mut iters = 0usize;
let mut best_metric = f32::INFINITY;
let mut best_weights = weights.to_vec();
let mut best_biases = biases.to_vec();
let mut no_improve = 0usize;
let mut last_log = Instant::now() - Duration::from_secs(2);
let patience = generate_dataset::train_early_stop_patience();
let min_delta = generate_dataset::train_min_delta();
let lr_decay = generate_dataset::train_lr_decay();
while start.elapsed() < budget || iters == 0 {
iters += 1;
let lr_now = cfg.learning_rate / (1.0 + lr_decay * iters as f32);
let mut epoch_input = base_input_batch.to_vec();
for v in &mut epoch_input {
*v += rng.next_f32_signed() * 0.003;
}
normalize_batch_in_place(&mut epoch_input, input_size, cfg.batch);
let outputs = run_vectorized(topology, weights, biases, &epoch_input, cfg.batch);
last_loss = eval_reconstruction_loss(topology, &outputs, &epoch_input, cfg.batch);
generate_dataset::log_progress_every_second(
"very_complex",
iters,
last_loss,
start,
budget,
&mut last_log,
);
let common = input_size.min(output_size);
let mut err_by_output = vec![0.0f32; output_size.max(1)];
for b in 0..cfg.batch {
for o in 0..output_size {
let target = if o < common {
epoch_input[b * input_size + o] * 0.75
} else {
0.0
};
let err = outputs[b * output_size + o] - target;
err_by_output[o] += err;
}
}
for e in &mut err_by_output {
*e /= cfg.batch.max(1) as f32;
}
for (i, b) in biases.iter_mut().enumerate() {
let e = err_by_output[i % err_by_output.len()];
*b -= lr_now * e;
}
let threads = generate_dataset::train_worker_threads();
let chunk = (weights.len() / threads).max(1);
std::thread::scope(|scope| {
for (chunk_idx, wchunk) in weights.chunks_mut(chunk).enumerate() {
let err_ref = &err_by_output;
let mut local_rng = XorShift32::new(
cfg.seed ^ (chunk_idx as u32).wrapping_mul(0x85EB_CA6B) ^ rng.next_u32(),
);
scope.spawn(move || {
for (j, w) in wchunk.iter_mut().enumerate() {
let global_idx = chunk_idx * chunk + j;
let e = err_ref[global_idx % err_ref.len()];
let sign = if local_rng.next_u32() & 1 == 0 {
1.0
} else {
-1.0
};
*w = (*w * 0.9994) - lr_now * 0.09 * e * sign;
}
});
}
});
if best_metric - last_loss > min_delta {
best_metric = last_loss;
best_weights.copy_from_slice(weights);
best_biases.copy_from_slice(biases);
no_improve = 0;
} else {
no_improve += 1;
if no_improve >= patience {
break;
}
}
}
weights.copy_from_slice(&best_weights);
biases.copy_from_slice(&best_biases);
last_loss = best_metric.min(last_loss);
println!("very_complex training iterations={}", iters);
last_loss
}
fn write_container_model(path: &Path, topology: &[usize], weights: &[f32], biases: &[f32]) {
let counts = required_counts(topology);
let mut layer_specs = make_layer_scratch(counts.layers);
let mut out_bytes =
vec![0u8; counts.weights * 4 + counts.biases * 4 + counts.layers * 64 + 4096];
let written = loop {
match rstd::rnn_pack_v1(
topology,
ActivationKind::Identity,
ActivationKind::Identity,
weights,
biases,
&mut layer_specs,
&mut out_bytes,
) {
Ok(w) => break w,
Err(native_neural_network::rnn_api::RnnApiError::CapacityTooSmall) => {
let new_len = out_bytes.len().saturating_mul(2).max(8192);
out_bytes.resize(new_len, 0);
}
Err(e) => panic!("pack dense model: {:?}", e),
}
};
let container = rstd::rnn_wrap_payload_in_container_std(&out_bytes[..written], "model");
fs::write(path, container).expect("write container model");
}
fn run_sync() {
ensure_output_dirs();
let topology = TOPOLOGY.to_vec();
let expected_input_values = BATCH_SIZE_SLOT * topology[0];
let input_csv = resolve_input_slot(INPUT_VALUES_SLOT, "NNS_VERY_COMPLEX_INPUT_VALUES");
let input_batch = parse_slot_csv(&input_csv, expected_input_values, "INPUT_VALUES_SLOT");
let counts = required_counts(&topology);
let (mut weights, mut biases) = random_params(
counts.weights,
counts.biases,
WEIGHT_SCALE_SLOT,
0x5500_0001,
);
let train_cfg = TrainConfig {
runtime_seconds: generate_dataset::train_runtime_seconds(),
batch: BATCH_SIZE_SLOT,
learning_rate: LEARNING_RATE_SLOT,
seed: 0x55AB_CDEF,
};
let loss = self_train_parallel_vectorized(
&topology,
&mut weights,
&mut biases,
&input_batch,
train_cfg,
);
let out_path = output_dir().join("very_complex.rnn");
write_container_model(&out_path, &topology, &weights, &biases);
println!(
"very_complex template -> {} | topology={:?} batch={} runtime={}s loss={:.6e}",
out_path.display(),
topology,
BATCH_SIZE_SLOT,
generate_dataset::train_runtime_seconds(),
loss,
);
}
fn main() {
run_sync();
}