Skip to main content

tract_libcli/
profile.rs

1use crate::model::Model;
2use crate::tensor::RunTensors;
3use crate::tensor::make_inputs_for_model;
4use crate::{annotations::*, capture_gpu_trace};
5use std::any::TypeId;
6use std::time::{Duration, Instant};
7use tract_core::internal::*;
8use tract_core::num_traits::Zero;
9use tract_core::ops::submodel::TypedModelOpState;
10
11pub fn reusable_state(runnable: &Arc<dyn Runnable>) -> bool {
12    runnable.typed_model().is_some_and(|model| model.properties().contains_key("pulse.delay"))
13}
14
15pub fn run_one_step(
16    runnable: &Arc<dyn Runnable>,
17    state: &mut Box<dyn State>,
18    inputs: &RunTensors,
19) -> TractResult<Duration> {
20    if !reusable_state(runnable) {
21        *state = runnable.spawn()?;
22    }
23    let start = Instant::now();
24    for source in &inputs.sources {
25        state.run(source.clone())?;
26    }
27    Ok(start.elapsed())
28}
29
30pub struct BenchLimits {
31    pub warmup_loops: usize,
32    pub warmup_time: std::time::Duration,
33    pub max_loops: usize,
34    pub max_time: std::time::Duration,
35}
36
37impl Default for BenchLimits {
38    fn default() -> Self {
39        BenchLimits {
40            warmup_loops: 0,
41            warmup_time: Duration::default(),
42            max_loops: 100_000,
43            max_time: std::time::Duration::from_secs(5),
44        }
45    }
46}
47
48impl BenchLimits {
49    pub fn warmup(&self, runnable: &Arc<dyn Runnable>, inputs: &RunTensors) -> TractResult<()> {
50        if self.warmup_time.is_zero() && self.warmup_loops.is_zero() {
51            return Ok(());
52        }
53        let reuse = reusable_state(runnable);
54        let mut state = runnable.spawn()?;
55
56        let mut iters = 0;
57        let max_loops = if self.warmup_loops.is_zero() { usize::MAX } else { self.warmup_loops };
58        let max_time = if self.warmup_time.is_zero() { Duration::MAX } else { self.warmup_time };
59
60        let start_warmup = Instant::now();
61        info!("Warming up before profiling...");
62        while iters < max_loops && start_warmup.elapsed() < max_time {
63            if !reuse {
64                state = runnable.spawn()?;
65            }
66            state.run(inputs.sources[0].clone())?;
67            iters += 1;
68        }
69        info!("Done warming up.");
70
71        Ok(())
72    }
73
74    pub fn bench(
75        &self,
76        runnable: &Arc<dyn Runnable>,
77        inputs: &RunTensors,
78    ) -> TractResult<(usize, Duration)> {
79        if self.max_time.is_zero() && self.max_loops.is_zero() {
80            return Ok(Default::default());
81        }
82        let reuse = reusable_state(runnable);
83        let mut state = runnable.spawn()?;
84
85        let mut iters = 0;
86        let max_loops = if self.max_loops.is_zero() { usize::MAX } else { self.max_loops };
87        let max_time = if self.max_time.is_zero() { Duration::MAX } else { self.max_time };
88
89        let mut dur = Duration::default();
90        let start = Instant::now();
91        while iters < max_loops && start.elapsed() < max_time {
92            if !reuse {
93                state = runnable.spawn()?;
94            }
95            let start_inner = Instant::now();
96            state.run(inputs.sources[0].clone())?;
97            dur += start_inner.elapsed();
98            iters += 1;
99        }
100
101        Ok((iters, dur))
102    }
103}
104
105pub fn profile(
106    runnable: &Arc<dyn Runnable>,
107    bench_limits: &BenchLimits,
108    dg: &mut Annotations,
109    inputs: &RunTensors,
110    custom_profiler: Option<HashMap<TypeId, Profiler>>,
111    folded: bool,
112) -> TractResult<()> {
113    let Some(plan) = runnable.typed_plan() else {
114        bail!("Can only profile TypedRunnable");
115    };
116    info!("Running entire network");
117    let mut iters = 0usize;
118    let prefix = tvec!();
119
120    bench_limits.warmup(runnable, inputs)?;
121
122    let reuse = reusable_state(runnable);
123    let mut state = plan.spawn()?;
124
125    let mut dur = Duration::default();
126    let mut time_accounted_by_inner_nodes = Duration::default();
127    while iters < bench_limits.max_loops && dur < bench_limits.max_time {
128        if !reuse {
129            state = plan.spawn()?;
130        }
131        let start = Instant::now();
132
133        for source in &inputs.sources {
134            rec_profiler(
135                &mut state,
136                dg,
137                source,
138                custom_profiler.as_ref(),
139                &prefix,
140                None,
141                &mut time_accounted_by_inner_nodes,
142                folded,
143            )?;
144        }
145        dur += start.elapsed();
146        iters += 1;
147    }
148
149    dur -= time_accounted_by_inner_nodes;
150
151    info!("Running {} iterations max. for each node.", bench_limits.max_loops);
152    info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
153
154    let denum = (iters as f32).recip();
155    let entire = dur.mul_f32(denum);
156    for d in dg.tags.values_mut() {
157        if let Some(d) = d.profile.as_mut() {
158            *d = d.mul_f32(denum);
159        }
160
161        if let Some(d) = d.accelerator_profile.as_mut() {
162            *d = d.mul_f32(denum);
163        }
164    }
165    let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
166    let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
167    let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
168    dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
169    Ok(())
170}
171
172pub fn profile_gpu(
173    runnable: &Arc<dyn Runnable>,
174    bench_limits: &BenchLimits,
175    sub_matches: &clap::ArgMatches,
176    dg: &mut Annotations,
177    inputs: &RunTensors,
178    before_node: &dyn Fn(usize),
179    after_iteration: &dyn Fn(&mut Annotations, &[(usize, String)]) -> TractResult<()>,
180) -> TractResult<()> {
181    let Some(plan) = runnable.typed_plan() else {
182        bail!("Can only profile TypedRunnable");
183    };
184    info!("Running entire network");
185    let mut iters = 0usize;
186    let prefix = tvec!();
187
188    bench_limits.warmup(runnable, inputs)?;
189
190    let reuse = reusable_state(runnable);
191    let mut state = plan.spawn()?;
192
193    let mut dur = Duration::default();
194
195    capture_gpu_trace(sub_matches, || -> TractResult<()> {
196        while iters < bench_limits.max_loops && dur < bench_limits.max_time {
197            if !reuse {
198                state = plan.spawn()?;
199            }
200            let start = Instant::now();
201            for source in &inputs.sources {
202                rec_profiler_gpu(&mut state, dg, source, &prefix, before_node)?;
203            }
204            after_iteration(dg, &prefix)?;
205            dur += start.elapsed();
206            iters += 1;
207        }
208        Ok(())
209    })?;
210
211    info!("Running {} iterations max. for each node.", bench_limits.max_loops);
212    info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
213
214    let denum = (iters as f32).recip();
215    let entire = dur.mul_f32(denum);
216    for d in dg.tags.values_mut() {
217        if let Some(d) = d.profile.as_mut() {
218            *d = d.mul_f32(denum);
219        }
220
221        if let Some(d) = d.accelerator_profile.as_mut() {
222            *d = d.mul_f32(denum);
223        }
224    }
225    let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
226    let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
227    let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
228    dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
229    Ok(())
230}
231
232pub fn rec_profiler_gpu(
233    state: &mut TypedSimpleState,
234    dg: &mut Annotations,
235    inputs: &TVec<TValue>,
236    prefix: &[(usize, String)],
237    before_node: &dyn Fn(usize),
238) -> TractResult<TVec<TValue>> {
239    let r = state.run_plan_with_eval(
240        inputs.clone(),
241        |session_state, mut node_state, node, input| {
242            before_node(node.id);
243            // Profile node
244            let start = crate::time::now();
245            let res = tract_core::plan::eval(
246                session_state,
247                node_state.as_deref_mut(),
248                node,
249                input.clone(),
250            );
251            let elapsed = start.elapsed();
252            let node_id = NodeQId(prefix.into(), node.id);
253            *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
254
255            res
256        },
257    )?;
258
259    Ok(r)
260}
261
262#[allow(clippy::too_many_arguments)]
263pub fn rec_profiler(
264    state: &mut TypedSimpleState,
265    dg: &mut Annotations,
266    inputs: &TVec<TValue>,
267    profilers: Option<&HashMap<TypeId, Profiler>>,
268    prefix: &[(usize, String)],
269    multiplier: Option<usize>,
270    time_accounted_by_inner_nodes: &mut Duration,
271    folded: bool,
272) -> TractResult<TVec<TValue>> {
273    let r = state.run_plan_with_eval(
274        inputs.clone(),
275        |session_state, mut node_state, node, input| {
276            // Profile node
277            let start = crate::time::now();
278            let res = tract_core::plan::eval(
279                session_state,
280                node_state.as_deref_mut(),
281                node,
282                input.clone(),
283            );
284            let elapsed = start.elapsed().mul_f32(multiplier.unwrap_or(1) as _);
285            let node_id = NodeQId(prefix.into(), node.id);
286            *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
287
288            if !folded {
289                let start = crate::time::now();
290                profile_submodel(
291                    node,
292                    node_state,
293                    input,
294                    dg,
295                    profilers,
296                    prefix,
297                    time_accounted_by_inner_nodes,
298                )?;
299                *time_accounted_by_inner_nodes += start.elapsed();
300            }
301
302            // Update parent nodes if any (childs timings are deducted from parents)
303            let prefix_vec = prefix.to_vec();
304            if !prefix_vec.is_empty() {
305                (1..prefix_vec.len() + 1).map(|idx| prefix_vec[..idx].to_vec()).for_each(
306                    |parent_path| {
307                        let parent_node = parent_path.last().map(|it| it.0).unwrap();
308                        let parent = dg
309                            .node_mut(NodeQId(
310                                parent_path[..parent_path.len() - 1].into(),
311                                parent_node,
312                            ))
313                            .profile
314                            .get_or_insert(Duration::default());
315                        *parent -= elapsed.min(*parent);
316                    },
317                );
318            }
319            res
320        },
321    )?;
322    Ok(r)
323}
324
325fn profile_submodel(
326    node: &TypedNode,
327    mut node_state: Option<&mut dyn OpState>,
328    input: TVec<TValue>,
329    dg: &mut Annotations,
330    profilers: Option<&HashMap<TypeId, Profiler>>,
331    prefix: &[(usize, String)],
332    time_accounted_by_inner_nodes: &mut Duration,
333) -> TractResult<()> {
334    if let Some(ref mut op_state) = node_state {
335        if let Some(profiler) = profilers.and_then(|it| it.get(&op_state.type_id())) {
336            let mut new_prefix: TVec<_> = prefix.into();
337            new_prefix.push((node.id, "submodel".to_string()));
338
339            let (_, _) =
340                (profiler.func)(*op_state, input, dg, &new_prefix, time_accounted_by_inner_nodes)?;
341        } else if let Some(scan_state) = op_state.downcast_mut::<tract_core::ops::scan::State>() {
342            let mut new_prefix: TVec<_> = prefix.into();
343            new_prefix.push((node.id, "loop".to_string()));
344
345            let scan_inputs = make_inputs_for_model(scan_state.model_state.model())?;
346            let multi = scan_state.iteration_count(&input);
347
348            rec_profiler(
349                &mut scan_state.model_state,
350                dg,
351                &scan_inputs,
352                None,
353                &new_prefix,
354                Some(multi),
355                time_accounted_by_inner_nodes,
356                false,
357            )?;
358        } else if let Some(typed_model_state) = op_state.downcast_mut::<TypedModelOpState>() {
359            let mut new_prefix: TVec<_> = prefix.into();
360            new_prefix.push((node.id, "submodel".to_string()));
361
362            rec_profiler(
363                typed_model_state,
364                dg,
365                &input,
366                None,
367                &new_prefix,
368                None,
369                time_accounted_by_inner_nodes,
370                false,
371            )?;
372        }
373    }
374
375    Ok(())
376}
377
378type ProfilerFn = fn(
379    &mut dyn OpState,
380    TVec<TValue>,
381    &mut Annotations,
382    &[(usize, String)],
383    &mut Duration,
384) -> TractResult<(TractResult<TVec<TValue>>, Duration)>;
385
386#[derive(Clone)]
387pub struct Profiler {
388    pub func: ProfilerFn,
389    pub name: &'static str,
390}
391
392impl Hash for Profiler {
393    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
394        self.name.hash(state)
395    }
396}
397
398pub fn extract_costs(
399    annotations: &mut Annotations,
400    model: &dyn Model,
401    extra_symbols: &SymbolValues,
402) -> TractResult<()> {
403    fn extract_costs_rec(
404        annotations: &mut Annotations,
405        model: &dyn Model,
406        prefix: &[(usize, String)],
407        multiplier: TDim,
408        extra_symbols: &SymbolValues,
409    ) -> TractResult<()> {
410        if let Some(model) = model.downcast_ref::<TypedModel>() {
411            for node_id in 0..model.nodes().len() {
412                let inputs = model.node_input_facts(node_id)?;
413                let cost = model
414                    .node(node_id)
415                    .op
416                    .cost(&inputs)
417                    .with_context(|| format!("costing node {}", model.node(node_id)))?;
418                annotations.node_mut(NodeQId(prefix.into(), node_id)).cost = cost
419                    .into_iter()
420                    .map(|(k, v)| {
421                        let cost = if k.is_compute() { v * &multiplier } else { v };
422                        (k, cost.eval(extra_symbols))
423                    })
424                    .collect();
425
426                let nested_subs = model.nested_models(node_id);
427                let nested_multis = (model as &dyn Model).nested_models_iters(node_id, &inputs);
428                for (name, sub) in nested_subs {
429                    let mut prefix: TVec<_> = prefix.into();
430                    prefix.push((node_id, name.to_string()));
431                    extract_costs_rec(
432                        annotations,
433                        sub,
434                        &prefix,
435                        nested_multis.clone().unwrap_or_else(|| 1.into()) * &multiplier,
436                        extra_symbols,
437                    )?;
438                }
439            }
440        }
441        Ok(())
442    }
443    extract_costs_rec(annotations, model, &[], 1.into(), extra_symbols)
444}