1use tract_core::internal::*;
2use tract_core::num_traits::Zero;
3use tract_core::ops::scan::State;
4use tract_core::ops::submodel::TypedModelOpState;
5
6use crate::annotations::*;
7use crate::model::Model;
8use crate::tensor::make_inputs_for_model;
9use std::any::TypeId;
10use std::time::Duration;
11
12pub struct BenchLimits {
13 pub warmup_loops: usize,
14 pub warmup_time: std::time::Duration,
15 pub max_loops: usize,
16 pub max_time: std::time::Duration,
17}
18
19impl Default for BenchLimits {
20 fn default() -> Self {
21 BenchLimits {
22 warmup_loops: 0,
23 warmup_time: Duration::default(),
24 max_loops: 100_000,
25 max_time: std::time::Duration::from_secs(5),
26 }
27 }
28}
29
30impl BenchLimits {
31 pub fn warmup(&self, model: &TypedModel, inputs: &TVec<TValue>) -> TractResult<()> {
32 if self.warmup_time.is_zero() && self.warmup_loops.is_zero() {
33 return Ok(());
34 }
35 let plan = TypedSimplePlan::new(model.clone())?;
36 let mut state = TypedSimpleState::new(Arc::new(plan))?;
37 let mut iters = 0;
38 let max_loops = if self.warmup_loops.is_zero() { usize::MAX } else { self.warmup_loops };
39 let max_time = if self.warmup_time.is_zero() { Duration::MAX } else { self.warmup_time };
40
41 let start_warmup = crate::time::now();
42 debug!("Warming up before profiling...");
43 while iters < max_loops && start_warmup.elapsed() < max_time {
44 state.run(inputs.clone())?;
45 iters += 1;
46 }
47 debug!("Done warming up.");
48 Ok(())
49 }
50}
51
52pub fn profile(
53 model: &TypedModel,
54 bench_limits: &BenchLimits,
55 dg: &mut Annotations,
56 plan_options: &PlanOptions,
57 inputs: &TVec<TValue>,
58 custom_profiler: Option<HashMap<TypeId, Profiler>>,
59 folded: bool,
60) -> TractResult<()> {
61 info!("Running entire network");
62 let mut iters = 0usize;
63 let prefix = tvec!();
64
65 bench_limits.warmup(model, inputs)?;
66
67 let plan = TypedSimplePlan::new_with_options(model.clone(), plan_options)?;
68 let mut state = TypedSimpleState::new(Arc::new(plan))?;
69
70 let start = crate::time::now();
71 let mut time_accounted_by_inner_nodes = Duration::default();
72 while iters < bench_limits.max_loops && start.elapsed() < bench_limits.max_time {
73 rec_profiler(
74 &mut state,
75 dg,
76 inputs,
77 custom_profiler.as_ref(),
78 &prefix,
79 None,
80 &mut time_accounted_by_inner_nodes,
81 folded,
82 )?;
83
84 iters += 1;
85 }
86
87 let entire = start.elapsed() - time_accounted_by_inner_nodes;
88
89 info!("Running {} iterations max. for each node.", bench_limits.max_loops);
90 info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
91
92 let denum = (iters as f32).recip();
93 let entire = entire.mul_f32(denum);
94 for d in dg.tags.values_mut() {
95 if let Some(d) = d.profile.as_mut() {
96 *d = d.mul_f32(denum);
97 }
98
99 if let Some(d) = d.accelerator_profile.as_mut() {
100 *d = d.mul_f32(denum);
101 }
102 }
103 let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
104 let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
105 let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
106 dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
107 Ok(())
108}
109
110#[cfg(any(target_os = "macos", target_os = "ios"))]
111pub fn profile_metal(
112 model: &TypedModel,
113 bench_limits: &BenchLimits,
114 dg: &mut Annotations,
115 plan_options: &PlanOptions,
116 inputs: &TVec<TValue>,
117) -> TractResult<()> {
118 info!("Running entire network");
119 let mut iters = 0usize;
120 let prefix = tvec!();
121
122 bench_limits.warmup(model, inputs)?;
123
124 let mut plan = TypedSimplePlan::new_with_options(model.clone(), plan_options)?;
125 let state = TypedSimpleState::new_from_inputs(&plan, inputs.clone())?;
126
127 let session_handler =
128 tract_metal::MetalSessionHandler::from_plan(&plan, &state.session_state.resolved_symbols)?;
129
130 plan = plan.with_session_handler(session_handler);
131
132 let mut state = TypedSimpleState::new(Arc::new(plan))?;
133
134 let mut entire = Duration::default();
135 while iters < bench_limits.max_loops && entire < bench_limits.max_time {
136 entire += rec_profiler_metal(&mut state, dg, inputs, &prefix)?.1;
137
138 iters += 1;
139 }
140
141 info!("Running {} iterations max. for each node.", bench_limits.max_loops);
142 info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
143
144 let denum = (iters as f32).recip();
145 let entire = entire.mul_f32(denum);
146 for d in dg.tags.values_mut() {
147 if let Some(d) = d.profile.as_mut() {
148 *d = d.mul_f32(denum);
149 }
150
151 if let Some(d) = d.accelerator_profile.as_mut() {
152 *d = d.mul_f32(denum);
153 }
154 }
155 let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
156 let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
157 let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
158 dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
159 Ok(())
160}
161
162#[cfg(any(target_os = "macos", target_os = "ios"))]
163pub fn rec_profiler_metal(
164 state: &mut TypedSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>,
165 dg: &mut Annotations,
166 inputs: &TVec<TValue>,
167 prefix: &[(usize, String)],
168) -> TractResult<(TVec<TValue>, Duration)> {
169 tract_metal::METAL_CONTEXT.with_borrow(|ctxt| {
170 let (mut cpu_start, mut gpu_start): (u64, u64) = (0, 0);
171 ctxt.device().sample_timestamps(&mut cpu_start, &mut gpu_start);
172
173 let n_nodes = state.plan().model().nodes_len();
174 let (result, eval_dur, profiler) = ctxt.profile(n_nodes, || {
175 let profile_start = crate::time::now();
176 let r = state.run_plan_with_eval(
177 inputs.clone(),
178 |session_state, mut node_state, node, input| {
179 let start = crate::time::now();
181 let res = tract_core::plan::eval(
182 session_state,
183 node_state.as_deref_mut(),
184 node,
185 input.clone(),
186 );
187 let elapsed = start.elapsed();
188 let node_id = NodeQId(prefix.into(), node.id);
189 *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
190
191 res
192 },
193 )?;
194
195 Ok((r, profile_start.elapsed()))
196 })?;
197
198 let (mut cpu_end, mut gpu_end): (u64, u64) = (0, 0);
199 ctxt.device().sample_timestamps(&mut cpu_end, &mut gpu_end);
200
201 profiler.iter().enumerate().for_each(|(node_id, duration)| {
202 let node_id = NodeQId(prefix.into(), node_id);
203 *dg.node_mut(node_id).accelerator_profile.get_or_insert(Duration::default()) +=
204 Duration::from_nanos(tract_metal::utils::rescale_gpu_duration(
205 *duration, cpu_start, cpu_end, gpu_start, gpu_end,
206 ));
207 });
208
209 Ok((result, eval_dur))
210 })
211}
212
213#[allow(clippy::too_many_arguments)]
214pub fn rec_profiler(
215 state: &mut TypedSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>,
216 dg: &mut Annotations,
217 inputs: &TVec<TValue>,
218 profilers: Option<&HashMap<TypeId, Profiler>>,
219 prefix: &[(usize, String)],
220 multiplier: Option<usize>,
221 time_accounted_by_inner_nodes: &mut Duration,
222 folded: bool,
223) -> TractResult<TVec<TValue>> {
224 let r = state.run_plan_with_eval(
225 inputs.clone(),
226 |session_state, mut node_state, node, input| {
227 let start = crate::time::now();
229 let res = tract_core::plan::eval(
230 session_state,
231 node_state.as_deref_mut(),
232 node,
233 input.clone(),
234 );
235 let elapsed = start.elapsed().mul_f32(multiplier.unwrap_or(1) as _);
236 let node_id = NodeQId(prefix.into(), node.id);
237 *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
238
239 if !folded {
240 let start = crate::time::now();
241 profile_submodel(
242 node,
243 node_state,
244 input,
245 dg,
246 profilers,
247 prefix,
248 time_accounted_by_inner_nodes,
249 )?;
250 *time_accounted_by_inner_nodes += start.elapsed();
251 }
252
253 let prefix_vec = prefix.to_vec();
255 if !prefix_vec.is_empty() {
256 (1..prefix_vec.len() + 1).map(|idx| prefix_vec[..idx].to_vec()).for_each(
257 |parent_path| {
258 let parent_node = parent_path.last().map(|it| it.0).unwrap();
259 let parent = dg
260 .node_mut(NodeQId(
261 parent_path[..parent_path.len() - 1].into(),
262 parent_node,
263 ))
264 .profile
265 .get_or_insert(Duration::default());
266 *parent -= elapsed.min(*parent);
267 },
268 );
269 }
270 res
271 },
272 )?;
273 Ok(r)
274}
275
276fn profile_submodel(
277 node: &TypedNode,
278 mut node_state: Option<&mut dyn OpState>,
279 input: TVec<TValue>,
280 dg: &mut Annotations,
281 profilers: Option<&HashMap<TypeId, Profiler>>,
282 prefix: &[(usize, String)],
283 time_accounted_by_inner_nodes: &mut Duration,
284) -> TractResult<()> {
285 if let Some(ref mut op_state) = node_state {
286 if let Some(profiler) = profilers.and_then(|it| it.get(&op_state.type_id())) {
287 let mut new_prefix: TVec<_> = prefix.into();
288 new_prefix.push((node.id, "submodel".to_string()));
289
290 let (_, _) =
291 (profiler.func)(*op_state, input, dg, &new_prefix, time_accounted_by_inner_nodes)?;
292 } else if let Some(scan_state) = op_state.downcast_mut::<State>() {
293 let mut new_prefix: TVec<_> = prefix.into();
294 new_prefix.push((node.id, "loop".to_string()));
295
296 let scan_inputs = make_inputs_for_model(scan_state.model_state.model())?;
297 let multi = scan_state.iteration_count(&input);
298
299 rec_profiler(
300 &mut scan_state.model_state,
301 dg,
302 &scan_inputs,
303 None,
304 &new_prefix,
305 Some(multi),
306 time_accounted_by_inner_nodes,
307 false,
308 )?;
309 } else if let Some(typed_model_state) = op_state.downcast_mut::<TypedModelOpState>() {
310 let mut new_prefix: TVec<_> = prefix.into();
311 new_prefix.push((node.id, "submodel".to_string()));
312
313 rec_profiler(
314 typed_model_state,
315 dg,
316 &input,
317 None,
318 &new_prefix,
319 None,
320 time_accounted_by_inner_nodes,
321 false,
322 )?;
323 }
324 }
325
326 Ok(())
327}
328
329type ProfilerFn = fn(
330 &mut dyn OpState,
331 TVec<TValue>,
332 &mut Annotations,
333 &[(usize, String)],
334 &mut Duration,
335) -> TractResult<(TractResult<TVec<TValue>>, Duration)>;
336
337#[derive(Clone)]
338pub struct Profiler {
339 pub func: ProfilerFn,
340 pub name: &'static str,
341}
342
343impl Hash for Profiler {
344 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
345 self.name.hash(state)
346 }
347}
348
349pub fn extract_costs(
350 annotations: &mut Annotations,
351 model: &dyn Model,
352 extra_symbols: &SymbolValues,
353) -> TractResult<()> {
354 fn extract_costs_rec(
355 annotations: &mut Annotations,
356 model: &dyn Model,
357 prefix: &[(usize, String)],
358 multiplier: TDim,
359 extra_symbols: &SymbolValues,
360 ) -> TractResult<()> {
361 if let Some(model) = model.downcast_ref::<TypedModel>() {
362 for node_id in 0..model.nodes().len() {
363 let inputs = model.node_input_facts(node_id)?;
364 let cost = model
365 .node(node_id)
366 .op
367 .cost(&inputs)
368 .with_context(|| format!("costing node {}", model.node(node_id)))?;
369 annotations.node_mut(NodeQId(prefix.into(), node_id)).cost = cost
370 .into_iter()
371 .map(|(k, v)| {
372 let cost = if k.is_compute() { v * &multiplier } else { v };
373 (k, cost.eval(extra_symbols))
374 })
375 .collect();
376
377 let nested_subs = model.nested_models(node_id);
378 let nested_multis = (model as &dyn Model).nested_models_iters(node_id, &inputs);
379 for (name, sub) in nested_subs {
380 let mut prefix: TVec<_> = prefix.into();
381 prefix.push((node_id, name.to_string()));
382 extract_costs_rec(
383 annotations,
384 sub,
385 &prefix,
386 nested_multis.clone().unwrap_or_else(|| 1.into()) * &multiplier,
387 extra_symbols,
388 )?;
389 }
390 }
391 }
392 Ok(())
393 }
394 extract_costs_rec(annotations, model, &[], 1.into(), extra_symbols)
395}