1use once_cell::sync::Lazy;
4use runmat_builtins::{CharArray, LogicalArray, Tensor, Value};
5use runmat_macros::runtime_builtin;
6use std::sync::RwLock;
7
8use crate::builtins::common::gpu_helpers;
9use crate::builtins::common::spec::{
10 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11 ReductionNaN, ResidencyPolicy, ShapeRequirements,
12};
13#[cfg(all(target_arch = "wasm32", feature = "plot-web"))]
14use crate::builtins::plotting;
15use crate::builtins::timing::type_resolvers::pause_type;
16#[cfg(not(test))]
17use crate::interaction;
18use crate::{build_runtime_error, BuiltinResult, RuntimeError};
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::timing::pause")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22 name: "pause",
23 op_kind: GpuOpKind::Custom("timer"),
24 supported_precisions: &[],
25 broadcast: BroadcastSemantics::None,
26 provider_hooks: &[],
27 constant_strategy: ConstantStrategy::InlineLiteral,
28 residency: ResidencyPolicy::GatherImmediately,
29 nan_mode: ReductionNaN::Include,
30 two_pass_threshold: None,
31 workgroup_size: None,
32 accepts_nan_mode: false,
33 notes: "pause executes entirely on the host. Acceleration providers are never queried.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::timing::pause")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38 name: "pause",
39 shape: ShapeRequirements::Any,
40 constant_strategy: ConstantStrategy::InlineLiteral,
41 elementwise: None,
42 reduction: None,
43 emits_nan: false,
44 notes: "pause suspends host execution and is excluded from fusion pipelines.",
45};
46
47static PAUSE_STATE: Lazy<RwLock<PauseState>> = Lazy::new(|| RwLock::new(PauseState::default()));
48
49#[cfg(test)]
50use std::sync::Mutex;
51#[cfg(test)]
52pub(crate) static TEST_GUARD: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
53
54#[derive(Debug, Clone, Copy)]
55struct PauseState {
56 enabled: bool,
57}
58
59impl Default for PauseState {
60 fn default() -> Self {
61 Self { enabled: true }
62 }
63}
64
65const BUILTIN_NAME: &str = "pause";
66const ERR_INVALID_ARG: &str = "RunMat:pause:InvalidInputArgument";
67const ERR_TOO_MANY_INPUTS: &str = "RunMat:pause:TooManyInputs";
68const MSG_INVALID_ARG: &str = "pause: invalid input argument";
69const MSG_TOO_MANY_INPUTS: &str = "pause: too many input arguments";
70const MSG_STATE_LOCK: &str = "pause: failed to acquire pause state";
71
72fn pause_error(message: impl Into<String>) -> RuntimeError {
73 build_runtime_error(message)
74 .with_builtin(BUILTIN_NAME)
75 .build()
76}
77
78fn pause_error_with_identifier(message: impl Into<String>, identifier: &str) -> RuntimeError {
79 build_runtime_error(message)
80 .with_builtin(BUILTIN_NAME)
81 .with_identifier(identifier)
82 .build()
83}
84
85#[derive(Debug, Clone, Copy)]
86enum PauseArgument {
87 Wait(PauseWait),
88 SetState(bool),
89 Query,
90}
91
92#[derive(Debug, Clone, Copy)]
93enum PauseWait {
94 Default,
95 Seconds(f64),
96}
97
98#[runtime_builtin(
100 name = "pause",
101 category = "timing",
102 summary = "Suspend execution until a key press or specified duration.",
103 keywords = "pause,sleep,wait,delay",
104 accel = "metadata",
105 sink = true,
106 type_resolver(pause_type),
107 builtin_path = "crate::builtins::timing::pause"
108)]
109async fn pause_builtin(args: Vec<Value>) -> BuiltinResult<Value> {
110 match args.len() {
111 0 => {
112 perform_wait(PauseWait::Default).await?;
113 Ok(empty_return_value())
114 }
115 1 => match classify_argument(&args[0]).await? {
116 PauseArgument::Wait(wait) => {
117 perform_wait(wait).await?;
118 Ok(empty_return_value())
119 }
120 PauseArgument::SetState(next_state) => {
121 let previous = set_pause_enabled(next_state)?;
122 Ok(state_value(previous))
123 }
124 PauseArgument::Query => {
125 let current = pause_enabled()?;
126 Ok(state_value(current))
127 }
128 },
129 _ => Err(pause_error_with_identifier(
130 MSG_TOO_MANY_INPUTS,
131 ERR_TOO_MANY_INPUTS,
132 )),
133 }
134}
135
136async fn perform_wait(wait: PauseWait) -> Result<(), RuntimeError> {
137 if !pause_enabled()? {
138 return Ok(());
139 }
140
141 #[cfg(all(target_arch = "wasm32", feature = "plot-web"))]
142 {
143 let handle = plotting::current_figure_handle();
146 let _ = plotting::render_current_scene(handle.as_u32());
148 }
149
150 match wait {
151 PauseWait::Default => wait_for_key_press().await,
152 PauseWait::Seconds(seconds) => {
153 if seconds == 0.0 {
154 #[cfg(target_arch = "wasm32")]
156 {
157 return wasm_sleep_seconds(0.0).await;
158 }
159 #[cfg(not(target_arch = "wasm32"))]
160 {
161 return Ok(());
162 }
163 }
164 sleep_seconds(seconds).await?;
165 #[cfg(all(target_arch = "wasm32", feature = "plot-web"))]
166 {
167 let handle = plotting::current_figure_handle();
171 let _ = plotting::render_current_scene(handle.as_u32());
172 }
173 Ok(())
174 }
175 }
176}
177
178async fn wait_for_key_press() -> Result<(), RuntimeError> {
179 #[cfg(test)]
180 {
181 Ok(())
182 }
183 #[cfg(not(test))]
184 {
185 interaction::wait_for_key_async("").await
186 }
187}
188
189async fn sleep_seconds(seconds: f64) -> Result<(), RuntimeError> {
190 #[cfg(target_arch = "wasm32")]
191 {
192 wasm_sleep_seconds(seconds).await
193 }
194 #[cfg(not(target_arch = "wasm32"))]
195 {
196 let duration = std::time::Duration::from_secs_f64(seconds);
198 std::thread::sleep(duration);
199 Ok(())
200 }
201}
202
203#[cfg(target_arch = "wasm32")]
204async fn wasm_sleep_seconds(seconds: f64) -> Result<(), RuntimeError> {
205 use js_sys::{Function, Promise, Reflect};
206 use wasm_bindgen::JsCast;
207 use wasm_bindgen_futures::JsFuture;
208
209 let global = js_sys::global();
212 let set_timeout = Reflect::get(&global, &wasm_bindgen::JsValue::from_str("setTimeout"))
213 .map_err(|_| build_runtime_error("pause: setTimeout unavailable").build())?
214 .dyn_into::<Function>()
215 .map_err(|_| build_runtime_error("pause: setTimeout unavailable").build())?;
216
217 let millis = (seconds * 1000.0).max(0.0).round();
218 let millis_i32 = if millis > i32::MAX as f64 {
219 i32::MAX
220 } else {
221 millis as i32
222 };
223
224 let promise = Promise::new(&mut |resolve, _reject| {
225 let resolve: Function = resolve.unchecked_into();
226 let _ = set_timeout.call2(
227 &global,
228 &resolve.into(),
229 &wasm_bindgen::JsValue::from_f64(millis_i32 as f64),
230 );
231 });
232
233 let _ = JsFuture::from(promise)
234 .await
235 .map_err(|err| build_runtime_error(format!("pause: timer failed ({err:?})")).build())?;
236 Ok(())
237}
238
239async fn classify_argument(arg: &Value) -> Result<PauseArgument, RuntimeError> {
240 let host_value = gpu_helpers::gather_value_async(arg)
241 .await
242 .map_err(|e| pause_error(format!("pause: {e}")))?;
243 match host_value {
244 Value::String(text) => parse_command(&text),
245 Value::CharArray(ca) => {
246 if ca.rows == 0 || ca.data.is_empty() {
247 Ok(PauseArgument::Wait(PauseWait::Default))
248 } else if ca.rows == 1 {
249 let text: String = ca.data.iter().collect();
250 parse_command(&text)
251 } else {
252 Err(pause_error_with_identifier(
253 MSG_INVALID_ARG,
254 ERR_INVALID_ARG,
255 ))
256 }
257 }
258 Value::StringArray(sa) => {
259 if sa.data.is_empty() {
260 Ok(PauseArgument::Wait(PauseWait::Default))
261 } else if sa.data.len() == 1 {
262 parse_command(&sa.data[0])
263 } else {
264 Err(pause_error_with_identifier(
265 MSG_INVALID_ARG,
266 ERR_INVALID_ARG,
267 ))
268 }
269 }
270 Value::Num(value) => parse_numeric(value),
271 Value::Int(int_value) => parse_numeric(int_value.to_f64()),
272 Value::Bool(flag) => parse_numeric(if flag { 1.0 } else { 0.0 }),
273 Value::Tensor(tensor) => parse_tensor(tensor),
274 Value::LogicalArray(logical) => parse_logical(logical),
275 Value::GpuTensor(handle) => {
276 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
277 parse_tensor(tensor)
278 }
279 Value::Complex(_, _)
280 | Value::ComplexTensor(_)
281 | Value::Cell(_)
282 | Value::Struct(_)
283 | Value::Object(_)
284 | Value::HandleObject(_)
285 | Value::Listener(_)
286 | Value::FunctionHandle(_)
287 | Value::Closure(_)
288 | Value::ClassRef(_)
289 | Value::MException(_)
290 | Value::OutputList(_) => Err(pause_error_with_identifier(
291 MSG_INVALID_ARG,
292 ERR_INVALID_ARG,
293 )),
294 }
295}
296
297fn parse_command(raw: &str) -> Result<PauseArgument, RuntimeError> {
298 let trimmed = raw.trim();
299 if trimmed.is_empty() {
300 return Ok(PauseArgument::Wait(PauseWait::Default));
301 }
302 let lower = trimmed.to_ascii_lowercase();
303 match lower.as_str() {
304 "on" => Ok(PauseArgument::SetState(true)),
305 "off" => Ok(PauseArgument::SetState(false)),
306 "query" => Ok(PauseArgument::Query),
307 _ => Err(pause_error_with_identifier(
308 MSG_INVALID_ARG,
309 ERR_INVALID_ARG,
310 )),
311 }
312}
313
314fn parse_numeric(value: f64) -> Result<PauseArgument, RuntimeError> {
315 if !value.is_finite() {
316 if value.is_sign_positive() {
317 return Ok(PauseArgument::Wait(PauseWait::Default));
318 }
319 return Err(pause_error_with_identifier(
320 MSG_INVALID_ARG,
321 ERR_INVALID_ARG,
322 ));
323 }
324 if value < 0.0 {
325 return Err(pause_error_with_identifier(
326 MSG_INVALID_ARG,
327 ERR_INVALID_ARG,
328 ));
329 }
330 Ok(PauseArgument::Wait(PauseWait::Seconds(value)))
331}
332
333fn parse_tensor(tensor: Tensor) -> Result<PauseArgument, RuntimeError> {
334 if tensor.data.is_empty() {
335 return Ok(PauseArgument::Wait(PauseWait::Default));
336 }
337 if tensor.data.len() != 1 {
338 return Err(pause_error_with_identifier(
339 MSG_INVALID_ARG,
340 ERR_INVALID_ARG,
341 ));
342 }
343 parse_numeric(tensor.data[0])
344}
345
346fn parse_logical(logical: LogicalArray) -> Result<PauseArgument, RuntimeError> {
347 if logical.data.is_empty() {
348 return Ok(PauseArgument::Wait(PauseWait::Default));
349 }
350 if logical.data.len() != 1 {
351 return Err(pause_error_with_identifier(
352 MSG_INVALID_ARG,
353 ERR_INVALID_ARG,
354 ));
355 }
356 let scalar = if logical.data[0] != 0 { 1.0 } else { 0.0 };
357 parse_numeric(scalar)
358}
359
360fn empty_return_value() -> Value {
361 Value::Tensor(Tensor::zeros(vec![0, 0]))
362}
363
364fn state_value(enabled: bool) -> Value {
365 let text = if enabled { "on" } else { "off" };
366 Value::CharArray(CharArray::new_row(text))
367}
368
369fn pause_enabled() -> Result<bool, RuntimeError> {
370 PAUSE_STATE
371 .read()
372 .map(|guard| guard.enabled)
373 .map_err(|_| pause_error(MSG_STATE_LOCK))
374}
375
376fn set_pause_enabled(next: bool) -> Result<bool, RuntimeError> {
377 let mut guard = PAUSE_STATE
378 .write()
379 .map_err(|_| pause_error(MSG_STATE_LOCK))?;
380 let previous = guard.enabled;
381 guard.enabled = next;
382 Ok(previous)
383}
384
385#[cfg(test)]
386pub(crate) mod tests {
387 use super::*;
388 use crate::builtins::common::test_support;
389 use futures::executor::block_on;
390 use runmat_accelerate_api::HostTensorView;
391 use runmat_builtins::{IntValue, LogicalArray, Tensor};
392
393 #[cfg(feature = "wgpu")]
394 use runmat_accelerate::backend::wgpu::provider as wgpu_provider;
395
396 fn reset_state(enabled: bool) {
397 let mut guard = PAUSE_STATE.write().unwrap_or_else(|e| e.into_inner());
398 guard.enabled = enabled;
399 }
400
401 fn char_array_to_string(value: Value) -> String {
402 match value {
403 Value::CharArray(ca) if ca.rows == 1 => ca.data.iter().collect(),
404 other => panic!("expected char array, got {other:?}"),
405 }
406 }
407
408 fn assert_pause_error_identifier(err: crate::RuntimeError, identifier: &str) {
409 assert_eq!(
410 err.identifier(),
411 Some(identifier),
412 "message: {}",
413 err.message()
414 );
415 }
416
417 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
418 #[test]
419 fn query_returns_on_by_default() {
420 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
421 reset_state(true);
422 let result = block_on(pause_builtin(vec![Value::from("query")])).expect("pause query");
423 assert_eq!(char_array_to_string(result), "on");
424 }
425
426 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
427 #[test]
428 fn pause_off_returns_previous_state() {
429 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
430 reset_state(true);
431 let previous = block_on(pause_builtin(vec![Value::from("off")])).expect("pause off");
432 assert_eq!(char_array_to_string(previous), "on");
433 assert!(!pause_enabled().unwrap());
434 }
435
436 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
437 #[test]
438 fn pause_on_restores_state() {
439 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
440 reset_state(false);
441 let previous = block_on(pause_builtin(vec![Value::from("on")])).expect("pause on");
442 assert_eq!(char_array_to_string(previous), "off");
443 assert!(pause_enabled().unwrap());
444 }
445
446 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
447 #[test]
448 fn pause_default_returns_empty_tensor() {
449 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
450 reset_state(true);
451 let result = block_on(pause_builtin(Vec::new())).expect("pause()");
452 match result {
453 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
454 other => panic!("expected empty tensor, got {other:?}"),
455 }
456 }
457
458 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
459 #[test]
460 fn numeric_zero_is_accepted() {
461 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
462 reset_state(true);
463 let result = block_on(pause_builtin(vec![Value::Num(0.0)])).expect("pause(0)");
464 match result {
465 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
466 other => panic!("expected empty tensor, got {other:?}"),
467 }
468 }
469
470 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
471 #[test]
472 fn integer_scalar_is_accepted() {
473 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
474 reset_state(true);
475 let result =
476 block_on(pause_builtin(vec![Value::Int(IntValue::I32(0))])).expect("pause(int)");
477 match result {
478 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
479 other => panic!("expected empty tensor, got {other:?}"),
480 }
481 }
482
483 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
484 #[test]
485 fn numeric_negative_zero_is_treated_as_zero() {
486 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
487 reset_state(true);
488 let result = block_on(pause_builtin(vec![Value::Num(-0.0)])).expect("pause(-0)");
489 match result {
490 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
491 other => panic!("expected empty tensor, got {other:?}"),
492 }
493 }
494
495 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
496 #[test]
497 fn negative_duration_raises_error() {
498 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
499 reset_state(true);
500 let err = block_on(pause_builtin(vec![Value::Num(-0.1)])).unwrap_err();
501 assert_pause_error_identifier(err, ERR_INVALID_ARG);
502 }
503
504 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
505 #[test]
506 fn non_scalar_tensor_is_rejected() {
507 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
508 reset_state(true);
509 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
510 let err = block_on(pause_builtin(vec![Value::Tensor(tensor)])).unwrap_err();
511 assert_pause_error_identifier(err, ERR_INVALID_ARG);
512 }
513
514 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
515 #[test]
516 fn empty_tensor_behaves_like_default_pause() {
517 let _guard = TEST_GUARD.lock().unwrap();
518 reset_state(true);
519 let empty = Tensor::zeros(vec![0, 0]);
520 let result = block_on(pause_builtin(vec![Value::Tensor(empty)])).expect("pause([])");
521 match result {
522 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
523 other => panic!("expected empty tensor, got {other:?}"),
524 }
525 }
526
527 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
528 #[test]
529 fn logical_scalar_is_accepted() {
530 let _guard = TEST_GUARD.lock().unwrap();
531 reset_state(true);
532 let logical = LogicalArray::new(vec![1u8], vec![1, 1]).unwrap();
533 let result =
534 block_on(pause_builtin(vec![Value::LogicalArray(logical)])).expect("pause(true)");
535 match result {
536 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
537 other => panic!("expected empty tensor, got {other:?}"),
538 }
539 }
540
541 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
542 #[test]
543 fn infinite_duration_behaves_like_default() {
544 let _guard = TEST_GUARD.lock().unwrap();
545 reset_state(true);
546 let result = block_on(pause_builtin(vec![Value::Num(f64::INFINITY)])).expect("pause(Inf)");
547 match result {
548 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
549 other => panic!("expected empty tensor, got {other:?}"),
550 }
551 }
552
553 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
554 #[test]
555 fn pause_gpu_duration_gathered() {
556 let _guard = TEST_GUARD.lock().unwrap();
557 reset_state(true);
558 test_support::with_test_provider(|provider| {
559 let tensor = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
560 let view = HostTensorView {
561 data: &tensor.data,
562 shape: &tensor.shape,
563 };
564 let handle = provider.upload(&view).expect("upload");
565 let result =
566 block_on(pause_builtin(vec![Value::GpuTensor(handle)])).expect("pause(gpuScalar)");
567 match result {
568 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
569 other => panic!("expected empty tensor, got {other:?}"),
570 }
571 });
572 }
573
574 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
575 #[test]
576 #[cfg(feature = "wgpu")]
577 fn pause_wgpu_duration_gathered() {
578 let _guard = TEST_GUARD.lock().unwrap();
579 reset_state(true);
580 if wgpu_provider::register_wgpu_provider(wgpu_provider::WgpuProviderOptions::default())
581 .is_err()
582 {
583 return;
584 }
585 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
586 let tensor = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
587 let view = HostTensorView {
588 data: &tensor.data,
589 shape: &tensor.shape,
590 };
591 let handle = provider.upload(&view).expect("upload");
592 let result =
593 block_on(pause_builtin(vec![Value::GpuTensor(handle)])).expect("pause(gpuScalar)");
594 match result {
595 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
596 other => panic!("expected empty tensor, got {other:?}"),
597 }
598 }
599
600 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
601 #[test]
602 fn invalid_command_raises_error() {
603 let _guard = TEST_GUARD.lock().unwrap();
604 reset_state(true);
605 let err = block_on(pause_builtin(vec![Value::from("invalid")])).unwrap_err();
606 assert_pause_error_identifier(err, ERR_INVALID_ARG);
607 }
608}