Skip to main content

runmat_runtime/builtins/timing/
pause.rs

1//! MATLAB-compatible `pause` builtin that temporarily suspends execution.
2
3use once_cell::sync::Lazy;
4use runmat_builtins::{CharArray, LogicalArray, Tensor, Value};
5use runmat_macros::runtime_builtin;
6use std::sync::RwLock;
7
8use crate::builtins::common::gpu_helpers;
9use crate::builtins::common::spec::{
10    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11    ReductionNaN, ResidencyPolicy, ShapeRequirements,
12};
13#[cfg(all(target_arch = "wasm32", feature = "plot-web"))]
14use crate::builtins::plotting;
15use crate::builtins::timing::type_resolvers::pause_type;
16#[cfg(not(test))]
17use crate::interaction;
18use crate::{build_runtime_error, BuiltinResult, RuntimeError};
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::timing::pause")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22    name: "pause",
23    op_kind: GpuOpKind::Custom("timer"),
24    supported_precisions: &[],
25    broadcast: BroadcastSemantics::None,
26    provider_hooks: &[],
27    constant_strategy: ConstantStrategy::InlineLiteral,
28    residency: ResidencyPolicy::GatherImmediately,
29    nan_mode: ReductionNaN::Include,
30    two_pass_threshold: None,
31    workgroup_size: None,
32    accepts_nan_mode: false,
33    notes: "pause executes entirely on the host. Acceleration providers are never queried.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::timing::pause")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38    name: "pause",
39    shape: ShapeRequirements::Any,
40    constant_strategy: ConstantStrategy::InlineLiteral,
41    elementwise: None,
42    reduction: None,
43    emits_nan: false,
44    notes: "pause suspends host execution and is excluded from fusion pipelines.",
45};
46
47static PAUSE_STATE: Lazy<RwLock<PauseState>> = Lazy::new(|| RwLock::new(PauseState::default()));
48
49#[cfg(test)]
50use std::sync::Mutex;
51#[cfg(test)]
52pub(crate) static TEST_GUARD: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
53
54#[derive(Debug, Clone, Copy)]
55struct PauseState {
56    enabled: bool,
57}
58
59impl Default for PauseState {
60    fn default() -> Self {
61        Self { enabled: true }
62    }
63}
64
65const BUILTIN_NAME: &str = "pause";
66const ERR_INVALID_ARG: &str = "RunMat:pause:InvalidInputArgument";
67const ERR_TOO_MANY_INPUTS: &str = "RunMat:pause:TooManyInputs";
68const MSG_INVALID_ARG: &str = "pause: invalid input argument";
69const MSG_TOO_MANY_INPUTS: &str = "pause: too many input arguments";
70const MSG_STATE_LOCK: &str = "pause: failed to acquire pause state";
71
72fn pause_error(message: impl Into<String>) -> RuntimeError {
73    build_runtime_error(message)
74        .with_builtin(BUILTIN_NAME)
75        .build()
76}
77
78fn pause_error_with_identifier(message: impl Into<String>, identifier: &str) -> RuntimeError {
79    build_runtime_error(message)
80        .with_builtin(BUILTIN_NAME)
81        .with_identifier(identifier)
82        .build()
83}
84
85#[derive(Debug, Clone, Copy)]
86enum PauseArgument {
87    Wait(PauseWait),
88    SetState(bool),
89    Query,
90}
91
92#[derive(Debug, Clone, Copy)]
93enum PauseWait {
94    Default,
95    Seconds(f64),
96}
97
98/// Suspend execution according to MATLAB-compatible pause semantics.
99#[runtime_builtin(
100    name = "pause",
101    category = "timing",
102    summary = "Suspend execution until a key press or specified duration.",
103    keywords = "pause,sleep,wait,delay",
104    accel = "metadata",
105    sink = true,
106    type_resolver(pause_type),
107    builtin_path = "crate::builtins::timing::pause"
108)]
109async fn pause_builtin(args: Vec<Value>) -> BuiltinResult<Value> {
110    match args.len() {
111        0 => {
112            perform_wait(PauseWait::Default).await?;
113            Ok(empty_return_value())
114        }
115        1 => match classify_argument(&args[0]).await? {
116            PauseArgument::Wait(wait) => {
117                perform_wait(wait).await?;
118                Ok(empty_return_value())
119            }
120            PauseArgument::SetState(next_state) => {
121                let previous = set_pause_enabled(next_state)?;
122                Ok(state_value(previous))
123            }
124            PauseArgument::Query => {
125                let current = pause_enabled()?;
126                Ok(state_value(current))
127            }
128        },
129        _ => Err(pause_error_with_identifier(
130            MSG_TOO_MANY_INPUTS,
131            ERR_TOO_MANY_INPUTS,
132        )),
133    }
134}
135
136async fn perform_wait(wait: PauseWait) -> Result<(), RuntimeError> {
137    if !pause_enabled()? {
138        return Ok(());
139    }
140
141    #[cfg(all(target_arch = "wasm32", feature = "plot-web"))]
142    {
143        // MATLAB semantics: `pause` gives the UI a chance to update.
144        // In RunMat Web/WASM this is an explicit flush boundary for plotting.
145        let handle = plotting::current_figure_handle();
146        // Present before the wait.
147        let _ = plotting::render_current_scene(handle.as_u32());
148    }
149
150    match wait {
151        PauseWait::Default => wait_for_key_press().await,
152        PauseWait::Seconds(seconds) => {
153            if seconds == 0.0 {
154                // `pause(0)` is a useful yield point in simulation loops.
155                #[cfg(target_arch = "wasm32")]
156                {
157                    return wasm_sleep_seconds(0.0).await;
158                }
159                #[cfg(not(target_arch = "wasm32"))]
160                {
161                    return Ok(());
162                }
163            }
164            sleep_seconds(seconds).await?;
165            #[cfg(all(target_arch = "wasm32", feature = "plot-web"))]
166            {
167                // Present again after the wait to ensure the compositor sees the most recent frame.
168                // Some browser/driver combinations appear to delay presentation unless we yield across
169                // a timer boundary.
170                let handle = plotting::current_figure_handle();
171                let _ = plotting::render_current_scene(handle.as_u32());
172            }
173            Ok(())
174        }
175    }
176}
177
178async fn wait_for_key_press() -> Result<(), RuntimeError> {
179    #[cfg(test)]
180    {
181        Ok(())
182    }
183    #[cfg(not(test))]
184    {
185        interaction::wait_for_key_async("").await
186    }
187}
188
189async fn sleep_seconds(seconds: f64) -> Result<(), RuntimeError> {
190    #[cfg(target_arch = "wasm32")]
191    {
192        wasm_sleep_seconds(seconds).await
193    }
194    #[cfg(not(target_arch = "wasm32"))]
195    {
196        // from_secs_f64 rejects NaN/±Inf; classify_argument filters those earlier.
197        let duration = std::time::Duration::from_secs_f64(seconds);
198        std::thread::sleep(duration);
199        Ok(())
200    }
201}
202
203#[cfg(target_arch = "wasm32")]
204async fn wasm_sleep_seconds(seconds: f64) -> Result<(), RuntimeError> {
205    use js_sys::{Function, Promise, Reflect};
206    use wasm_bindgen::JsCast;
207    use wasm_bindgen_futures::JsFuture;
208
209    // `pause` runs in both Window and WebWorker contexts; workers do not have `window`.
210    // Use the global `setTimeout` function instead.
211    let global = js_sys::global();
212    let set_timeout = Reflect::get(&global, &wasm_bindgen::JsValue::from_str("setTimeout"))
213        .map_err(|_| build_runtime_error("pause: setTimeout unavailable").build())?
214        .dyn_into::<Function>()
215        .map_err(|_| build_runtime_error("pause: setTimeout unavailable").build())?;
216
217    let millis = (seconds * 1000.0).max(0.0).round();
218    let millis_i32 = if millis > i32::MAX as f64 {
219        i32::MAX
220    } else {
221        millis as i32
222    };
223
224    let promise = Promise::new(&mut |resolve, _reject| {
225        let resolve: Function = resolve.unchecked_into();
226        let _ = set_timeout.call2(
227            &global,
228            &resolve.into(),
229            &wasm_bindgen::JsValue::from_f64(millis_i32 as f64),
230        );
231    });
232
233    let _ = JsFuture::from(promise)
234        .await
235        .map_err(|err| build_runtime_error(format!("pause: timer failed ({err:?})")).build())?;
236    Ok(())
237}
238
239async fn classify_argument(arg: &Value) -> Result<PauseArgument, RuntimeError> {
240    let host_value = gpu_helpers::gather_value_async(arg)
241        .await
242        .map_err(|e| pause_error(format!("pause: {e}")))?;
243    match host_value {
244        Value::String(text) => parse_command(&text),
245        Value::CharArray(ca) => {
246            if ca.rows == 0 || ca.data.is_empty() {
247                Ok(PauseArgument::Wait(PauseWait::Default))
248            } else if ca.rows == 1 {
249                let text: String = ca.data.iter().collect();
250                parse_command(&text)
251            } else {
252                Err(pause_error_with_identifier(
253                    MSG_INVALID_ARG,
254                    ERR_INVALID_ARG,
255                ))
256            }
257        }
258        Value::StringArray(sa) => {
259            if sa.data.is_empty() {
260                Ok(PauseArgument::Wait(PauseWait::Default))
261            } else if sa.data.len() == 1 {
262                parse_command(&sa.data[0])
263            } else {
264                Err(pause_error_with_identifier(
265                    MSG_INVALID_ARG,
266                    ERR_INVALID_ARG,
267                ))
268            }
269        }
270        Value::Num(value) => parse_numeric(value),
271        Value::Int(int_value) => parse_numeric(int_value.to_f64()),
272        Value::Bool(flag) => parse_numeric(if flag { 1.0 } else { 0.0 }),
273        Value::Tensor(tensor) => parse_tensor(tensor),
274        Value::LogicalArray(logical) => parse_logical(logical),
275        Value::GpuTensor(handle) => {
276            let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
277            parse_tensor(tensor)
278        }
279        Value::Complex(_, _)
280        | Value::ComplexTensor(_)
281        | Value::Cell(_)
282        | Value::Struct(_)
283        | Value::Object(_)
284        | Value::HandleObject(_)
285        | Value::Listener(_)
286        | Value::FunctionHandle(_)
287        | Value::Closure(_)
288        | Value::ClassRef(_)
289        | Value::MException(_)
290        | Value::OutputList(_) => Err(pause_error_with_identifier(
291            MSG_INVALID_ARG,
292            ERR_INVALID_ARG,
293        )),
294    }
295}
296
297fn parse_command(raw: &str) -> Result<PauseArgument, RuntimeError> {
298    let trimmed = raw.trim();
299    if trimmed.is_empty() {
300        return Ok(PauseArgument::Wait(PauseWait::Default));
301    }
302    let lower = trimmed.to_ascii_lowercase();
303    match lower.as_str() {
304        "on" => Ok(PauseArgument::SetState(true)),
305        "off" => Ok(PauseArgument::SetState(false)),
306        "query" => Ok(PauseArgument::Query),
307        _ => Err(pause_error_with_identifier(
308            MSG_INVALID_ARG,
309            ERR_INVALID_ARG,
310        )),
311    }
312}
313
314fn parse_numeric(value: f64) -> Result<PauseArgument, RuntimeError> {
315    if !value.is_finite() {
316        if value.is_sign_positive() {
317            return Ok(PauseArgument::Wait(PauseWait::Default));
318        }
319        return Err(pause_error_with_identifier(
320            MSG_INVALID_ARG,
321            ERR_INVALID_ARG,
322        ));
323    }
324    if value < 0.0 {
325        return Err(pause_error_with_identifier(
326            MSG_INVALID_ARG,
327            ERR_INVALID_ARG,
328        ));
329    }
330    Ok(PauseArgument::Wait(PauseWait::Seconds(value)))
331}
332
333fn parse_tensor(tensor: Tensor) -> Result<PauseArgument, RuntimeError> {
334    if tensor.data.is_empty() {
335        return Ok(PauseArgument::Wait(PauseWait::Default));
336    }
337    if tensor.data.len() != 1 {
338        return Err(pause_error_with_identifier(
339            MSG_INVALID_ARG,
340            ERR_INVALID_ARG,
341        ));
342    }
343    parse_numeric(tensor.data[0])
344}
345
346fn parse_logical(logical: LogicalArray) -> Result<PauseArgument, RuntimeError> {
347    if logical.data.is_empty() {
348        return Ok(PauseArgument::Wait(PauseWait::Default));
349    }
350    if logical.data.len() != 1 {
351        return Err(pause_error_with_identifier(
352            MSG_INVALID_ARG,
353            ERR_INVALID_ARG,
354        ));
355    }
356    let scalar = if logical.data[0] != 0 { 1.0 } else { 0.0 };
357    parse_numeric(scalar)
358}
359
360fn empty_return_value() -> Value {
361    Value::Tensor(Tensor::zeros(vec![0, 0]))
362}
363
364fn state_value(enabled: bool) -> Value {
365    let text = if enabled { "on" } else { "off" };
366    Value::CharArray(CharArray::new_row(text))
367}
368
369fn pause_enabled() -> Result<bool, RuntimeError> {
370    PAUSE_STATE
371        .read()
372        .map(|guard| guard.enabled)
373        .map_err(|_| pause_error(MSG_STATE_LOCK))
374}
375
376fn set_pause_enabled(next: bool) -> Result<bool, RuntimeError> {
377    let mut guard = PAUSE_STATE
378        .write()
379        .map_err(|_| pause_error(MSG_STATE_LOCK))?;
380    let previous = guard.enabled;
381    guard.enabled = next;
382    Ok(previous)
383}
384
385#[cfg(test)]
386pub(crate) mod tests {
387    use super::*;
388    use crate::builtins::common::test_support;
389    use futures::executor::block_on;
390    use runmat_accelerate_api::HostTensorView;
391    use runmat_builtins::{IntValue, LogicalArray, Tensor};
392
393    #[cfg(feature = "wgpu")]
394    use runmat_accelerate::backend::wgpu::provider as wgpu_provider;
395
396    fn reset_state(enabled: bool) {
397        let mut guard = PAUSE_STATE.write().unwrap_or_else(|e| e.into_inner());
398        guard.enabled = enabled;
399    }
400
401    fn char_array_to_string(value: Value) -> String {
402        match value {
403            Value::CharArray(ca) if ca.rows == 1 => ca.data.iter().collect(),
404            other => panic!("expected char array, got {other:?}"),
405        }
406    }
407
408    fn assert_pause_error_identifier(err: crate::RuntimeError, identifier: &str) {
409        assert_eq!(
410            err.identifier(),
411            Some(identifier),
412            "message: {}",
413            err.message()
414        );
415    }
416
417    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
418    #[test]
419    fn query_returns_on_by_default() {
420        let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
421        reset_state(true);
422        let result = block_on(pause_builtin(vec![Value::from("query")])).expect("pause query");
423        assert_eq!(char_array_to_string(result), "on");
424    }
425
426    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
427    #[test]
428    fn pause_off_returns_previous_state() {
429        let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
430        reset_state(true);
431        let previous = block_on(pause_builtin(vec![Value::from("off")])).expect("pause off");
432        assert_eq!(char_array_to_string(previous), "on");
433        assert!(!pause_enabled().unwrap());
434    }
435
436    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
437    #[test]
438    fn pause_on_restores_state() {
439        let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
440        reset_state(false);
441        let previous = block_on(pause_builtin(vec![Value::from("on")])).expect("pause on");
442        assert_eq!(char_array_to_string(previous), "off");
443        assert!(pause_enabled().unwrap());
444    }
445
446    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
447    #[test]
448    fn pause_default_returns_empty_tensor() {
449        let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
450        reset_state(true);
451        let result = block_on(pause_builtin(Vec::new())).expect("pause()");
452        match result {
453            Value::Tensor(t) => assert_eq!(t.data.len(), 0),
454            other => panic!("expected empty tensor, got {other:?}"),
455        }
456    }
457
458    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
459    #[test]
460    fn numeric_zero_is_accepted() {
461        let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
462        reset_state(true);
463        let result = block_on(pause_builtin(vec![Value::Num(0.0)])).expect("pause(0)");
464        match result {
465            Value::Tensor(t) => assert_eq!(t.data.len(), 0),
466            other => panic!("expected empty tensor, got {other:?}"),
467        }
468    }
469
470    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
471    #[test]
472    fn integer_scalar_is_accepted() {
473        let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
474        reset_state(true);
475        let result =
476            block_on(pause_builtin(vec![Value::Int(IntValue::I32(0))])).expect("pause(int)");
477        match result {
478            Value::Tensor(t) => assert_eq!(t.data.len(), 0),
479            other => panic!("expected empty tensor, got {other:?}"),
480        }
481    }
482
483    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
484    #[test]
485    fn numeric_negative_zero_is_treated_as_zero() {
486        let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
487        reset_state(true);
488        let result = block_on(pause_builtin(vec![Value::Num(-0.0)])).expect("pause(-0)");
489        match result {
490            Value::Tensor(t) => assert_eq!(t.data.len(), 0),
491            other => panic!("expected empty tensor, got {other:?}"),
492        }
493    }
494
495    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
496    #[test]
497    fn negative_duration_raises_error() {
498        let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
499        reset_state(true);
500        let err = block_on(pause_builtin(vec![Value::Num(-0.1)])).unwrap_err();
501        assert_pause_error_identifier(err, ERR_INVALID_ARG);
502    }
503
504    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
505    #[test]
506    fn non_scalar_tensor_is_rejected() {
507        let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
508        reset_state(true);
509        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
510        let err = block_on(pause_builtin(vec![Value::Tensor(tensor)])).unwrap_err();
511        assert_pause_error_identifier(err, ERR_INVALID_ARG);
512    }
513
514    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
515    #[test]
516    fn empty_tensor_behaves_like_default_pause() {
517        let _guard = TEST_GUARD.lock().unwrap();
518        reset_state(true);
519        let empty = Tensor::zeros(vec![0, 0]);
520        let result = block_on(pause_builtin(vec![Value::Tensor(empty)])).expect("pause([])");
521        match result {
522            Value::Tensor(t) => assert_eq!(t.data.len(), 0),
523            other => panic!("expected empty tensor, got {other:?}"),
524        }
525    }
526
527    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
528    #[test]
529    fn logical_scalar_is_accepted() {
530        let _guard = TEST_GUARD.lock().unwrap();
531        reset_state(true);
532        let logical = LogicalArray::new(vec![1u8], vec![1, 1]).unwrap();
533        let result =
534            block_on(pause_builtin(vec![Value::LogicalArray(logical)])).expect("pause(true)");
535        match result {
536            Value::Tensor(t) => assert_eq!(t.data.len(), 0),
537            other => panic!("expected empty tensor, got {other:?}"),
538        }
539    }
540
541    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
542    #[test]
543    fn infinite_duration_behaves_like_default() {
544        let _guard = TEST_GUARD.lock().unwrap();
545        reset_state(true);
546        let result = block_on(pause_builtin(vec![Value::Num(f64::INFINITY)])).expect("pause(Inf)");
547        match result {
548            Value::Tensor(t) => assert_eq!(t.data.len(), 0),
549            other => panic!("expected empty tensor, got {other:?}"),
550        }
551    }
552
553    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
554    #[test]
555    fn pause_gpu_duration_gathered() {
556        let _guard = TEST_GUARD.lock().unwrap();
557        reset_state(true);
558        test_support::with_test_provider(|provider| {
559            let tensor = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
560            let view = HostTensorView {
561                data: &tensor.data,
562                shape: &tensor.shape,
563            };
564            let handle = provider.upload(&view).expect("upload");
565            let result =
566                block_on(pause_builtin(vec![Value::GpuTensor(handle)])).expect("pause(gpuScalar)");
567            match result {
568                Value::Tensor(t) => assert_eq!(t.data.len(), 0),
569                other => panic!("expected empty tensor, got {other:?}"),
570            }
571        });
572    }
573
574    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
575    #[test]
576    #[cfg(feature = "wgpu")]
577    fn pause_wgpu_duration_gathered() {
578        let _guard = TEST_GUARD.lock().unwrap();
579        reset_state(true);
580        if wgpu_provider::register_wgpu_provider(wgpu_provider::WgpuProviderOptions::default())
581            .is_err()
582        {
583            return;
584        }
585        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
586        let tensor = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
587        let view = HostTensorView {
588            data: &tensor.data,
589            shape: &tensor.shape,
590        };
591        let handle = provider.upload(&view).expect("upload");
592        let result =
593            block_on(pause_builtin(vec![Value::GpuTensor(handle)])).expect("pause(gpuScalar)");
594        match result {
595            Value::Tensor(t) => assert_eq!(t.data.len(), 0),
596            other => panic!("expected empty tensor, got {other:?}"),
597        }
598    }
599
600    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
601    #[test]
602    fn invalid_command_raises_error() {
603        let _guard = TEST_GUARD.lock().unwrap();
604        reset_state(true);
605        let err = block_on(pause_builtin(vec![Value::from("invalid")])).unwrap_err();
606        assert_pause_error_identifier(err, ERR_INVALID_ARG);
607    }
608}