1use crate::model::Model;
2use crate::tensor::RunTensors;
3use crate::tensor::make_inputs_for_model;
4use crate::{annotations::*, capture_gpu_trace};
5use std::any::TypeId;
6use std::time::{Duration, Instant};
7use tract_core::internal::*;
8use tract_core::num_traits::Zero;
9use tract_core::ops::submodel::TypedModelOpState;
10
11pub fn reusable_state(runnable: &Arc<dyn Runnable>) -> bool {
12 runnable.typed_model().is_some_and(|model| model.properties().contains_key("pulse.delay"))
13}
14
15pub fn run_one_step(
16 runnable: &Arc<dyn Runnable>,
17 state: &mut Box<dyn State>,
18 inputs: &RunTensors,
19) -> TractResult<Duration> {
20 if !reusable_state(runnable) {
21 *state = runnable.spawn()?;
22 }
23 let start = Instant::now();
24 for source in &inputs.sources {
25 state.run(source.clone())?;
26 }
27 Ok(start.elapsed())
28}
29
30pub struct BenchLimits {
31 pub warmup_loops: usize,
32 pub warmup_time: std::time::Duration,
33 pub max_loops: usize,
34 pub max_time: std::time::Duration,
35}
36
37impl Default for BenchLimits {
38 fn default() -> Self {
39 BenchLimits {
40 warmup_loops: 0,
41 warmup_time: Duration::default(),
42 max_loops: 100_000,
43 max_time: std::time::Duration::from_secs(5),
44 }
45 }
46}
47
48impl BenchLimits {
49 pub fn warmup(&self, runnable: &Arc<dyn Runnable>, inputs: &RunTensors) -> TractResult<()> {
50 if self.warmup_time.is_zero() && self.warmup_loops.is_zero() {
51 return Ok(());
52 }
53 let reuse = reusable_state(runnable);
54 let mut state = runnable.spawn()?;
55
56 let mut iters = 0;
57 let max_loops = if self.warmup_loops.is_zero() { usize::MAX } else { self.warmup_loops };
58 let max_time = if self.warmup_time.is_zero() { Duration::MAX } else { self.warmup_time };
59
60 let start_warmup = Instant::now();
61 info!("Warming up before profiling...");
62 while iters < max_loops && start_warmup.elapsed() < max_time {
63 if !reuse {
64 state = runnable.spawn()?;
65 }
66 state.run(inputs.sources[0].clone())?;
67 iters += 1;
68 }
69 info!("Done warming up.");
70
71 Ok(())
72 }
73
74 pub fn bench(
75 &self,
76 runnable: &Arc<dyn Runnable>,
77 inputs: &RunTensors,
78 ) -> TractResult<(usize, Duration)> {
79 if self.max_time.is_zero() && self.max_loops.is_zero() {
80 return Ok(Default::default());
81 }
82 let reuse = reusable_state(runnable);
83 let mut state = runnable.spawn()?;
84
85 let mut iters = 0;
86 let max_loops = if self.max_loops.is_zero() { usize::MAX } else { self.max_loops };
87 let max_time = if self.max_time.is_zero() { Duration::MAX } else { self.max_time };
88
89 let mut dur = Duration::default();
90 let start = Instant::now();
91 while iters < max_loops && start.elapsed() < max_time {
92 if !reuse {
93 state = runnable.spawn()?;
94 }
95 let start_inner = Instant::now();
96 state.run(inputs.sources[0].clone())?;
97 dur += start_inner.elapsed();
98 iters += 1;
99 }
100
101 Ok((iters, dur))
102 }
103}
104
105pub fn profile(
106 runnable: &Arc<dyn Runnable>,
107 bench_limits: &BenchLimits,
108 dg: &mut Annotations,
109 inputs: &RunTensors,
110 custom_profiler: Option<HashMap<TypeId, Profiler>>,
111 folded: bool,
112) -> TractResult<()> {
113 let Some(plan) = runnable.typed_plan() else {
114 bail!("Can only profile TypedRunnable");
115 };
116 info!("Running entire network");
117 let mut iters = 0usize;
118 let prefix = tvec!();
119
120 bench_limits.warmup(runnable, inputs)?;
121
122 let reuse = reusable_state(runnable);
123 let mut state = plan.spawn()?;
124
125 let mut dur = Duration::default();
126 let mut time_accounted_by_inner_nodes = Duration::default();
127 while iters < bench_limits.max_loops && dur < bench_limits.max_time {
128 if !reuse {
129 state = plan.spawn()?;
130 }
131 let start = Instant::now();
132
133 for source in &inputs.sources {
134 rec_profiler(
135 &mut state,
136 dg,
137 source,
138 custom_profiler.as_ref(),
139 &prefix,
140 None,
141 &mut time_accounted_by_inner_nodes,
142 folded,
143 )?;
144 }
145 dur += start.elapsed();
146 iters += 1;
147 }
148
149 dur -= time_accounted_by_inner_nodes;
150
151 info!("Running {} iterations max. for each node.", bench_limits.max_loops);
152 info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
153
154 let denum = (iters as f32).recip();
155 let entire = dur.mul_f32(denum);
156 for d in dg.tags.values_mut() {
157 if let Some(d) = d.profile.as_mut() {
158 *d = d.mul_f32(denum);
159 }
160
161 if let Some(d) = d.accelerator_profile.as_mut() {
162 *d = d.mul_f32(denum);
163 }
164 }
165 let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
166 let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
167 let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
168 dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
169 Ok(())
170}
171
172pub fn profile_gpu(
173 runnable: &Arc<dyn Runnable>,
174 bench_limits: &BenchLimits,
175 sub_matches: &clap::ArgMatches,
176 dg: &mut Annotations,
177 inputs: &RunTensors,
178 before_node: &dyn Fn(usize),
179 after_iteration: &dyn Fn(&mut Annotations, &[(usize, String)]) -> TractResult<()>,
180) -> TractResult<()> {
181 let Some(plan) = runnable.typed_plan() else {
182 bail!("Can only profile TypedRunnable");
183 };
184 info!("Running entire network");
185 let mut iters = 0usize;
186 let prefix = tvec!();
187
188 bench_limits.warmup(runnable, inputs)?;
189
190 let reuse = reusable_state(runnable);
191 let mut state = plan.spawn()?;
192
193 let mut dur = Duration::default();
194
195 capture_gpu_trace(sub_matches, || -> TractResult<()> {
196 while iters < bench_limits.max_loops && dur < bench_limits.max_time {
197 if !reuse {
198 state = plan.spawn()?;
199 }
200 let start = Instant::now();
201 for source in &inputs.sources {
202 rec_profiler_gpu(&mut state, dg, source, &prefix, before_node)?;
203 }
204 after_iteration(dg, &prefix)?;
205 dur += start.elapsed();
206 iters += 1;
207 }
208 Ok(())
209 })?;
210
211 info!("Running {} iterations max. for each node.", bench_limits.max_loops);
212 info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
213
214 let denum = (iters as f32).recip();
215 let entire = dur.mul_f32(denum);
216 for d in dg.tags.values_mut() {
217 if let Some(d) = d.profile.as_mut() {
218 *d = d.mul_f32(denum);
219 }
220
221 if let Some(d) = d.accelerator_profile.as_mut() {
222 *d = d.mul_f32(denum);
223 }
224 }
225 let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
226 let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
227 let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
228 dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
229 Ok(())
230}
231
232pub fn rec_profiler_gpu(
233 state: &mut TypedSimpleState,
234 dg: &mut Annotations,
235 inputs: &TVec<TValue>,
236 prefix: &[(usize, String)],
237 before_node: &dyn Fn(usize),
238) -> TractResult<TVec<TValue>> {
239 let r = state.run_plan_with_eval(
240 inputs.clone(),
241 |session_state, mut node_state, node, input| {
242 before_node(node.id);
243 let start = crate::time::now();
245 let res = tract_core::plan::eval(
246 session_state,
247 node_state.as_deref_mut(),
248 node,
249 input.clone(),
250 );
251 let elapsed = start.elapsed();
252 let node_id = NodeQId(prefix.into(), node.id);
253 *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
254
255 res
256 },
257 )?;
258
259 Ok(r)
260}
261
262#[allow(clippy::too_many_arguments)]
263pub fn rec_profiler(
264 state: &mut TypedSimpleState,
265 dg: &mut Annotations,
266 inputs: &TVec<TValue>,
267 profilers: Option<&HashMap<TypeId, Profiler>>,
268 prefix: &[(usize, String)],
269 multiplier: Option<usize>,
270 time_accounted_by_inner_nodes: &mut Duration,
271 folded: bool,
272) -> TractResult<TVec<TValue>> {
273 let r = state.run_plan_with_eval(
274 inputs.clone(),
275 |session_state, mut node_state, node, input| {
276 let start = crate::time::now();
278 let res = tract_core::plan::eval(
279 session_state,
280 node_state.as_deref_mut(),
281 node,
282 input.clone(),
283 );
284 let elapsed = start.elapsed().mul_f32(multiplier.unwrap_or(1) as _);
285 let node_id = NodeQId(prefix.into(), node.id);
286 *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
287
288 if !folded {
289 let start = crate::time::now();
290 profile_submodel(
291 node,
292 node_state,
293 input,
294 dg,
295 profilers,
296 prefix,
297 time_accounted_by_inner_nodes,
298 )?;
299 *time_accounted_by_inner_nodes += start.elapsed();
300 }
301
302 let prefix_vec = prefix.to_vec();
304 if !prefix_vec.is_empty() {
305 (1..prefix_vec.len() + 1).map(|idx| prefix_vec[..idx].to_vec()).for_each(
306 |parent_path| {
307 let parent_node = parent_path.last().map(|it| it.0).unwrap();
308 let parent = dg
309 .node_mut(NodeQId(
310 parent_path[..parent_path.len() - 1].into(),
311 parent_node,
312 ))
313 .profile
314 .get_or_insert(Duration::default());
315 *parent -= elapsed.min(*parent);
316 },
317 );
318 }
319 res
320 },
321 )?;
322 Ok(r)
323}
324
325fn profile_submodel(
326 node: &TypedNode,
327 mut node_state: Option<&mut dyn OpState>,
328 input: TVec<TValue>,
329 dg: &mut Annotations,
330 profilers: Option<&HashMap<TypeId, Profiler>>,
331 prefix: &[(usize, String)],
332 time_accounted_by_inner_nodes: &mut Duration,
333) -> TractResult<()> {
334 if let Some(ref mut op_state) = node_state {
335 if let Some(profiler) = profilers.and_then(|it| it.get(&op_state.type_id())) {
336 let mut new_prefix: TVec<_> = prefix.into();
337 new_prefix.push((node.id, "submodel".to_string()));
338
339 let (_, _) =
340 (profiler.func)(*op_state, input, dg, &new_prefix, time_accounted_by_inner_nodes)?;
341 } else if let Some(scan_state) = op_state.downcast_mut::<tract_core::ops::scan::State>() {
342 let mut new_prefix: TVec<_> = prefix.into();
343 new_prefix.push((node.id, "loop".to_string()));
344
345 let scan_inputs = make_inputs_for_model(scan_state.model_state.model())?;
346 let multi = scan_state.iteration_count(&input);
347
348 rec_profiler(
349 &mut scan_state.model_state,
350 dg,
351 &scan_inputs,
352 None,
353 &new_prefix,
354 Some(multi),
355 time_accounted_by_inner_nodes,
356 false,
357 )?;
358 } else if let Some(typed_model_state) = op_state.downcast_mut::<TypedModelOpState>() {
359 let mut new_prefix: TVec<_> = prefix.into();
360 new_prefix.push((node.id, "submodel".to_string()));
361
362 rec_profiler(
363 typed_model_state,
364 dg,
365 &input,
366 None,
367 &new_prefix,
368 None,
369 time_accounted_by_inner_nodes,
370 false,
371 )?;
372 }
373 }
374
375 Ok(())
376}
377
378type ProfilerFn = fn(
379 &mut dyn OpState,
380 TVec<TValue>,
381 &mut Annotations,
382 &[(usize, String)],
383 &mut Duration,
384) -> TractResult<(TractResult<TVec<TValue>>, Duration)>;
385
386#[derive(Clone)]
387pub struct Profiler {
388 pub func: ProfilerFn,
389 pub name: &'static str,
390}
391
392impl Hash for Profiler {
393 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
394 self.name.hash(state)
395 }
396}
397
398pub fn extract_costs(
399 annotations: &mut Annotations,
400 model: &dyn Model,
401 extra_symbols: &SymbolValues,
402) -> TractResult<()> {
403 fn extract_costs_rec(
404 annotations: &mut Annotations,
405 model: &dyn Model,
406 prefix: &[(usize, String)],
407 multiplier: TDim,
408 extra_symbols: &SymbolValues,
409 ) -> TractResult<()> {
410 if let Some(model) = model.downcast_ref::<TypedModel>() {
411 for node_id in 0..model.nodes().len() {
412 let inputs = model.node_input_facts(node_id)?;
413 let cost = model
414 .node(node_id)
415 .op
416 .cost(&inputs)
417 .with_context(|| format!("costing node {}", model.node(node_id)))?;
418 annotations.node_mut(NodeQId(prefix.into(), node_id)).cost = cost
419 .into_iter()
420 .map(|(k, v)| {
421 let cost = if k.is_compute() { v * &multiplier } else { v };
422 (k, cost.eval(extra_symbols))
423 })
424 .collect();
425
426 let nested_subs = model.nested_models(node_id);
427 let nested_multis = (model as &dyn Model).nested_models_iters(node_id, &inputs);
428 for (name, sub) in nested_subs {
429 let mut prefix: TVec<_> = prefix.into();
430 prefix.push((node_id, name.to_string()));
431 extract_costs_rec(
432 annotations,
433 sub,
434 &prefix,
435 nested_multis.clone().unwrap_or_else(|| 1.into()) * &multiplier,
436 extra_symbols,
437 )?;
438 }
439 }
440 }
441 Ok(())
442 }
443 extract_costs_rec(annotations, model, &[], 1.into(), extra_symbols)
444}