Skip to main content

tract_libcli/
profile.rs

1use crate::model::Model;
2use crate::tensor::RunTensors;
3use crate::tensor::make_inputs_for_model;
4use crate::{annotations::*, capture_gpu_trace};
5use std::any::TypeId;
6use std::time::{Duration, Instant};
7use tract_core::internal::*;
8use tract_core::num_traits::Zero;
9use tract_core::ops::submodel::TypedModelOpState;
10
11pub fn reusable_state(runnable: &Arc<dyn Runnable>) -> bool {
12    runnable.typed_model().is_some_and(|model| model.properties().contains_key("pulse.delay"))
13}
14
15pub fn run_one_step(
16    runnable: &Arc<dyn Runnable>,
17    state: &mut Box<dyn State>,
18    inputs: &RunTensors,
19) -> TractResult<Duration> {
20    if !reusable_state(runnable) {
21        *state = runnable.spawn()?;
22        state.init_state(&inputs.state_initializers.clone())?;
23    }
24    let start = Instant::now();
25    for source in &inputs.sources {
26        state.run(source.clone())?;
27    }
28    Ok(start.elapsed())
29}
30
31pub struct BenchLimits {
32    pub warmup_loops: usize,
33    pub warmup_time: std::time::Duration,
34    pub max_loops: usize,
35    pub max_time: std::time::Duration,
36}
37
38impl Default for BenchLimits {
39    fn default() -> Self {
40        BenchLimits {
41            warmup_loops: 0,
42            warmup_time: Duration::default(),
43            max_loops: 100_000,
44            max_time: std::time::Duration::from_secs(5),
45        }
46    }
47}
48
49impl BenchLimits {
50    pub fn warmup(&self, runnable: &Arc<dyn Runnable>, inputs: &RunTensors) -> TractResult<()> {
51        if self.warmup_time.is_zero() && self.warmup_loops.is_zero() {
52            return Ok(());
53        }
54        let reuse = reusable_state(runnable);
55        let mut state = runnable.spawn()?;
56        state.init_state(&inputs.state_initializers.clone())?;
57
58        let mut iters = 0;
59        let max_loops = if self.warmup_loops.is_zero() { usize::MAX } else { self.warmup_loops };
60        let max_time = if self.warmup_time.is_zero() { Duration::MAX } else { self.warmup_time };
61
62        let start_warmup = Instant::now();
63        info!("Warming up before profiling...");
64        while iters < max_loops && start_warmup.elapsed() < max_time {
65            if !reuse {
66                state = runnable.spawn()?;
67                state.init_state(&inputs.state_initializers.clone())?;
68            }
69            state.run(inputs.sources[0].clone())?;
70            iters += 1;
71        }
72        info!("Done warming up.");
73
74        Ok(())
75    }
76
77    pub fn bench(
78        &self,
79        runnable: &Arc<dyn Runnable>,
80        inputs: &RunTensors,
81    ) -> TractResult<(usize, Duration)> {
82        if self.max_time.is_zero() && self.max_loops.is_zero() {
83            return Ok(Default::default());
84        }
85        let reuse = reusable_state(runnable);
86        let mut state = runnable.spawn()?;
87        state.init_state(&inputs.state_initializers.clone())?;
88
89        let mut iters = 0;
90        let max_loops = if self.max_loops.is_zero() { usize::MAX } else { self.max_loops };
91        let max_time = if self.max_time.is_zero() { Duration::MAX } else { self.max_time };
92
93        let mut dur = Duration::default();
94        let start = Instant::now();
95        while iters < max_loops && start.elapsed() < max_time {
96            if !reuse {
97                state = runnable.spawn()?;
98                state.init_state(&inputs.state_initializers.clone())?;
99            }
100            let start_inner = Instant::now();
101            state.run(inputs.sources[0].clone())?;
102            dur += start_inner.elapsed();
103            iters += 1;
104        }
105
106        Ok((iters, dur))
107    }
108}
109
110pub fn profile(
111    runnable: &Arc<dyn Runnable>,
112    bench_limits: &BenchLimits,
113    dg: &mut Annotations,
114    inputs: &RunTensors,
115    custom_profiler: Option<HashMap<TypeId, Profiler>>,
116    folded: bool,
117) -> TractResult<()> {
118    let Some(plan) = runnable.typed_plan() else {
119        bail!("Can only profile TypedRunnable");
120    };
121    info!("Running entire network");
122    let mut iters = 0usize;
123    let prefix = tvec!();
124
125    bench_limits.warmup(runnable, inputs)?;
126
127    let reuse = reusable_state(runnable);
128    let mut state = plan.spawn()?;
129    state.init_state(&inputs.state_initializers.clone())?;
130
131    let mut dur = Duration::default();
132    let mut time_accounted_by_inner_nodes = Duration::default();
133    while iters < bench_limits.max_loops && dur < bench_limits.max_time {
134        if !reuse {
135            state = plan.spawn()?;
136            state.init_state(&inputs.state_initializers.clone())?;
137        }
138        let start = Instant::now();
139
140        for source in &inputs.sources {
141            rec_profiler(
142                &mut state,
143                dg,
144                source,
145                custom_profiler.as_ref(),
146                &prefix,
147                None,
148                &mut time_accounted_by_inner_nodes,
149                folded,
150            )?;
151        }
152        dur += start.elapsed();
153        iters += 1;
154    }
155
156    dur -= time_accounted_by_inner_nodes;
157
158    info!("Running {} iterations max. for each node.", bench_limits.max_loops);
159    info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
160
161    let denum = (iters as f32).recip();
162    let entire = dur.mul_f32(denum);
163    for d in dg.tags.values_mut() {
164        if let Some(d) = d.profile.as_mut() {
165            *d = d.mul_f32(denum);
166        }
167
168        if let Some(d) = d.accelerator_profile.as_mut() {
169            *d = d.mul_f32(denum);
170        }
171    }
172    let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
173    let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
174    let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
175    dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
176    Ok(())
177}
178
179pub fn profile_gpu(
180    runnable: &Arc<dyn Runnable>,
181    bench_limits: &BenchLimits,
182    sub_matches: &clap::ArgMatches,
183    dg: &mut Annotations,
184    inputs: &RunTensors,
185) -> TractResult<()> {
186    let Some(plan) = runnable.typed_plan() else {
187        bail!("Can only profile TypedRunnable");
188    };
189    info!("Running entire network");
190    let mut iters = 0usize;
191    let prefix = tvec!();
192
193    bench_limits.warmup(runnable, inputs)?;
194
195    let reuse = reusable_state(runnable);
196    let mut state = plan.spawn()?;
197    state.init_state(&inputs.state_initializers.clone())?;
198
199    let mut dur = Duration::default();
200
201    capture_gpu_trace(sub_matches, || -> TractResult<()> {
202        while iters < bench_limits.max_loops && dur < bench_limits.max_time {
203            if !reuse {
204                state = plan.spawn()?;
205                state.init_state(&inputs.state_initializers.clone())?;
206            }
207            let start = Instant::now();
208            for source in &inputs.sources {
209                rec_profiler_gpu(&mut state, dg, source, &prefix)?;
210            }
211            dur += start.elapsed();
212            iters += 1;
213        }
214        Ok(())
215    })?;
216
217    info!("Running {} iterations max. for each node.", bench_limits.max_loops);
218    info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
219
220    let denum = (iters as f32).recip();
221    let entire = dur.mul_f32(denum);
222    for d in dg.tags.values_mut() {
223        if let Some(d) = d.profile.as_mut() {
224            *d = d.mul_f32(denum);
225        }
226
227        if let Some(d) = d.accelerator_profile.as_mut() {
228            *d = d.mul_f32(denum);
229        }
230    }
231    let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
232    let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
233    let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
234    dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
235    Ok(())
236}
237
238pub fn rec_profiler_gpu(
239    state: &mut TypedSimpleState,
240    dg: &mut Annotations,
241    inputs: &TVec<TValue>,
242    prefix: &[(usize, String)],
243) -> TractResult<TVec<TValue>> {
244    let r = state.run_plan_with_eval(
245        inputs.clone(),
246        |session_state, mut node_state, node, input| {
247            // Profile node
248            let start = crate::time::now();
249            let res = tract_core::plan::eval(
250                session_state,
251                node_state.as_deref_mut(),
252                node,
253                input.clone(),
254            );
255            let elapsed = start.elapsed();
256            let node_id = NodeQId(prefix.into(), node.id);
257            *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
258
259            res
260        },
261    )?;
262
263    Ok(r)
264}
265
266#[allow(clippy::too_many_arguments)]
267pub fn rec_profiler(
268    state: &mut TypedSimpleState,
269    dg: &mut Annotations,
270    inputs: &TVec<TValue>,
271    profilers: Option<&HashMap<TypeId, Profiler>>,
272    prefix: &[(usize, String)],
273    multiplier: Option<usize>,
274    time_accounted_by_inner_nodes: &mut Duration,
275    folded: bool,
276) -> TractResult<TVec<TValue>> {
277    let r = state.run_plan_with_eval(
278        inputs.clone(),
279        |session_state, mut node_state, node, input| {
280            // Profile node
281            let start = crate::time::now();
282            let res = tract_core::plan::eval(
283                session_state,
284                node_state.as_deref_mut(),
285                node,
286                input.clone(),
287            );
288            let elapsed = start.elapsed().mul_f32(multiplier.unwrap_or(1) as _);
289            let node_id = NodeQId(prefix.into(), node.id);
290            *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
291
292            if !folded {
293                let start = crate::time::now();
294                profile_submodel(
295                    node,
296                    node_state,
297                    input,
298                    dg,
299                    profilers,
300                    prefix,
301                    time_accounted_by_inner_nodes,
302                )?;
303                *time_accounted_by_inner_nodes += start.elapsed();
304            }
305
306            // Update parent nodes if any (childs timings are deducted from parents)
307            let prefix_vec = prefix.to_vec();
308            if !prefix_vec.is_empty() {
309                (1..prefix_vec.len() + 1).map(|idx| prefix_vec[..idx].to_vec()).for_each(
310                    |parent_path| {
311                        let parent_node = parent_path.last().map(|it| it.0).unwrap();
312                        let parent = dg
313                            .node_mut(NodeQId(
314                                parent_path[..parent_path.len() - 1].into(),
315                                parent_node,
316                            ))
317                            .profile
318                            .get_or_insert(Duration::default());
319                        *parent -= elapsed.min(*parent);
320                    },
321                );
322            }
323            res
324        },
325    )?;
326    Ok(r)
327}
328
329fn profile_submodel(
330    node: &TypedNode,
331    mut node_state: Option<&mut dyn OpState>,
332    input: TVec<TValue>,
333    dg: &mut Annotations,
334    profilers: Option<&HashMap<TypeId, Profiler>>,
335    prefix: &[(usize, String)],
336    time_accounted_by_inner_nodes: &mut Duration,
337) -> TractResult<()> {
338    if let Some(ref mut op_state) = node_state {
339        if let Some(profiler) = profilers.and_then(|it| it.get(&op_state.type_id())) {
340            let mut new_prefix: TVec<_> = prefix.into();
341            new_prefix.push((node.id, "submodel".to_string()));
342
343            let (_, _) =
344                (profiler.func)(*op_state, input, dg, &new_prefix, time_accounted_by_inner_nodes)?;
345        } else if let Some(scan_state) = op_state.downcast_mut::<tract_core::ops::scan::State>() {
346            let mut new_prefix: TVec<_> = prefix.into();
347            new_prefix.push((node.id, "loop".to_string()));
348
349            let scan_inputs = make_inputs_for_model(scan_state.model_state.model())?;
350            let multi = scan_state.iteration_count(&input);
351
352            rec_profiler(
353                &mut scan_state.model_state,
354                dg,
355                &scan_inputs,
356                None,
357                &new_prefix,
358                Some(multi),
359                time_accounted_by_inner_nodes,
360                false,
361            )?;
362        } else if let Some(typed_model_state) = op_state.downcast_mut::<TypedModelOpState>() {
363            let mut new_prefix: TVec<_> = prefix.into();
364            new_prefix.push((node.id, "submodel".to_string()));
365
366            rec_profiler(
367                typed_model_state,
368                dg,
369                &input,
370                None,
371                &new_prefix,
372                None,
373                time_accounted_by_inner_nodes,
374                false,
375            )?;
376        }
377    }
378
379    Ok(())
380}
381
382type ProfilerFn = fn(
383    &mut dyn OpState,
384    TVec<TValue>,
385    &mut Annotations,
386    &[(usize, String)],
387    &mut Duration,
388) -> TractResult<(TractResult<TVec<TValue>>, Duration)>;
389
390#[derive(Clone)]
391pub struct Profiler {
392    pub func: ProfilerFn,
393    pub name: &'static str,
394}
395
396impl Hash for Profiler {
397    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
398        self.name.hash(state)
399    }
400}
401
402pub fn extract_costs(
403    annotations: &mut Annotations,
404    model: &dyn Model,
405    extra_symbols: &SymbolValues,
406) -> TractResult<()> {
407    fn extract_costs_rec(
408        annotations: &mut Annotations,
409        model: &dyn Model,
410        prefix: &[(usize, String)],
411        multiplier: TDim,
412        extra_symbols: &SymbolValues,
413    ) -> TractResult<()> {
414        if let Some(model) = model.downcast_ref::<TypedModel>() {
415            for node_id in 0..model.nodes().len() {
416                let inputs = model.node_input_facts(node_id)?;
417                let cost = model
418                    .node(node_id)
419                    .op
420                    .cost(&inputs)
421                    .with_context(|| format!("costing node {}", model.node(node_id)))?;
422                annotations.node_mut(NodeQId(prefix.into(), node_id)).cost = cost
423                    .into_iter()
424                    .map(|(k, v)| {
425                        let cost = if k.is_compute() { v * &multiplier } else { v };
426                        (k, cost.eval(extra_symbols))
427                    })
428                    .collect();
429
430                let nested_subs = model.nested_models(node_id);
431                let nested_multis = (model as &dyn Model).nested_models_iters(node_id, &inputs);
432                for (name, sub) in nested_subs {
433                    let mut prefix: TVec<_> = prefix.into();
434                    prefix.push((node_id, name.to_string()));
435                    extract_costs_rec(
436                        annotations,
437                        sub,
438                        &prefix,
439                        nested_multis.clone().unwrap_or_else(|| 1.into()) * &multiplier,
440                        extra_symbols,
441                    )?;
442                }
443            }
444        }
445        Ok(())
446    }
447    extract_costs_rec(annotations, model, &[], 1.into(), extra_symbols)
448}