Skip to main content

runmat_runtime/builtins/timing/
timeit.rs

1//! MATLAB-compatible `timeit` builtin for RunMat.
2//!
3//! Measures the execution time of zero-input function handles by running them
4//! repeatedly and returning the median per-invocation runtime in seconds.
5
6use runmat_time::Instant;
7use std::cmp::Ordering;
8
9use runmat_builtins::Value;
10use runmat_macros::runtime_builtin;
11
12use crate::builtins::common::spec::{
13    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14    ReductionNaN, ResidencyPolicy, ShapeRequirements,
15};
16use crate::builtins::timing::type_resolvers::timeit_type;
17
18const TARGET_BATCH_SECONDS: f64 = 0.005;
19const MAX_BATCH_SECONDS: f64 = 0.25;
20const LOOP_COUNT_LIMIT: usize = 1 << 20;
21const MIN_SAMPLE_COUNT: usize = 7;
22const MAX_SAMPLE_COUNT: usize = 21;
23const BUILTIN_NAME: &str = "timeit";
24
25fn timeit_error(message: impl Into<String>) -> crate::RuntimeError {
26    crate::build_runtime_error(message)
27        .with_builtin(BUILTIN_NAME)
28        .build()
29}
30
31#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::timing::timeit")]
32pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
33    name: "timeit",
34    op_kind: GpuOpKind::Custom("timer"),
35    supported_precisions: &[],
36    broadcast: BroadcastSemantics::None,
37    provider_hooks: &[],
38    constant_strategy: ConstantStrategy::InlineLiteral,
39    residency: ResidencyPolicy::GatherImmediately,
40    nan_mode: ReductionNaN::Include,
41    two_pass_threshold: None,
42    workgroup_size: None,
43    accepts_nan_mode: false,
44    notes: "Host-side helper; GPU kernels execute only if invoked by the timed function.",
45};
46
47#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::timing::timeit")]
48pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
49    name: "timeit",
50    shape: ShapeRequirements::Any,
51    constant_strategy: ConstantStrategy::InlineLiteral,
52    elementwise: None,
53    reduction: None,
54    emits_nan: false,
55    notes: "Timing helper; excluded from fusion planning.",
56};
57
58#[runtime_builtin(
59    name = "timeit",
60    category = "timing",
61    summary = "Measure the execution time of a zero-argument function handle.",
62    keywords = "timeit,benchmark,timing,performance,gpu",
63    accel = "helper",
64    type_resolver(timeit_type),
65    builtin_path = "crate::builtins::timing::timeit"
66)]
67async fn timeit_builtin(func: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
68    let requested_outputs = parse_num_outputs(&rest)?;
69    let callable = prepare_callable(func, requested_outputs)?;
70
71    // Warm-up once to catch early errors and pay one-time JIT costs.
72    callable.invoke().await?;
73
74    let loop_count = determine_loop_count(&callable).await?;
75    let samples = collect_samples(&callable, loop_count).await?;
76    if samples.is_empty() {
77        return Ok(Value::Num(0.0));
78    }
79
80    Ok(Value::Num(compute_median(samples)))
81}
82
83fn parse_num_outputs(rest: &[Value]) -> Result<Option<usize>, crate::RuntimeError> {
84    match rest.len() {
85        0 => Ok(None),
86        1 => parse_non_negative_integer(&rest[0]).map(Some),
87        _ => Err(timeit_error("timeit: too many input arguments")),
88    }
89}
90
91fn parse_non_negative_integer(value: &Value) -> Result<usize, crate::RuntimeError> {
92    match value {
93        Value::Int(iv) => {
94            let raw = iv.to_i64();
95            if raw < 0 {
96                Err(timeit_error(
97                    "timeit: numOutputs must be a nonnegative integer",
98                ))
99            } else {
100                Ok(raw as usize)
101            }
102        }
103        Value::Num(n) => {
104            if !n.is_finite() {
105                return Err(timeit_error("timeit: numOutputs must be finite"));
106            }
107            if *n < 0.0 {
108                return Err(timeit_error(
109                    "timeit: numOutputs must be a nonnegative integer",
110                ));
111            }
112            let rounded = n.round();
113            if (rounded - n).abs() > f64::EPSILON {
114                return Err(timeit_error("timeit: numOutputs must be an integer value"));
115            }
116            Ok(rounded as usize)
117        }
118        _ => Err(timeit_error(
119            "timeit: numOutputs must be a scalar numeric value",
120        )),
121    }
122}
123
124async fn determine_loop_count(callable: &TimeitCallable) -> Result<usize, crate::RuntimeError> {
125    let mut loops = 1usize;
126    loop {
127        let elapsed = run_batch(callable, loops).await?;
128        if elapsed >= TARGET_BATCH_SECONDS
129            || elapsed >= MAX_BATCH_SECONDS
130            || loops >= LOOP_COUNT_LIMIT
131        {
132            return Ok(loops);
133        }
134        loops = loops.saturating_mul(2);
135        if loops == 0 {
136            return Ok(LOOP_COUNT_LIMIT);
137        }
138    }
139}
140
141async fn collect_samples(
142    callable: &TimeitCallable,
143    loop_count: usize,
144) -> Result<Vec<f64>, crate::RuntimeError> {
145    let mut samples = Vec::with_capacity(MIN_SAMPLE_COUNT);
146    while samples.len() < MIN_SAMPLE_COUNT {
147        let elapsed = run_batch(callable, loop_count).await?;
148        let per_iter = elapsed / loop_count as f64;
149        samples.push(per_iter);
150        if samples.len() >= MAX_SAMPLE_COUNT || elapsed >= MAX_BATCH_SECONDS {
151            break;
152        }
153    }
154    Ok(samples)
155}
156
157async fn run_batch(
158    callable: &TimeitCallable,
159    loop_count: usize,
160) -> Result<f64, crate::RuntimeError> {
161    let start = Instant::now();
162    for _ in 0..loop_count {
163        let value = callable.invoke().await?;
164        drop(value);
165    }
166    Ok(start.elapsed().as_secs_f64())
167}
168
169fn compute_median(mut samples: Vec<f64>) -> f64 {
170    if samples.is_empty() {
171        return 0.0;
172    }
173    samples.sort_by(|a, b| match (a.is_nan(), b.is_nan()) {
174        (true, true) => Ordering::Equal,
175        (true, false) => Ordering::Greater,
176        (false, true) => Ordering::Less,
177        (false, false) => a.partial_cmp(b).unwrap_or_else(|| {
178            if a < b {
179                Ordering::Less
180            } else {
181                Ordering::Greater
182            }
183        }),
184    });
185    let mid = samples.len() / 2;
186    if samples.len() % 2 == 1 {
187        samples[mid]
188    } else {
189        (samples[mid - 1] + samples[mid]) * 0.5
190    }
191}
192
193#[derive(Clone)]
194struct TimeitCallable {
195    handle: Value,
196    num_outputs: Option<usize>,
197}
198
199impl TimeitCallable {
200    async fn invoke(&self) -> Result<Value, crate::RuntimeError> {
201        // The runtime currently treats all builtin invocations as returning a single `Value`.
202        // The optional `num_outputs` flag is stored so future multi-output support can
203        // request the correct number of outputs when dispatching through `feval`.
204        // For now, we invoke the handle normally and drop whatever value is produced.
205        if let Some(0) = self.num_outputs {
206            let value =
207                crate::call_builtin_async("feval", std::slice::from_ref(&self.handle)).await?;
208            drop(value);
209            Ok(Value::Num(0.0))
210        } else {
211            Ok(crate::call_builtin_async("feval", std::slice::from_ref(&self.handle)).await?)
212        }
213    }
214}
215
216fn prepare_callable(
217    func: Value,
218    num_outputs: Option<usize>,
219) -> Result<TimeitCallable, crate::RuntimeError> {
220    match func {
221        Value::String(text) => parse_handle_string(&text).map(|handle| TimeitCallable {
222            handle: Value::String(handle),
223            num_outputs,
224        }),
225        Value::CharArray(arr) => {
226            if arr.rows != 1 {
227                Err(timeit_error(
228                    "timeit: function handle must be a string scalar or function handle",
229                ))
230            } else {
231                let text: String = arr.data.iter().collect();
232                parse_handle_string(&text).map(|handle| TimeitCallable {
233                    handle: Value::String(handle),
234                    num_outputs,
235                })
236            }
237        }
238        Value::StringArray(sa) => {
239            if sa.data.len() == 1 {
240                parse_handle_string(&sa.data[0]).map(|handle| TimeitCallable {
241                    handle: Value::String(handle),
242                    num_outputs,
243                })
244            } else {
245                Err(timeit_error(
246                    "timeit: function handle must be a string scalar or function handle",
247                ))
248            }
249        }
250        Value::FunctionHandle(name) => Ok(TimeitCallable {
251            handle: Value::String(format!("@{name}")),
252            num_outputs,
253        }),
254        Value::Closure(closure) => Ok(TimeitCallable {
255            handle: Value::Closure(closure),
256            num_outputs,
257        }),
258        other => Err(timeit_error(format!(
259            "timeit: first argument must be a function handle, got {other:?}"
260        ))),
261    }
262}
263
264fn parse_handle_string(text: &str) -> Result<String, crate::RuntimeError> {
265    let trimmed = text.trim();
266    if let Some(rest) = trimmed.strip_prefix('@') {
267        if rest.trim().is_empty() {
268            Err(timeit_error("timeit: empty function handle string"))
269        } else {
270            Ok(format!("@{}", rest.trim()))
271        }
272    } else {
273        Err(timeit_error(
274            "timeit: expected a function handle string beginning with '@'",
275        ))
276    }
277}
278
279#[cfg(test)]
280pub(crate) mod tests {
281    use super::*;
282    use futures::executor::block_on;
283    use runmat_builtins::IntValue;
284    use std::sync::atomic::{AtomicUsize, Ordering};
285
286    static COUNTER_DEFAULT: AtomicUsize = AtomicUsize::new(0);
287    static COUNTER_NUM_OUTPUTS: AtomicUsize = AtomicUsize::new(0);
288    static COUNTER_INVALID: AtomicUsize = AtomicUsize::new(0);
289    static COUNTER_ZERO_OUTPUTS: AtomicUsize = AtomicUsize::new(0);
290
291    #[runtime_builtin(
292        name = "__timeit_helper_counter_default",
293        type_resolver(crate::builtins::timing::type_resolvers::timeit_type),
294        builtin_path = "crate::builtins::timing::timeit::tests"
295    )]
296    async fn helper_counter_default() -> crate::BuiltinResult<Value> {
297        COUNTER_DEFAULT.fetch_add(1, Ordering::SeqCst);
298        Ok(Value::Num(1.0))
299    }
300
301    #[runtime_builtin(
302        name = "__timeit_helper_counter_outputs",
303        type_resolver(crate::builtins::timing::type_resolvers::timeit_type),
304        builtin_path = "crate::builtins::timing::timeit::tests"
305    )]
306    async fn helper_counter_outputs() -> crate::BuiltinResult<Value> {
307        COUNTER_NUM_OUTPUTS.fetch_add(1, Ordering::SeqCst);
308        Ok(Value::Num(1.0))
309    }
310
311    #[runtime_builtin(
312        name = "__timeit_helper_counter_invalid",
313        type_resolver(crate::builtins::timing::type_resolvers::timeit_type),
314        builtin_path = "crate::builtins::timing::timeit::tests"
315    )]
316    async fn helper_counter_invalid() -> crate::BuiltinResult<Value> {
317        COUNTER_INVALID.fetch_add(1, Ordering::SeqCst);
318        Ok(Value::Num(1.0))
319    }
320
321    #[runtime_builtin(
322        name = "__timeit_helper_zero_outputs",
323        type_resolver(crate::builtins::timing::type_resolvers::timeit_type),
324        builtin_path = "crate::builtins::timing::timeit::tests"
325    )]
326    async fn helper_counter_zero_outputs() -> crate::BuiltinResult<Value> {
327        COUNTER_ZERO_OUTPUTS.fetch_add(1, Ordering::SeqCst);
328        Ok(Value::Num(0.0))
329    }
330
331    fn default_handle() -> Value {
332        Value::String("@__timeit_helper_counter_default".to_string())
333    }
334
335    fn assert_timeit_error_contains(err: crate::RuntimeError, needle: &str) {
336        let message = err.message().to_ascii_lowercase();
337        assert!(
338            message.contains(&needle.to_ascii_lowercase()),
339            "unexpected error text: {}",
340            err.message()
341        );
342    }
343
344    fn outputs_handle() -> Value {
345        Value::String("@__timeit_helper_counter_outputs".to_string())
346    }
347
348    fn invalid_handle() -> Value {
349        Value::String("@__timeit_helper_counter_invalid".to_string())
350    }
351
352    fn zero_outputs_handle() -> Value {
353        Value::String("@__timeit_helper_zero_outputs".to_string())
354    }
355
356    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
357    #[test]
358    fn timeit_measures_time() {
359        COUNTER_DEFAULT.store(0, Ordering::SeqCst);
360        let result = block_on(timeit_builtin(default_handle(), Vec::new())).expect("timeit");
361        match result {
362            Value::Num(v) => assert!(v >= 0.0),
363            other => panic!("expected numeric result, got {other:?}"),
364        }
365        assert!(
366            COUNTER_DEFAULT.load(Ordering::SeqCst) >= MIN_SAMPLE_COUNT,
367            "expected at least {} invocations",
368            MIN_SAMPLE_COUNT
369        );
370    }
371
372    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
373    #[test]
374    fn timeit_accepts_num_outputs_argument() {
375        COUNTER_NUM_OUTPUTS.store(0, Ordering::SeqCst);
376        let args = vec![Value::Int(IntValue::I32(3))];
377        let _ = block_on(timeit_builtin(outputs_handle(), args)).expect("timeit numOutputs");
378        assert!(
379            COUNTER_NUM_OUTPUTS.load(Ordering::SeqCst) >= MIN_SAMPLE_COUNT,
380            "expected at least {} invocations",
381            MIN_SAMPLE_COUNT
382        );
383    }
384
385    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
386    #[test]
387    fn timeit_supports_zero_outputs() {
388        COUNTER_ZERO_OUTPUTS.store(0, Ordering::SeqCst);
389        let args = vec![Value::Int(IntValue::I32(0))];
390        let _ = block_on(timeit_builtin(zero_outputs_handle(), args)).expect("timeit zero outputs");
391        assert!(
392            COUNTER_ZERO_OUTPUTS.load(Ordering::SeqCst) >= MIN_SAMPLE_COUNT,
393            "expected at least {} invocations",
394            MIN_SAMPLE_COUNT
395        );
396    }
397
398    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
399    #[test]
400    #[cfg(feature = "wgpu")]
401    fn timeit_runs_with_wgpu_provider_registered() {
402        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
403            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
404        );
405        let result =
406            block_on(timeit_builtin(default_handle(), Vec::new())).expect("timeit with wgpu");
407        match result {
408            Value::Num(v) => assert!(v >= 0.0),
409            other => panic!("expected numeric result, got {other:?}"),
410        }
411    }
412
413    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
414    #[test]
415    fn timeit_rejects_non_function_input() {
416        let err = block_on(timeit_builtin(Value::Num(1.0), Vec::new())).unwrap_err();
417        assert_timeit_error_contains(err, "function");
418    }
419
420    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
421    #[test]
422    fn timeit_rejects_invalid_num_outputs() {
423        COUNTER_INVALID.store(0, Ordering::SeqCst);
424        let err = block_on(timeit_builtin(invalid_handle(), vec![Value::Num(-1.0)])).unwrap_err();
425        assert_timeit_error_contains(err, "nonnegative");
426        assert_eq!(COUNTER_INVALID.load(Ordering::SeqCst), 0);
427    }
428
429    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
430    #[test]
431    fn timeit_rejects_extra_arguments() {
432        let err = block_on(timeit_builtin(
433            default_handle(),
434            vec![Value::from(1.0), Value::from(2.0)],
435        ))
436        .unwrap_err();
437        assert_timeit_error_contains(err, "too many");
438    }
439}