tract 0.22.1

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use tract_gpu::memory::DeviceMemSchema;
use tract_hir::internal::*;

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
struct MemArenaUsage {
    arena_memory_size: i64,
    peak_memory_size: i64,
    peak_memory_usage: f32,
}

impl MemArenaUsage {
    pub fn eval_from_schema(
        schema: &DeviceMemSchema,
        symbol_values: &SymbolValues,
    ) -> TractResult<Self> {
        Ok(Self {
            arena_memory_size: schema.eval_memory_size(symbol_values)?,
            peak_memory_size: schema.eval_peak_memory_size(symbol_values)?,
            peak_memory_usage: schema.eval_usage(symbol_values)?,
        })
    }
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
struct MemArenaMetrics {
    memory_size: String,
    size_by_partition: Vec<String>,
    pp: BTreeMap<i64, MemArenaUsage>,
    tg: BTreeMap<i64, MemArenaUsage>,
    max_memory_size: i64,
    aggregate_usage: f32,
}

impl MemArenaMetrics {
    pub fn from_schema(schema: &DeviceMemSchema) -> TractResult<Self> {
        log::info!("Analyzing memory arena utilization...");
        const MAX_GEN_TOKENS: i64 = 2048;
        const MAX_PROMPT_TOKENS: i64 = 2048;

        const STEP_TOKENS: i64 = 16;
        let memory_size: String = schema.memory_size().to_string();
        let size_by_partition: Vec<String> =
            schema.size_by_partition().iter().map(|it| it.to_string()).collect();
        let symbol_scope = SymbolScope::default();
        let sequence_length = symbol_scope.sym("S");
        let past_sequence_length = symbol_scope.sym("P");

        let mut pp = BTreeMap::new();
        let mut max_memory_size: i64 = 0;
        let mut sum_size: i64 = 0;
        let mut sum_used: i64 = 0;
        for s in (STEP_TOKENS..MAX_PROMPT_TOKENS + 1).step_by(STEP_TOKENS as usize) {
            log::info!("Prompt processing: P: 0, S: {s}");
            let symbol_values =
                SymbolValues::default().with(&sequence_length, s).with(&past_sequence_length, 0);
            let usage = MemArenaUsage::eval_from_schema(schema, &symbol_values)?;
            max_memory_size = max_memory_size.max(usage.arena_memory_size);
            sum_size += usage.arena_memory_size;
            sum_used += usage.peak_memory_size;
            pp.insert(s, usage);
        }
        let mut tg = BTreeMap::new();
        for p in (0..MAX_GEN_TOKENS + 1).step_by(STEP_TOKENS as usize) {
            log::info!("Token generation: P: {p}, S: 1");
            let symbol_values =
                SymbolValues::default().with(&sequence_length, 1).with(&past_sequence_length, p);
            let usage = MemArenaUsage::eval_from_schema(schema, &symbol_values)?;
            max_memory_size = max_memory_size.max(usage.arena_memory_size);
            sum_size += usage.arena_memory_size;
            sum_used += usage.peak_memory_size;
            tg.insert(p, usage);
        }

        let aggregate_usage = ((sum_used * 100 / sum_size.max(1)) as f32) / 100.0;
        Ok(Self { memory_size, size_by_partition, pp, tg, max_memory_size, aggregate_usage })
    }
}

pub fn dump_metrics(
    model: &TypedModel,
    options: &PlanOptions,
    path: impl AsRef<std::path::Path>,
) -> TractResult<()> {
    log::info!("Analyzing Metal memory schema utilization...");
    const SCHEMA_HINT_S: i64 = 1024;
    const SCHEMA_HINT_P: i64 = 0;

    let plan = SimplePlan::new_with_options(model, options)?;
    let order = plan.order_without_consts();
    let mut symbol_values = SymbolValues::default();
    let sequence_length = model.symbols.get("S").context("Could not find symbol S in model")?;
    let past_sequence_length =
        model.symbols.get("P").context("Could not find symbol P in model")?;

    symbol_values.set(&sequence_length, SCHEMA_HINT_S);
    symbol_values.set(&past_sequence_length, SCHEMA_HINT_P);

    let schema = DeviceMemSchema::build(model, order, &symbol_values)?;

    println!("resolved_memory_size: {}", schema.eval_memory_size(&symbol_values)?);
    println!("Schema:\n{schema}");

    let metrics = MemArenaMetrics::from_schema(&schema)?;

    std::fs::write(path.as_ref(), serde_json::to_string(&metrics)?).expect("Unable to write file");

    Ok(())
}