use std::time::Instant;
use zeropool::BufferPool;
#[allow(missing_docs)]
const TOTAL_MODEL_SIZE_GB: usize = 2;
const NUM_EPOCHS: usize = 1000; const NUM_LAYERS: usize = 48;
const METADATA_ITEMS_PER_LAYER: usize = 10;
const METADATA_SIZE: usize = 4 * 1024; const EMBEDDING_WEIGHT_SIZE: usize = 128 * 1024 * 1024; const ATTENTION_WEIGHT_SIZE: usize = 64 * 1024 * 1024; const FFN_WEIGHT_SIZE: usize = 256 * 1024 * 1024; const LAYER_NORM_SIZE: usize = 1024 * 1024;
fn main() {
eprintln!("=== ML Checkpoint Loader Profiling ===");
eprintln!("Model size: {TOTAL_MODEL_SIZE_GB} GB");
eprintln!("Epochs: {NUM_EPOCHS}");
eprintln!("Layers: {NUM_LAYERS}");
eprintln!();
let pool = BufferPool::builder()
.num_shards(8)
.tls_cache_size(4)
.max_buffers_per_shard(32)
.min_buffer_size(1024 * 1024) .build();
eprintln!("Pre-allocating buffers...");
preallocate_buffers(&pool);
let total_start = Instant::now();
for epoch in 1..=NUM_EPOCHS {
if epoch % 10 == 1 || epoch == NUM_EPOCHS {
eprintln!("Processing epoch {epoch}/{NUM_EPOCHS}...");
}
load_checkpoint(&pool);
}
let total_duration = total_start.elapsed();
eprintln!("\n=== Profiling Complete ===");
eprintln!("Total time: {total_duration:.2?}");
eprintln!("Average epoch time: {:.2?}", total_duration / NUM_EPOCHS as u32);
eprintln!("Pool stats: {} buffers in pool", pool.len());
}
fn preallocate_buffers(pool: &BufferPool) {
pool.preallocate(8, EMBEDDING_WEIGHT_SIZE);
pool.preallocate(16, ATTENTION_WEIGHT_SIZE);
pool.preallocate(16, FFN_WEIGHT_SIZE);
pool.preallocate(32, METADATA_SIZE);
}
fn load_checkpoint(pool: &BufferPool) {
load_embeddings(pool);
for layer_idx in 0..NUM_LAYERS {
load_transformer_layer(pool, layer_idx);
}
load_output_head(pool);
}
fn load_embeddings(pool: &BufferPool) {
let mut embedding = pool.get(EMBEDDING_WEIGHT_SIZE);
simulate_disk_read(&mut embedding);
process_weights(&embedding);
let mut pos_embedding = pool.get(EMBEDDING_WEIGHT_SIZE / 2);
simulate_disk_read(&mut pos_embedding);
process_weights(&pos_embedding);
}
fn load_transformer_layer(pool: &BufferPool, _layer_idx: usize) {
for _ in 0..METADATA_ITEMS_PER_LAYER {
let mut meta = pool.get(METADATA_SIZE);
simulate_disk_read(&mut meta);
parse_metadata(&meta);
}
for _ in 0..4 {
let mut attn_weight = pool.get(ATTENTION_WEIGHT_SIZE);
simulate_disk_read(&mut attn_weight);
process_weights(&attn_weight);
}
for _ in 0..2 {
let mut ffn_weight = pool.get(FFN_WEIGHT_SIZE);
simulate_disk_read(&mut ffn_weight);
process_weights(&ffn_weight);
}
for _ in 0..2 {
let mut norm_weight = pool.get(LAYER_NORM_SIZE);
simulate_disk_read(&mut norm_weight);
process_weights(&norm_weight);
}
}
fn load_output_head(pool: &BufferPool) {
let mut final_norm = pool.get(LAYER_NORM_SIZE);
simulate_disk_read(&mut final_norm);
process_weights(&final_norm);
let mut lm_head = pool.get(EMBEDDING_WEIGHT_SIZE);
simulate_disk_read(&mut lm_head);
process_weights(&lm_head);
}
fn simulate_disk_read(buffer: &mut [u8]) {
if !buffer.is_empty() {
buffer[0] = 1;
if buffer.len() > 1 {
buffer[buffer.len() - 1] = 1;
}
}
}
fn process_weights(buffer: &[u8]) {
std::hint::black_box(buffer.len());
std::hint::black_box(&buffer[0]);
}
fn parse_metadata(buffer: &[u8]) {
std::hint::black_box(buffer.len());
}