runmat_runtime/builtins/timing/
timeit.rs1use runmat_time::Instant;
7use std::cmp::Ordering;
8
9use runmat_builtins::Value;
10use runmat_macros::runtime_builtin;
11
12use crate::builtins::common::spec::{
13 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14 ReductionNaN, ResidencyPolicy, ShapeRequirements,
15};
16use crate::builtins::timing::type_resolvers::timeit_type;
17
18const TARGET_BATCH_SECONDS: f64 = 0.005;
19const MAX_BATCH_SECONDS: f64 = 0.25;
20const LOOP_COUNT_LIMIT: usize = 1 << 20;
21const MIN_SAMPLE_COUNT: usize = 7;
22const MAX_SAMPLE_COUNT: usize = 21;
23const BUILTIN_NAME: &str = "timeit";
24
25fn timeit_error(message: impl Into<String>) -> crate::RuntimeError {
26 crate::build_runtime_error(message)
27 .with_builtin(BUILTIN_NAME)
28 .build()
29}
30
31#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::timing::timeit")]
32pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
33 name: "timeit",
34 op_kind: GpuOpKind::Custom("timer"),
35 supported_precisions: &[],
36 broadcast: BroadcastSemantics::None,
37 provider_hooks: &[],
38 constant_strategy: ConstantStrategy::InlineLiteral,
39 residency: ResidencyPolicy::GatherImmediately,
40 nan_mode: ReductionNaN::Include,
41 two_pass_threshold: None,
42 workgroup_size: None,
43 accepts_nan_mode: false,
44 notes: "Host-side helper; GPU kernels execute only if invoked by the timed function.",
45};
46
47#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::timing::timeit")]
48pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
49 name: "timeit",
50 shape: ShapeRequirements::Any,
51 constant_strategy: ConstantStrategy::InlineLiteral,
52 elementwise: None,
53 reduction: None,
54 emits_nan: false,
55 notes: "Timing helper; excluded from fusion planning.",
56};
57
58#[runtime_builtin(
59 name = "timeit",
60 category = "timing",
61 summary = "Measure the execution time of a zero-argument function handle.",
62 keywords = "timeit,benchmark,timing,performance,gpu",
63 accel = "helper",
64 type_resolver(timeit_type),
65 builtin_path = "crate::builtins::timing::timeit"
66)]
67async fn timeit_builtin(func: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
68 let requested_outputs = parse_num_outputs(&rest)?;
69 let callable = prepare_callable(func, requested_outputs)?;
70
71 callable.invoke().await?;
73
74 let loop_count = determine_loop_count(&callable).await?;
75 let samples = collect_samples(&callable, loop_count).await?;
76 if samples.is_empty() {
77 return Ok(Value::Num(0.0));
78 }
79
80 Ok(Value::Num(compute_median(samples)))
81}
82
83fn parse_num_outputs(rest: &[Value]) -> Result<Option<usize>, crate::RuntimeError> {
84 match rest.len() {
85 0 => Ok(None),
86 1 => parse_non_negative_integer(&rest[0]).map(Some),
87 _ => Err(timeit_error("timeit: too many input arguments")),
88 }
89}
90
91fn parse_non_negative_integer(value: &Value) -> Result<usize, crate::RuntimeError> {
92 match value {
93 Value::Int(iv) => {
94 let raw = iv.to_i64();
95 if raw < 0 {
96 Err(timeit_error(
97 "timeit: numOutputs must be a nonnegative integer",
98 ))
99 } else {
100 Ok(raw as usize)
101 }
102 }
103 Value::Num(n) => {
104 if !n.is_finite() {
105 return Err(timeit_error("timeit: numOutputs must be finite"));
106 }
107 if *n < 0.0 {
108 return Err(timeit_error(
109 "timeit: numOutputs must be a nonnegative integer",
110 ));
111 }
112 let rounded = n.round();
113 if (rounded - n).abs() > f64::EPSILON {
114 return Err(timeit_error("timeit: numOutputs must be an integer value"));
115 }
116 Ok(rounded as usize)
117 }
118 _ => Err(timeit_error(
119 "timeit: numOutputs must be a scalar numeric value",
120 )),
121 }
122}
123
124async fn determine_loop_count(callable: &TimeitCallable) -> Result<usize, crate::RuntimeError> {
125 let mut loops = 1usize;
126 loop {
127 let elapsed = run_batch(callable, loops).await?;
128 if elapsed >= TARGET_BATCH_SECONDS
129 || elapsed >= MAX_BATCH_SECONDS
130 || loops >= LOOP_COUNT_LIMIT
131 {
132 return Ok(loops);
133 }
134 loops = loops.saturating_mul(2);
135 if loops == 0 {
136 return Ok(LOOP_COUNT_LIMIT);
137 }
138 }
139}
140
141async fn collect_samples(
142 callable: &TimeitCallable,
143 loop_count: usize,
144) -> Result<Vec<f64>, crate::RuntimeError> {
145 let mut samples = Vec::with_capacity(MIN_SAMPLE_COUNT);
146 while samples.len() < MIN_SAMPLE_COUNT {
147 let elapsed = run_batch(callable, loop_count).await?;
148 let per_iter = elapsed / loop_count as f64;
149 samples.push(per_iter);
150 if samples.len() >= MAX_SAMPLE_COUNT || elapsed >= MAX_BATCH_SECONDS {
151 break;
152 }
153 }
154 Ok(samples)
155}
156
157async fn run_batch(
158 callable: &TimeitCallable,
159 loop_count: usize,
160) -> Result<f64, crate::RuntimeError> {
161 let start = Instant::now();
162 for _ in 0..loop_count {
163 let value = callable.invoke().await?;
164 drop(value);
165 }
166 Ok(start.elapsed().as_secs_f64())
167}
168
169fn compute_median(mut samples: Vec<f64>) -> f64 {
170 if samples.is_empty() {
171 return 0.0;
172 }
173 samples.sort_by(|a, b| match (a.is_nan(), b.is_nan()) {
174 (true, true) => Ordering::Equal,
175 (true, false) => Ordering::Greater,
176 (false, true) => Ordering::Less,
177 (false, false) => a.partial_cmp(b).unwrap_or_else(|| {
178 if a < b {
179 Ordering::Less
180 } else {
181 Ordering::Greater
182 }
183 }),
184 });
185 let mid = samples.len() / 2;
186 if samples.len() % 2 == 1 {
187 samples[mid]
188 } else {
189 (samples[mid - 1] + samples[mid]) * 0.5
190 }
191}
192
193#[derive(Clone)]
194struct TimeitCallable {
195 handle: Value,
196 num_outputs: Option<usize>,
197}
198
199impl TimeitCallable {
200 async fn invoke(&self) -> Result<Value, crate::RuntimeError> {
201 if let Some(0) = self.num_outputs {
206 let value =
207 crate::call_builtin_async("feval", std::slice::from_ref(&self.handle)).await?;
208 drop(value);
209 Ok(Value::Num(0.0))
210 } else {
211 Ok(crate::call_builtin_async("feval", std::slice::from_ref(&self.handle)).await?)
212 }
213 }
214}
215
216fn prepare_callable(
217 func: Value,
218 num_outputs: Option<usize>,
219) -> Result<TimeitCallable, crate::RuntimeError> {
220 match func {
221 Value::String(text) => parse_handle_string(&text).map(|handle| TimeitCallable {
222 handle: Value::String(handle),
223 num_outputs,
224 }),
225 Value::CharArray(arr) => {
226 if arr.rows != 1 {
227 Err(timeit_error(
228 "timeit: function handle must be a string scalar or function handle",
229 ))
230 } else {
231 let text: String = arr.data.iter().collect();
232 parse_handle_string(&text).map(|handle| TimeitCallable {
233 handle: Value::String(handle),
234 num_outputs,
235 })
236 }
237 }
238 Value::StringArray(sa) => {
239 if sa.data.len() == 1 {
240 parse_handle_string(&sa.data[0]).map(|handle| TimeitCallable {
241 handle: Value::String(handle),
242 num_outputs,
243 })
244 } else {
245 Err(timeit_error(
246 "timeit: function handle must be a string scalar or function handle",
247 ))
248 }
249 }
250 Value::FunctionHandle(name) => Ok(TimeitCallable {
251 handle: Value::String(format!("@{name}")),
252 num_outputs,
253 }),
254 Value::Closure(closure) => Ok(TimeitCallable {
255 handle: Value::Closure(closure),
256 num_outputs,
257 }),
258 other => Err(timeit_error(format!(
259 "timeit: first argument must be a function handle, got {other:?}"
260 ))),
261 }
262}
263
264fn parse_handle_string(text: &str) -> Result<String, crate::RuntimeError> {
265 let trimmed = text.trim();
266 if let Some(rest) = trimmed.strip_prefix('@') {
267 if rest.trim().is_empty() {
268 Err(timeit_error("timeit: empty function handle string"))
269 } else {
270 Ok(format!("@{}", rest.trim()))
271 }
272 } else {
273 Err(timeit_error(
274 "timeit: expected a function handle string beginning with '@'",
275 ))
276 }
277}
278
279#[cfg(test)]
280pub(crate) mod tests {
281 use super::*;
282 use futures::executor::block_on;
283 use runmat_builtins::IntValue;
284 use std::sync::atomic::{AtomicUsize, Ordering};
285
286 static COUNTER_DEFAULT: AtomicUsize = AtomicUsize::new(0);
287 static COUNTER_NUM_OUTPUTS: AtomicUsize = AtomicUsize::new(0);
288 static COUNTER_INVALID: AtomicUsize = AtomicUsize::new(0);
289 static COUNTER_ZERO_OUTPUTS: AtomicUsize = AtomicUsize::new(0);
290
291 #[runtime_builtin(
292 name = "__timeit_helper_counter_default",
293 type_resolver(crate::builtins::timing::type_resolvers::timeit_type),
294 builtin_path = "crate::builtins::timing::timeit::tests"
295 )]
296 async fn helper_counter_default() -> crate::BuiltinResult<Value> {
297 COUNTER_DEFAULT.fetch_add(1, Ordering::SeqCst);
298 Ok(Value::Num(1.0))
299 }
300
301 #[runtime_builtin(
302 name = "__timeit_helper_counter_outputs",
303 type_resolver(crate::builtins::timing::type_resolvers::timeit_type),
304 builtin_path = "crate::builtins::timing::timeit::tests"
305 )]
306 async fn helper_counter_outputs() -> crate::BuiltinResult<Value> {
307 COUNTER_NUM_OUTPUTS.fetch_add(1, Ordering::SeqCst);
308 Ok(Value::Num(1.0))
309 }
310
311 #[runtime_builtin(
312 name = "__timeit_helper_counter_invalid",
313 type_resolver(crate::builtins::timing::type_resolvers::timeit_type),
314 builtin_path = "crate::builtins::timing::timeit::tests"
315 )]
316 async fn helper_counter_invalid() -> crate::BuiltinResult<Value> {
317 COUNTER_INVALID.fetch_add(1, Ordering::SeqCst);
318 Ok(Value::Num(1.0))
319 }
320
321 #[runtime_builtin(
322 name = "__timeit_helper_zero_outputs",
323 type_resolver(crate::builtins::timing::type_resolvers::timeit_type),
324 builtin_path = "crate::builtins::timing::timeit::tests"
325 )]
326 async fn helper_counter_zero_outputs() -> crate::BuiltinResult<Value> {
327 COUNTER_ZERO_OUTPUTS.fetch_add(1, Ordering::SeqCst);
328 Ok(Value::Num(0.0))
329 }
330
331 fn default_handle() -> Value {
332 Value::String("@__timeit_helper_counter_default".to_string())
333 }
334
335 fn assert_timeit_error_contains(err: crate::RuntimeError, needle: &str) {
336 let message = err.message().to_ascii_lowercase();
337 assert!(
338 message.contains(&needle.to_ascii_lowercase()),
339 "unexpected error text: {}",
340 err.message()
341 );
342 }
343
344 fn outputs_handle() -> Value {
345 Value::String("@__timeit_helper_counter_outputs".to_string())
346 }
347
348 fn invalid_handle() -> Value {
349 Value::String("@__timeit_helper_counter_invalid".to_string())
350 }
351
352 fn zero_outputs_handle() -> Value {
353 Value::String("@__timeit_helper_zero_outputs".to_string())
354 }
355
356 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
357 #[test]
358 fn timeit_measures_time() {
359 COUNTER_DEFAULT.store(0, Ordering::SeqCst);
360 let result = block_on(timeit_builtin(default_handle(), Vec::new())).expect("timeit");
361 match result {
362 Value::Num(v) => assert!(v >= 0.0),
363 other => panic!("expected numeric result, got {other:?}"),
364 }
365 assert!(
366 COUNTER_DEFAULT.load(Ordering::SeqCst) >= MIN_SAMPLE_COUNT,
367 "expected at least {} invocations",
368 MIN_SAMPLE_COUNT
369 );
370 }
371
372 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
373 #[test]
374 fn timeit_accepts_num_outputs_argument() {
375 COUNTER_NUM_OUTPUTS.store(0, Ordering::SeqCst);
376 let args = vec![Value::Int(IntValue::I32(3))];
377 let _ = block_on(timeit_builtin(outputs_handle(), args)).expect("timeit numOutputs");
378 assert!(
379 COUNTER_NUM_OUTPUTS.load(Ordering::SeqCst) >= MIN_SAMPLE_COUNT,
380 "expected at least {} invocations",
381 MIN_SAMPLE_COUNT
382 );
383 }
384
385 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
386 #[test]
387 fn timeit_supports_zero_outputs() {
388 COUNTER_ZERO_OUTPUTS.store(0, Ordering::SeqCst);
389 let args = vec![Value::Int(IntValue::I32(0))];
390 let _ = block_on(timeit_builtin(zero_outputs_handle(), args)).expect("timeit zero outputs");
391 assert!(
392 COUNTER_ZERO_OUTPUTS.load(Ordering::SeqCst) >= MIN_SAMPLE_COUNT,
393 "expected at least {} invocations",
394 MIN_SAMPLE_COUNT
395 );
396 }
397
398 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
399 #[test]
400 #[cfg(feature = "wgpu")]
401 fn timeit_runs_with_wgpu_provider_registered() {
402 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
403 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
404 );
405 let result =
406 block_on(timeit_builtin(default_handle(), Vec::new())).expect("timeit with wgpu");
407 match result {
408 Value::Num(v) => assert!(v >= 0.0),
409 other => panic!("expected numeric result, got {other:?}"),
410 }
411 }
412
413 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
414 #[test]
415 fn timeit_rejects_non_function_input() {
416 let err = block_on(timeit_builtin(Value::Num(1.0), Vec::new())).unwrap_err();
417 assert_timeit_error_contains(err, "function");
418 }
419
420 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
421 #[test]
422 fn timeit_rejects_invalid_num_outputs() {
423 COUNTER_INVALID.store(0, Ordering::SeqCst);
424 let err = block_on(timeit_builtin(invalid_handle(), vec![Value::Num(-1.0)])).unwrap_err();
425 assert_timeit_error_contains(err, "nonnegative");
426 assert_eq!(COUNTER_INVALID.load(Ordering::SeqCst), 0);
427 }
428
429 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
430 #[test]
431 fn timeit_rejects_extra_arguments() {
432 let err = block_on(timeit_builtin(
433 default_handle(),
434 vec![Value::from(1.0), Value::from(2.0)],
435 ))
436 .unwrap_err();
437 assert_timeit_error_contains(err, "too many");
438 }
439}