Skip to main content

tract_libcli/
profile.rs

1use crate::model::Model;
2use crate::tensor::RunTensors;
3use crate::tensor::make_inputs_for_model;
4use crate::{annotations::*, capture_gpu_trace};
5use std::any::TypeId;
6use std::time::{Duration, Instant};
7use tract_core::internal::*;
8use tract_core::num_traits::Zero;
9use tract_core::ops::submodel::TypedModelOpState;
10
11pub fn reusable_state(runnable: &Arc<dyn Runnable>) -> bool {
12    runnable.typed_model().is_some_and(|model| model.properties().contains_key("pulse.delay"))
13}
14
15pub fn run_one_step(
16    runnable: &Arc<dyn Runnable>,
17    state: &mut Box<dyn State>,
18    inputs: &RunTensors,
19) -> TractResult<Duration> {
20    if !reusable_state(runnable) {
21        *state = runnable.spawn()?;
22    }
23    let start = Instant::now();
24    for source in &inputs.sources {
25        state.run(source.clone())?;
26    }
27    Ok(start.elapsed())
28}
29
30pub struct BenchLimits {
31    pub warmup_loops: usize,
32    pub warmup_time: std::time::Duration,
33    pub max_loops: usize,
34    pub max_time: std::time::Duration,
35}
36
37impl Default for BenchLimits {
38    fn default() -> Self {
39        BenchLimits {
40            warmup_loops: 0,
41            warmup_time: Duration::default(),
42            max_loops: 100_000,
43            max_time: std::time::Duration::from_secs(5),
44        }
45    }
46}
47
48/// Structured output of a single bench run: named metrics (e.g. ("evaltime", secs),
49/// ("pp512", tok/s)) plus the loop iteration count for the human report line. The
50/// `bench`/`llm-bench` runners return this so callers — the interactive subcommand or
51/// the bench suite — consume data instead of parsing stdout.
52#[derive(Clone, Debug, Default)]
53pub struct BenchResult {
54    pub metrics: Vec<(String, f64)>,
55    pub iters: usize,
56}
57
58impl BenchResult {
59    /// Emit each metric as a `{"metric":<name>,"value":<f64>}` JSON line on stdout.
60    /// This is the bench-suite child→orchestrator contract: stdout is pure JSONL
61    /// (logs go to stderr), so the orchestrator can validate every line and treat
62    /// anything that does not parse as a hard failure.
63    pub fn emit_jsonl(&self) {
64        for (k, v) in &self.metrics {
65            println!(r#"{{"metric":{k:?},"value":{v}}}"#);
66        }
67    }
68}
69
70/// Load-pipeline checkpoints whose readings the bench suite tracks: the dotted
71/// pattern matched against a normalized event label, and the metric-name fragment.
72/// The probe writes spaces and dashes as underscores, so `model.ready` matches the
73/// `model_ready` line and `before.optimize` matches `after_"before-optimize"`.
74const READINGS_STAGES: &[(&str, &str)] =
75    &[("model.ready", "model_ready"), ("before.optimize", "before_optimize")];
76
77/// Extract the load-time readings the bench suite reports from a readings-probe
78/// output file. For each tracked checkpoint, emit `time_to_<stage>` (elapsed
79/// seconds), `rsz_at_<stage>` (resident bytes) and `active_at_<stage>` (alloc −
80/// free bytes). A missing file or absent checkpoint is skipped; the orchestrator
81/// decides which metrics are required.
82pub fn stage_metrics_from_readings(path: impl AsRef<std::path::Path>) -> Vec<(String, f64)> {
83    let Ok(content) = std::fs::read_to_string(path) else { return vec![] };
84    let normalize = |l: &str| l.replace(['_', '-'], ".");
85    let mut out = vec![];
86    for (pattern, name) in READINGS_STAGES {
87        let Some(line) = content.lines().find(|l| normalize(l).contains(pattern)) else { continue };
88        let f: Vec<&str> = line.split_whitespace().collect();
89        let parse = |i: usize| f.get(i).and_then(|s| s.parse::<f64>().ok());
90        if let (Some(time), Some(rsz), Some(alloc), Some(free)) =
91            (parse(0), parse(3), parse(9), parse(10))
92        {
93            out.push((format!("time_to_{name}"), time));
94            out.push((format!("rsz_at_{name}"), rsz));
95            out.push((format!("active_at_{name}"), alloc - free));
96        }
97    }
98    out
99}
100
101impl BenchLimits {
102    pub fn warmup(&self, runnable: &Arc<dyn Runnable>, inputs: &RunTensors) -> TractResult<()> {
103        if self.warmup_time.is_zero() && self.warmup_loops.is_zero() {
104            return Ok(());
105        }
106        let reuse = reusable_state(runnable);
107        let mut state = runnable.spawn()?;
108
109        let mut iters = 0;
110        let max_loops = if self.warmup_loops.is_zero() { usize::MAX } else { self.warmup_loops };
111        let max_time = if self.warmup_time.is_zero() { Duration::MAX } else { self.warmup_time };
112
113        let start_warmup = Instant::now();
114        info!("Warming up before profiling...");
115        while iters < max_loops && start_warmup.elapsed() < max_time {
116            if !reuse {
117                state = runnable.spawn()?;
118            }
119            state.run(inputs.sources[0].clone())?;
120            iters += 1;
121        }
122        info!("Done warming up.");
123
124        Ok(())
125    }
126
127    pub fn bench(
128        &self,
129        runnable: &Arc<dyn Runnable>,
130        inputs: &RunTensors,
131    ) -> TractResult<(usize, Duration)> {
132        if self.max_time.is_zero() && self.max_loops.is_zero() {
133            return Ok(Default::default());
134        }
135        let reuse = reusable_state(runnable);
136        let mut state = runnable.spawn()?;
137
138        let mut iters = 0;
139        let max_loops = if self.max_loops.is_zero() { usize::MAX } else { self.max_loops };
140        let max_time = if self.max_time.is_zero() { Duration::MAX } else { self.max_time };
141
142        let mut dur = Duration::default();
143        let start = Instant::now();
144        while iters < max_loops && start.elapsed() < max_time {
145            if !reuse {
146                state = runnable.spawn()?;
147            }
148            let start_inner = Instant::now();
149            state.run(inputs.sources[0].clone())?;
150            dur += start_inner.elapsed();
151            iters += 1;
152        }
153
154        Ok((iters, dur))
155    }
156}
157
158pub fn profile(
159    runnable: &Arc<dyn Runnable>,
160    bench_limits: &BenchLimits,
161    dg: &mut Annotations,
162    inputs: &RunTensors,
163    custom_profiler: Option<HashMap<TypeId, Profiler>>,
164    folded: bool,
165) -> TractResult<()> {
166    let Some(plan) = runnable.typed_plan() else {
167        bail!("Can only profile TypedRunnable");
168    };
169    info!("Running entire network");
170    let mut iters = 0usize;
171    let prefix = tvec!();
172
173    bench_limits.warmup(runnable, inputs)?;
174
175    let reuse = reusable_state(runnable);
176    let mut state = plan.spawn()?;
177
178    let mut dur = Duration::default();
179    let mut time_accounted_by_inner_nodes = Duration::default();
180    while iters < bench_limits.max_loops && dur < bench_limits.max_time {
181        if !reuse {
182            state = plan.spawn()?;
183        }
184        let start = Instant::now();
185
186        for source in &inputs.sources {
187            rec_profiler(
188                &mut state,
189                dg,
190                source,
191                custom_profiler.as_ref(),
192                &prefix,
193                None,
194                &mut time_accounted_by_inner_nodes,
195                folded,
196            )?;
197        }
198        dur += start.elapsed();
199        iters += 1;
200    }
201
202    dur -= time_accounted_by_inner_nodes;
203
204    info!("Running {} iterations max. for each node.", bench_limits.max_loops);
205    info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
206
207    let denum = (iters as f32).recip();
208    let entire = dur.mul_f32(denum);
209    for d in dg.tags.values_mut() {
210        if let Some(d) = d.profile.as_mut() {
211            *d = d.mul_f32(denum);
212        }
213
214        if let Some(d) = d.accelerator_profile.as_mut() {
215            *d = d.mul_f32(denum);
216        }
217    }
218    let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
219    let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
220    let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
221    dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
222    Ok(())
223}
224
225#[allow(clippy::type_complexity)]
226pub fn profile_gpu(
227    runnable: &Arc<dyn Runnable>,
228    bench_limits: &BenchLimits,
229    sub_matches: &clap::ArgMatches,
230    dg: &mut Annotations,
231    inputs: &RunTensors,
232    before_node: &dyn Fn(usize),
233    after_iteration: &dyn Fn(&mut Annotations, &[(usize, String)]) -> TractResult<()>,
234) -> TractResult<()> {
235    let Some(plan) = runnable.typed_plan() else {
236        bail!("Can only profile TypedRunnable");
237    };
238    info!("Running entire network");
239    let mut iters = 0usize;
240    let prefix = tvec!();
241
242    bench_limits.warmup(runnable, inputs)?;
243
244    let reuse = reusable_state(runnable);
245    let mut state = plan.spawn()?;
246
247    let mut dur = Duration::default();
248
249    capture_gpu_trace(sub_matches, || -> TractResult<()> {
250        while iters < bench_limits.max_loops && dur < bench_limits.max_time {
251            if !reuse {
252                state = plan.spawn()?;
253            }
254            let start = Instant::now();
255            for source in &inputs.sources {
256                rec_profiler_gpu(&mut state, dg, source, &prefix, before_node)?;
257            }
258            after_iteration(dg, &prefix)?;
259            dur += start.elapsed();
260            iters += 1;
261        }
262        Ok(())
263    })?;
264
265    info!("Running {} iterations max. for each node.", bench_limits.max_loops);
266    info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
267
268    let denum = (iters as f32).recip();
269    let entire = dur.mul_f32(denum);
270    for d in dg.tags.values_mut() {
271        if let Some(d) = d.profile.as_mut() {
272            *d = d.mul_f32(denum);
273        }
274
275        if let Some(d) = d.accelerator_profile.as_mut() {
276            *d = d.mul_f32(denum);
277        }
278    }
279    let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
280    let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
281    let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
282    dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
283    Ok(())
284}
285
286pub fn rec_profiler_gpu(
287    state: &mut TypedSimpleState,
288    dg: &mut Annotations,
289    inputs: &TVec<TValue>,
290    prefix: &[(usize, String)],
291    before_node: &dyn Fn(usize),
292) -> TractResult<TVec<TValue>> {
293    let r = state.run_plan_with_eval(
294        inputs.clone(),
295        |session_state, mut node_state, node, input| {
296            before_node(node.id);
297            // Profile node
298            let start = crate::time::now();
299            let res = tract_core::plan::eval(
300                session_state,
301                node_state.as_deref_mut(),
302                node,
303                input.clone(),
304            );
305            let elapsed = start.elapsed();
306            let node_id = NodeQId(prefix.into(), node.id);
307            *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
308
309            res
310        },
311    )?;
312
313    Ok(r)
314}
315
316#[allow(clippy::too_many_arguments)]
317pub fn rec_profiler(
318    state: &mut TypedSimpleState,
319    dg: &mut Annotations,
320    inputs: &TVec<TValue>,
321    profilers: Option<&HashMap<TypeId, Profiler>>,
322    prefix: &[(usize, String)],
323    multiplier: Option<usize>,
324    time_accounted_by_inner_nodes: &mut Duration,
325    folded: bool,
326) -> TractResult<TVec<TValue>> {
327    let r = state.run_plan_with_eval(
328        inputs.clone(),
329        |session_state, mut node_state, node, input| {
330            // Profile node
331            let start = crate::time::now();
332            let res = tract_core::plan::eval(
333                session_state,
334                node_state.as_deref_mut(),
335                node,
336                input.clone(),
337            );
338            let elapsed = start.elapsed().mul_f32(multiplier.unwrap_or(1) as _);
339            let node_id = NodeQId(prefix.into(), node.id);
340            *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
341
342            if !folded {
343                let start = crate::time::now();
344                profile_submodel(
345                    node,
346                    node_state,
347                    input,
348                    dg,
349                    profilers,
350                    prefix,
351                    time_accounted_by_inner_nodes,
352                )?;
353                *time_accounted_by_inner_nodes += start.elapsed();
354            }
355
356            // Update parent nodes if any (childs timings are deducted from parents)
357            let prefix_vec = prefix.to_vec();
358            if !prefix_vec.is_empty() {
359                (1..prefix_vec.len() + 1).map(|idx| prefix_vec[..idx].to_vec()).for_each(
360                    |parent_path| {
361                        let parent_node = parent_path.last().map(|it| it.0).unwrap();
362                        let parent = dg
363                            .node_mut(NodeQId(
364                                parent_path[..parent_path.len() - 1].into(),
365                                parent_node,
366                            ))
367                            .profile
368                            .get_or_insert(Duration::default());
369                        *parent -= elapsed.min(*parent);
370                    },
371                );
372            }
373            res
374        },
375    )?;
376    Ok(r)
377}
378
379fn profile_submodel(
380    node: &TypedNode,
381    mut node_state: Option<&mut dyn OpState>,
382    input: TVec<TValue>,
383    dg: &mut Annotations,
384    profilers: Option<&HashMap<TypeId, Profiler>>,
385    prefix: &[(usize, String)],
386    time_accounted_by_inner_nodes: &mut Duration,
387) -> TractResult<()> {
388    if let Some(ref mut op_state) = node_state {
389        if let Some(profiler) = profilers.and_then(|it| it.get(&op_state.type_id())) {
390            let mut new_prefix: TVec<_> = prefix.into();
391            new_prefix.push((node.id, "submodel".to_string()));
392
393            let (_, _) =
394                (profiler.func)(*op_state, input, dg, &new_prefix, time_accounted_by_inner_nodes)?;
395        } else if let Some(scan_state) = op_state.downcast_mut::<tract_core::ops::scan::State>() {
396            let mut new_prefix: TVec<_> = prefix.into();
397            new_prefix.push((node.id, "loop".to_string()));
398
399            let scan_inputs = make_inputs_for_model(scan_state.model_state.model())?;
400            let multi = scan_state.iteration_count(&input);
401
402            rec_profiler(
403                &mut scan_state.model_state,
404                dg,
405                &scan_inputs,
406                None,
407                &new_prefix,
408                Some(multi),
409                time_accounted_by_inner_nodes,
410                false,
411            )?;
412        } else if let Some(typed_model_state) = op_state.downcast_mut::<TypedModelOpState>() {
413            let mut new_prefix: TVec<_> = prefix.into();
414            new_prefix.push((node.id, "submodel".to_string()));
415
416            rec_profiler(
417                typed_model_state,
418                dg,
419                &input,
420                None,
421                &new_prefix,
422                None,
423                time_accounted_by_inner_nodes,
424                false,
425            )?;
426        }
427    }
428
429    Ok(())
430}
431
432type ProfilerFn = fn(
433    &mut dyn OpState,
434    TVec<TValue>,
435    &mut Annotations,
436    &[(usize, String)],
437    &mut Duration,
438) -> TractResult<(TractResult<TVec<TValue>>, Duration)>;
439
440#[derive(Clone)]
441pub struct Profiler {
442    pub func: ProfilerFn,
443    pub name: &'static str,
444}
445
446impl Hash for Profiler {
447    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
448        self.name.hash(state)
449    }
450}
451
452pub fn extract_costs(
453    annotations: &mut Annotations,
454    model: &dyn Model,
455    extra_symbols: &SymbolValues,
456) -> TractResult<()> {
457    fn extract_costs_rec(
458        annotations: &mut Annotations,
459        model: &dyn Model,
460        prefix: &[(usize, String)],
461        multiplier: TDim,
462        extra_symbols: &SymbolValues,
463    ) -> TractResult<()> {
464        if let Some(model) = model.downcast_ref::<TypedModel>() {
465            for node_id in 0..model.nodes().len() {
466                let inputs = model.node_input_facts(node_id)?;
467                let cost = model
468                    .node(node_id)
469                    .op
470                    .cost(&inputs)
471                    .with_context(|| format!("costing node {}", model.node(node_id)))?;
472                annotations.node_mut(NodeQId(prefix.into(), node_id)).cost = cost
473                    .into_iter()
474                    .map(|(k, v)| {
475                        let cost = if k.is_compute() { v * &multiplier } else { v };
476                        (k, cost.eval(extra_symbols))
477                    })
478                    .collect();
479
480                let nested_subs = model.nested_models(node_id);
481                let nested_multis = (model as &dyn Model).nested_models_iters(node_id, &inputs);
482                for (name, sub) in nested_subs {
483                    let mut prefix: TVec<_> = prefix.into();
484                    prefix.push((node_id, name.to_string()));
485                    extract_costs_rec(
486                        annotations,
487                        sub,
488                        &prefix,
489                        nested_multis.clone().unwrap_or_else(|| 1.into()) * &multiplier,
490                        extra_symbols,
491                    )?;
492                }
493            }
494        }
495        Ok(())
496    }
497    extract_costs_rec(annotations, model, &[], 1.into(), extra_symbols)
498}