use once_cell::sync::Lazy;
use runmat_builtins::{
BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
};
use runmat_macros::runtime_builtin;
use runmat_time::Instant;
use std::sync::Mutex;
use std::time::Duration;
use crate::builtins::common::spec::{
BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
ReductionNaN, ResidencyPolicy, ShapeRequirements,
};
use crate::builtins::timing::type_resolvers::tic_type;
#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::timing::tic")]
pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
name: "tic",
op_kind: GpuOpKind::Custom("timer"),
supported_precisions: &[],
broadcast: BroadcastSemantics::None,
provider_hooks: &[],
constant_strategy: ConstantStrategy::InlineLiteral,
residency: ResidencyPolicy::GatherImmediately,
nan_mode: ReductionNaN::Include,
two_pass_threshold: None,
workgroup_size: None,
accepts_nan_mode: false,
notes: "Stopwatch state lives on the host. Providers are never consulted for tic/toc.",
};
#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::timing::tic")]
pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
name: "tic",
shape: ShapeRequirements::Any,
constant_strategy: ConstantStrategy::InlineLiteral,
elementwise: None,
reduction: None,
emits_nan: false,
notes: "Timing builtins are executed eagerly on the host and do not participate in fusion.",
};
static MONOTONIC_ORIGIN: Lazy<Instant> = Lazy::new(Instant::now);
static STOPWATCH: Lazy<Mutex<StopwatchState>> = Lazy::new(|| Mutex::new(StopwatchState::default()));
#[cfg(test)]
pub(crate) static TEST_GUARD: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
#[derive(Default)]
struct StopwatchState {
stack: Vec<Instant>,
}
impl StopwatchState {
fn push(&mut self, instant: Instant) {
self.stack.push(instant);
}
fn pop(&mut self) -> Option<Instant> {
self.stack.pop()
}
}
const BUILTIN_NAME: &str = "tic";
const TIC_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
name: "timerVal",
ty: BuiltinParamType::NumericScalar,
arity: BuiltinParamArity::Required,
default: None,
description: "Timer handle used by toc.",
}];
const TIC_INPUTS: [BuiltinParamDescriptor; 0] = [];
const TIC_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
label: "timerVal = tic()",
inputs: &TIC_INPUTS,
outputs: &TIC_OUTPUT,
}];
const TIC_ERROR_STATE_LOCK: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
code: "RM.TIC.STATE_LOCK",
identifier: Some("RunMat:tic:StateLockFailed"),
when: "Internal stopwatch state cannot be acquired.",
message: "tic: failed to acquire stopwatch state",
};
const TIC_ERRORS: [BuiltinErrorDescriptor; 1] = [TIC_ERROR_STATE_LOCK];
pub const TIC_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
signatures: &TIC_SIGNATURES,
output_mode: BuiltinOutputMode::Fixed,
completion_policy: BuiltinCompletionPolicy::Public,
errors: &TIC_ERRORS,
};
fn stopwatch_error_with_message(
builtin: &str,
message: impl Into<String>,
error: &BuiltinErrorDescriptor,
) -> crate::RuntimeError {
let mut builder = crate::build_runtime_error(message).with_builtin(builtin);
if let Some(identifier) = error.identifier {
builder = builder.with_identifier(identifier);
}
builder.build()
}
#[runtime_builtin(
name = "tic",
category = "timing",
summary = "Start a high-resolution stopwatch and optionally return a toc handle.",
keywords = "tic,timing,profiling,benchmark",
sink = true,
type_resolver(tic_type),
descriptor(crate::builtins::timing::tic::TIC_DESCRIPTOR),
builtin_path = "crate::builtins::timing::tic"
)]
pub async fn tic_builtin() -> crate::BuiltinResult<f64> {
record_tic(BUILTIN_NAME)
}
pub(crate) fn record_tic(builtin: &str) -> Result<f64, crate::RuntimeError> {
let now = Instant::now();
{
let mut guard = STOPWATCH.lock().map_err(|_| {
stopwatch_error_with_message(
builtin,
TIC_ERROR_STATE_LOCK.message,
&TIC_ERROR_STATE_LOCK,
)
})?;
guard.push(now);
}
Ok(encode_instant(now))
}
pub(crate) fn take_latest_start(builtin: &str) -> Result<Option<Instant>, crate::RuntimeError> {
let mut guard = STOPWATCH.lock().map_err(|_| {
stopwatch_error_with_message(builtin, TIC_ERROR_STATE_LOCK.message, &TIC_ERROR_STATE_LOCK)
})?;
Ok(guard.pop())
}
pub(crate) fn encode_instant(instant: Instant) -> f64 {
instant.duration_since(*MONOTONIC_ORIGIN).as_secs_f64()
}
pub(crate) fn decode_handle(
handle: f64,
builtin: &str,
error: &BuiltinErrorDescriptor,
) -> Result<Instant, crate::RuntimeError> {
if !handle.is_finite() || handle.is_sign_negative() {
return Err(stopwatch_error_with_message(builtin, error.message, error));
}
let duration = Duration::from_secs_f64(handle);
Ok((*MONOTONIC_ORIGIN) + duration)
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use futures::executor::block_on;
use std::thread;
use std::time::Duration;
const TEST_INVALID_HANDLE_ERROR: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
code: "RM.TOC.INVALID_HANDLE",
identifier: Some("RunMat:toc:InvalidTimerHandle"),
when: "The timer handle is non-finite or negative.",
message: "toc: invalid timer handle",
};
fn reset_stopwatch() {
let mut guard = STOPWATCH.lock().unwrap();
guard.stack.clear();
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn tic_returns_monotonic_handle() {
let _guard = TEST_GUARD.lock().unwrap();
reset_stopwatch();
let handle = block_on(tic_builtin()).expect("tic");
assert!(handle >= 0.0);
assert!(take_latest_start(BUILTIN_NAME).expect("take").is_some());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn tic_handles_increase_over_time() {
let _guard = TEST_GUARD.lock().unwrap();
reset_stopwatch();
let first = block_on(tic_builtin()).expect("tic");
thread::sleep(Duration::from_millis(5));
let second = block_on(tic_builtin()).expect("tic");
assert!(second > first);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn decode_roundtrip_matches_handle() {
let _guard = TEST_GUARD.lock().unwrap();
reset_stopwatch();
let handle = block_on(tic_builtin()).expect("tic");
let decoded = decode_handle(handle, "toc", &TEST_INVALID_HANDLE_ERROR).expect("decode");
let round_trip = encode_instant(decoded);
let delta = (round_trip - handle).abs();
assert!(delta < 1e-9, "delta {delta}");
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn take_latest_start_pops_stack() {
let _guard = TEST_GUARD.lock().unwrap();
reset_stopwatch();
block_on(tic_builtin()).expect("tic");
assert!(take_latest_start(BUILTIN_NAME).expect("take").is_some());
assert!(take_latest_start(BUILTIN_NAME)
.expect("second take")
.is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn decode_handle_rejects_invalid_values() {
let _guard = TEST_GUARD.lock().unwrap();
assert!(decode_handle(f64::NAN, "toc", &TEST_INVALID_HANDLE_ERROR).is_err());
assert!(decode_handle(-1.0, "toc", &TEST_INVALID_HANDLE_ERROR).is_err());
}
}