mod common;
use native_neural_network::modules::{
activations::ActivationKind,
layers::{DenseLayerDesc, LayerSpec},
losses::LossKind,
trainer::{DenseSgdConfig, TrainError},
};
use std::fs;
use std::path::Path;
use std::time::{Duration, Instant};
const LAYER_META_SIZE: usize = 20;
const MODEL_FILE: &str = "small.rnn";
const TRAIN_SECONDS: u64 = 30;
const TRAIN_ORDER: common::TrainingOrder = common::TrainingOrder(0);
const CPU_TARGET_UTIL: f64 = 0.75;
const TRAIN_SAMPLES_PER_CYCLE: usize = 4;
const LIVE_SNAPSHOT_INTERVAL_MS: u64 = 200;
const LEARNING_RATE: f32 = 0.01;
const GRADIENT_CLIP: Option<f32> = Some(1.0);
fn main() {
if let Err(err) = run_training_pipeline() {
eprintln!("training error: {err}");
std::process::exit(1);
}
}
fn run_training_pipeline() -> Result<(), String> {
train_one_precision(Precision::F32)?;
train_one_precision(Precision::F64)?;
Ok(())
}
#[derive(Clone, Copy)]
enum Precision {
F32,
F64,
}
impl Precision {
fn folder(self) -> &'static str {
match self {
Self::F32 => "f32",
Self::F64 => "f64",
}
}
}
fn precision_elem_size(precision: Precision) -> u64 {
match precision {
Precision::F32 => core::mem::size_of::<f32>() as u64,
Precision::F64 => core::mem::size_of::<f64>() as u64,
}
}
fn train_one_precision(precision: Precision) -> Result<(), String> {
eprintln!("[train_small_model] start precision={}", precision.folder());
let trained_dir = Path::new("trained").join(precision.folder());
let source_small_path =
common::ensure_source_model(&trained_dir, precision.folder(), MODEL_FILE)?;
let source_bytes = fs::read(&source_small_path)
.map_err(|e| format!("failed to read {}: {e}", source_small_path.display()))?;
let decoded = decode_native_dense_to_f32(&source_bytes, precision)?;
let topology = decoded.topology;
let hidden_activation = decoded.hidden_activation;
let output_activation = decoded.output_activation;
let weights = decoded.weights_f32;
let biases = decoded.biases_f32;
let baseline_output_bytes = source_bytes.len();
eprintln!(
"[train_small_model] warm-start from {}",
source_small_path.display()
);
eprintln!(
"[train_small_model] topology for {}: {:?}",
precision.folder(),
topology
);
eprintln!(
"[train_small_model] creating trained dir {}",
trained_dir.display()
);
fs::create_dir_all(&trained_dir)
.map_err(|e| format!("failed to create {}: {e}", trained_dir.display()))?;
let benchmark_dir = trained_dir.join("benchmark");
fs::create_dir_all(&benchmark_dir)
.map_err(|e| format!("failed to create {}: {e}", benchmark_dir.display()))?;
let stem = MODEL_FILE.strip_suffix(".rnn").unwrap_or(MODEL_FILE);
let resume = common::load_resume_cursor(&benchmark_dir, stem)?;
let mut live_snapshot =
common::create_live_snapshot(&benchmark_dir, stem, LIVE_SNAPSHOT_INTERVAL_MS)?;
eprintln!(
"[train_small_model] starting training loop for {} seconds",
TRAIN_SECONDS
);
let train_result = train_f32_model(
&topology,
hidden_activation,
output_activation,
weights,
biases,
precision.folder(),
resume.iterations,
resume.elapsed,
&mut live_snapshot,
)?;
eprintln!("[train_small_model] training finished: iterations={} elapsed={:?} avg_loss={} last_loss={}", train_result.iterations, train_result.elapsed, train_result.avg_loss, train_result.last_loss);
let trained_path = trained_dir.join(MODEL_FILE);
let _out_bytes = common::finalize_benchmark_outputs(
common::BenchmarkFinalizeRequest {
model_file: MODEL_FILE,
stem,
topology: &topology,
precision_label: precision.folder(),
element_size_bytes: precision_elem_size(precision),
iterations: train_result.iterations,
elapsed: train_result.elapsed,
avg_loss: train_result.avg_loss,
last_loss: train_result.last_loss,
train_samples_per_cycle: TRAIN_SAMPLES_PER_CYCLE,
baseline_output_bytes,
trained_path: trained_path.clone(),
benchmark_dir: benchmark_dir.clone(),
},
|benchmark_blob| {
eprintln!(
"[train_small_model] packing model (direct rnn_api) for precision={}",
precision.folder()
);
match precision {
Precision::F32 => common::pack_dense_native_with_benchmark_f32(
&topology,
hidden_activation,
output_activation,
&train_result.weights,
&train_result.biases,
benchmark_blob,
),
Precision::F64 => {
let weights_f64: Vec<f64> = train_result
.weights
.iter()
.copied()
.map(f64::from)
.collect();
let biases_f64: Vec<f64> =
train_result.biases.iter().copied().map(f64::from).collect();
common::pack_dense_native_with_benchmark_f64(
&topology,
hidden_activation,
output_activation,
&weights_f64,
&biases_f64,
benchmark_blob,
)
}
}
},
)?;
Ok(())
}
struct TrainResultF32 {
weights: Vec<f32>,
biases: Vec<f32>,
iterations: usize,
elapsed: Duration,
avg_loss: f32,
last_loss: f32,
}
struct DecodedModelF32 {
topology: Vec<usize>,
hidden_activation: ActivationKind,
output_activation: ActivationKind,
weights_f32: Vec<f32>,
biases_f32: Vec<f32>,
}
fn parse_topology_from_layer_meta(
layer_meta: &[u8],
) -> Result<(Vec<usize>, ActivationKind, ActivationKind), String> {
if layer_meta.len() < LAYER_META_SIZE || !layer_meta.len().is_multiple_of(LAYER_META_SIZE) {
return Err("invalid layer_meta length".to_string());
}
let layer_count = layer_meta.len() / LAYER_META_SIZE;
let mut topology = Vec::with_capacity(layer_count + 1);
let mut hidden_activation = ActivationKind::Identity;
let mut output_activation = ActivationKind::Identity;
for idx in 0..layer_count {
let base = idx * LAYER_META_SIZE;
let input_size = u32::from_le_bytes([
layer_meta[base],
layer_meta[base + 1],
layer_meta[base + 2],
layer_meta[base + 3],
]) as usize;
let output_size = u32::from_le_bytes([
layer_meta[base + 4],
layer_meta[base + 5],
layer_meta[base + 6],
layer_meta[base + 7],
]) as usize;
let activation_byte = layer_meta[base + 16];
let activation = ActivationKind::from_u8(activation_byte)
.ok_or_else(|| "invalid activation byte".to_string())?;
if idx == 0 {
topology.push(input_size);
if layer_count > 1 {
hidden_activation = activation;
}
} else {
let prev_out = *topology
.last()
.ok_or_else(|| "invalid topology".to_string())?;
if prev_out != input_size {
return Err("layer chain mismatch in layer_meta".to_string());
}
if idx < layer_count - 1 && activation != hidden_activation {
return Err("non-uniform hidden activation; unsupported by trainer".to_string());
}
}
topology.push(output_size);
if idx == layer_count - 1 {
output_activation = activation;
}
}
Ok((topology, hidden_activation, output_activation))
}
fn decode_native_dense_to_f32(
bytes: &[u8],
precision: Precision,
) -> Result<DecodedModelF32, String> {
let bytes = common::extract_last_rnn_snapshot(bytes)?;
if bytes.len() < 12 {
return Err("invalid rnn bytes: too short".to_string());
}
if &bytes[0..4] != b"RNN\x00" {
return Err("invalid rnn bytes: bad magic".to_string());
}
let (layer_meta, weights_blob, biases_blob) = common::extract_core_dense_blobs(&bytes)?;
if layer_meta.len() % LAYER_META_SIZE != 0 {
return Err("invalid layer_meta length".to_string());
}
let (topology, hidden_activation, output_activation) =
parse_topology_from_layer_meta(layer_meta)?;
let (weights_f32, biases_f32) = match precision {
Precision::F32 => (parse_f32_blob(weights_blob)?, parse_f32_blob(biases_blob)?),
Precision::F64 => (
parse_f64_blob_to_f32(weights_blob)?,
parse_f64_blob_to_f32(biases_blob)?,
),
};
let (expected_weights, expected_biases) = count_params(&topology);
if weights_f32.len() != expected_weights || biases_f32.len() != expected_biases {
return Err("weights/biases size mismatch with topology".to_string());
}
Ok(DecodedModelF32 {
topology: topology.to_vec(),
hidden_activation,
output_activation,
weights_f32,
biases_f32,
})
}
#[allow(clippy::too_many_arguments)]
fn train_f32_model(
topology: &[usize],
hidden_activation: ActivationKind,
output_activation: ActivationKind,
mut weights: Vec<f32>,
mut biases: Vec<f32>,
precision_label: &'static str,
start_iteration: usize,
elapsed_offset: Duration,
live_snapshot: &mut common::LiveBenchmarkSnapshot,
) -> Result<TrainResultF32, String> {
let layer_count = topology.len().saturating_sub(1);
let train_buf_len = required_train_len(topology)?;
let mut layer_specs_scratch = vec![
LayerSpec::Dense(DenseLayerDesc {
input_size: 1,
output_size: 1,
weight_offset: 0,
bias_offset: 0,
activation: ActivationKind::Identity,
});
layer_count
];
let mut activations_scratch = vec![0.0f32; train_buf_len];
let mut deltas_scratch = vec![0.0f32; train_buf_len];
let config = DenseSgdConfig {
learning_rate: LEARNING_RATE,
hidden_activation,
output_activation,
loss: LossKind::Mse,
gradient_clip: GRADIENT_CLIP,
};
let input_size = topology[0];
let output_size = *topology
.last()
.ok_or_else(|| "invalid topology".to_string())?;
let start = Instant::now();
let mut iterations = start_iteration;
let mut run_iterations = 0usize;
let mut loss_sum = 0.0f32;
let mut last_loss = 0.0f32;
let mut input_buf = vec![0.0f32; input_size];
let mut target_buf = vec![0.0f32; output_size];
let mut compute_elapsed = Duration::ZERO;
while start.elapsed() < Duration::from_secs(TRAIN_SECONDS) {
let cycle_compute_start = Instant::now();
let batch_samples = common::build_parallel_factored_batch(
iterations,
input_size,
output_size,
TRAIN_SAMPLES_PER_CYCLE,
TRAIN_ORDER,
common::kernel_config_for_model(MODEL_FILE),
)?;
for sample in batch_samples {
input_buf.copy_from_slice(&sample.input);
target_buf.copy_from_slice(&sample.target);
let loss = common::train_dense_step(
topology,
&mut weights,
&mut biases,
&input_buf,
&target_buf,
&mut layer_specs_scratch,
&mut activations_scratch,
&mut deltas_scratch,
config,
)
.map_err(map_train_error)?;
last_loss = loss;
loss_sum += loss;
iterations = iterations.saturating_add(1);
run_iterations = run_iterations.saturating_add(1);
if start.elapsed() >= Duration::from_secs(TRAIN_SECONDS) {
break;
}
}
common::pace_cpu_target_utilization(
start,
&mut compute_elapsed,
cycle_compute_start.elapsed(),
CPU_TARGET_UTIL,
);
let elapsed_now = elapsed_offset + start.elapsed();
let avg_loss_now = if run_iterations > 0 {
loss_sum / run_iterations as f32
} else {
0.0
};
live_snapshot.maybe_write(
common::LiveSnapshotPoint {
model_file: MODEL_FILE,
precision_label,
elapsed: elapsed_now,
iterations,
avg_loss: avg_loss_now,
last_loss,
train_samples_per_cycle: TRAIN_SAMPLES_PER_CYCLE,
},
false,
)?;
}
let elapsed = elapsed_offset + start.elapsed();
let avg_loss = if run_iterations > 0 {
loss_sum / run_iterations as f32
} else {
0.0
};
live_snapshot.maybe_write(
common::LiveSnapshotPoint {
model_file: MODEL_FILE,
precision_label,
elapsed,
iterations,
avg_loss,
last_loss,
train_samples_per_cycle: TRAIN_SAMPLES_PER_CYCLE,
},
true,
)?;
Ok(TrainResultF32 {
weights,
biases,
iterations,
elapsed,
avg_loss,
last_loss,
})
}
fn parse_f32_blob(blob: &[u8]) -> Result<Vec<f32>, String> {
if !blob.len().is_multiple_of(core::mem::size_of::<f32>()) {
return Err("invalid f32 blob length".to_string());
}
let mut out = Vec::with_capacity(blob.len() / core::mem::size_of::<f32>());
for chunk in blob.chunks_exact(4) {
out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
Ok(out)
}
fn parse_f64_blob_to_f32(blob: &[u8]) -> Result<Vec<f32>, String> {
if !blob.len().is_multiple_of(core::mem::size_of::<f64>()) {
return Err("invalid f64 blob length".to_string());
}
let mut out = Vec::with_capacity(blob.len() / core::mem::size_of::<f64>());
for chunk in blob.chunks_exact(8) {
let v = f64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
]);
out.push(v as f32);
}
Ok(out)
}
fn count_params(topology: &[usize]) -> (usize, usize) {
let mut weights = 0usize;
let mut biases = 0usize;
for pair in topology.windows(2) {
let inp = pair[0];
let out = pair[1];
weights = weights.saturating_add(inp.saturating_mul(out));
biases = biases.saturating_add(out);
}
(weights, biases)
}
fn required_train_len(layers: &[usize]) -> Result<usize, String> {
native_neural_network::trainer::required_train_buffer_len(layers)
.ok_or_else(|| "failed to compute train buffer length".to_string())
}
fn map_train_error(err: TrainError) -> String {
format!("training failed: {err:?}")
}