1use once_cell::sync::Lazy;
4use runmat_builtins::{
5 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7 CharArray, LogicalArray, Tensor, Value,
8};
9use runmat_macros::runtime_builtin;
10use std::sync::RwLock;
11
12use crate::builtins::common::gpu_helpers;
13use crate::builtins::common::spec::{
14 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15 ReductionNaN, ResidencyPolicy, ShapeRequirements,
16};
17#[cfg(all(target_arch = "wasm32", feature = "plot-web"))]
18use crate::builtins::plotting;
19use crate::builtins::timing::type_resolvers::pause_type;
20#[cfg(not(test))]
21use crate::interaction;
22use crate::{build_runtime_error, BuiltinResult, RuntimeError};
23
24#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::timing::pause")]
25pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
26 name: "pause",
27 op_kind: GpuOpKind::Custom("timer"),
28 supported_precisions: &[],
29 broadcast: BroadcastSemantics::None,
30 provider_hooks: &[],
31 constant_strategy: ConstantStrategy::InlineLiteral,
32 residency: ResidencyPolicy::GatherImmediately,
33 nan_mode: ReductionNaN::Include,
34 two_pass_threshold: None,
35 workgroup_size: None,
36 accepts_nan_mode: false,
37 notes: "pause executes entirely on the host. Acceleration providers are never queried.",
38};
39
40#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::timing::pause")]
41pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
42 name: "pause",
43 shape: ShapeRequirements::Any,
44 constant_strategy: ConstantStrategy::InlineLiteral,
45 elementwise: None,
46 reduction: None,
47 emits_nan: false,
48 notes: "pause suspends host execution and is excluded from fusion pipelines.",
49};
50
51static PAUSE_STATE: Lazy<RwLock<PauseState>> = Lazy::new(|| RwLock::new(PauseState::default()));
52
53#[cfg(test)]
54use std::sync::Mutex;
55#[cfg(test)]
56pub(crate) static TEST_GUARD: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
57
58#[derive(Debug, Clone, Copy)]
59struct PauseState {
60 enabled: bool,
61}
62
63impl Default for PauseState {
64 fn default() -> Self {
65 Self { enabled: true }
66 }
67}
68
69const BUILTIN_NAME: &str = "pause";
70
71const PAUSE_OUTPUT_EMPTY: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
72 name: "out",
73 ty: BuiltinParamType::NumericArray,
74 arity: BuiltinParamArity::Required,
75 default: None,
76 description: "Empty array when pausing or changing state.",
77}];
78
79const PAUSE_OUTPUT_STATE: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
80 name: "state",
81 ty: BuiltinParamType::StringScalar,
82 arity: BuiltinParamArity::Required,
83 default: None,
84 description: "Previous pause state ('on' or 'off').",
85}];
86
87const PAUSE_INPUTS_NONE: [BuiltinParamDescriptor; 0] = [];
88const PAUSE_INPUTS_DURATION: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
89 name: "duration",
90 ty: BuiltinParamType::Any,
91 arity: BuiltinParamArity::Required,
92 default: Some("0"),
93 description: "Duration scalar or command-like scalar value accepted by pause.",
94}];
95const PAUSE_INPUTS_COMMAND: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
96 name: "command",
97 ty: BuiltinParamType::StringScalar,
98 arity: BuiltinParamArity::Required,
99 default: None,
100 description: "One of 'on', 'off', or 'query'.",
101}];
102
103const PAUSE_SIGNATURES: [BuiltinSignatureDescriptor; 3] = [
104 BuiltinSignatureDescriptor {
105 label: "out = pause()",
106 inputs: &PAUSE_INPUTS_NONE,
107 outputs: &PAUSE_OUTPUT_EMPTY,
108 },
109 BuiltinSignatureDescriptor {
110 label: "out = pause(duration)",
111 inputs: &PAUSE_INPUTS_DURATION,
112 outputs: &PAUSE_OUTPUT_EMPTY,
113 },
114 BuiltinSignatureDescriptor {
115 label: "state = pause(command)",
116 inputs: &PAUSE_INPUTS_COMMAND,
117 outputs: &PAUSE_OUTPUT_STATE,
118 },
119];
120
121const PAUSE_ERROR_INVALID_ARG: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
122 code: "RM.PAUSE.INVALID_ARG",
123 identifier: Some("RunMat:pause:InvalidInputArgument"),
124 when: "Input argument is malformed, unsupported, non-scalar where scalar is required, or a negative/non-finite duration.",
125 message: "pause: invalid input argument",
126};
127
128const PAUSE_ERROR_TOO_MANY_INPUTS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
129 code: "RM.PAUSE.TOO_MANY_INPUTS",
130 identifier: Some("RunMat:pause:TooManyInputs"),
131 when: "More than one input argument is supplied.",
132 message: "pause: too many input arguments",
133};
134
135const PAUSE_ERROR_STATE_LOCK: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
136 code: "RM.PAUSE.STATE_LOCK",
137 identifier: Some("RunMat:pause:StateLockFailed"),
138 when: "Internal pause-state lock cannot be acquired.",
139 message: "pause: failed to acquire pause state",
140};
141
142const PAUSE_ERROR_GATHER_FAILED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
143 code: "RM.PAUSE.GPU_GATHER_FAILED",
144 identifier: Some("RunMat:pause:GpuGatherFailed"),
145 when: "Gathering a GPU argument to host fails during argument classification.",
146 message: "pause: failed to gather gpu input",
147};
148
149const PAUSE_ERRORS: [BuiltinErrorDescriptor; 4] = [
150 PAUSE_ERROR_INVALID_ARG,
151 PAUSE_ERROR_TOO_MANY_INPUTS,
152 PAUSE_ERROR_STATE_LOCK,
153 PAUSE_ERROR_GATHER_FAILED,
154];
155
156pub const PAUSE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
157 signatures: &PAUSE_SIGNATURES,
158 output_mode: BuiltinOutputMode::Fixed,
159 completion_policy: BuiltinCompletionPolicy::Public,
160 errors: &PAUSE_ERRORS,
161};
162
163fn pause_error_with_message(
164 message: impl Into<String>,
165 error: &'static BuiltinErrorDescriptor,
166) -> RuntimeError {
167 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
168 if let Some(identifier) = error.identifier {
169 builder = builder.with_identifier(identifier);
170 }
171 builder.build()
172}
173
174#[derive(Debug, Clone, Copy)]
175enum PauseArgument {
176 Wait(PauseWait),
177 SetState(bool),
178 Query,
179}
180
181#[derive(Debug, Clone, Copy)]
182enum PauseWait {
183 Default,
184 Seconds(f64),
185}
186
187#[runtime_builtin(
189 name = "pause",
190 category = "timing",
191 summary = "Pause execution until keypress or specified duration elapses.",
192 keywords = "pause,sleep,wait,delay",
193 accel = "metadata",
194 sink = true,
195 type_resolver(pause_type),
196 descriptor(crate::builtins::timing::pause::PAUSE_DESCRIPTOR),
197 builtin_path = "crate::builtins::timing::pause"
198)]
199async fn pause_builtin(args: Vec<Value>) -> BuiltinResult<Value> {
200 match args.len() {
201 0 => {
202 perform_wait(PauseWait::Default).await?;
203 Ok(empty_return_value())
204 }
205 1 => match classify_argument(&args[0]).await? {
206 PauseArgument::Wait(wait) => {
207 perform_wait(wait).await?;
208 Ok(empty_return_value())
209 }
210 PauseArgument::SetState(next_state) => {
211 let previous = set_pause_enabled(next_state)?;
212 Ok(state_value(previous))
213 }
214 PauseArgument::Query => {
215 let current = pause_enabled()?;
216 Ok(state_value(current))
217 }
218 },
219 _ => Err(pause_error_with_message(
220 PAUSE_ERROR_TOO_MANY_INPUTS.message,
221 &PAUSE_ERROR_TOO_MANY_INPUTS,
222 )),
223 }
224}
225
226async fn perform_wait(wait: PauseWait) -> Result<(), RuntimeError> {
227 if !pause_enabled()? {
228 return Ok(());
229 }
230
231 #[cfg(all(target_arch = "wasm32", feature = "plot-web"))]
232 {
233 let handle = plotting::current_figure_handle();
236 let _ = plotting::render_current_scene(handle.as_u32());
238 }
239
240 match wait {
241 PauseWait::Default => wait_for_key_press().await,
242 PauseWait::Seconds(seconds) => {
243 if seconds == 0.0 {
244 #[cfg(target_arch = "wasm32")]
246 {
247 return wasm_sleep_seconds(0.0).await;
248 }
249 #[cfg(not(target_arch = "wasm32"))]
250 {
251 return Ok(());
252 }
253 }
254 sleep_seconds(seconds).await?;
255 #[cfg(all(target_arch = "wasm32", feature = "plot-web"))]
256 {
257 let handle = plotting::current_figure_handle();
261 let _ = plotting::render_current_scene(handle.as_u32());
262 }
263 Ok(())
264 }
265 }
266}
267
268async fn wait_for_key_press() -> Result<(), RuntimeError> {
269 #[cfg(test)]
270 {
271 Ok(())
272 }
273 #[cfg(not(test))]
274 {
275 interaction::wait_for_key_async("").await
276 }
277}
278
279async fn sleep_seconds(seconds: f64) -> Result<(), RuntimeError> {
280 #[cfg(target_arch = "wasm32")]
281 {
282 wasm_sleep_seconds(seconds).await
283 }
284 #[cfg(not(target_arch = "wasm32"))]
285 {
286 let duration = std::time::Duration::from_secs_f64(seconds);
288 std::thread::sleep(duration);
289 Ok(())
290 }
291}
292
293#[cfg(target_arch = "wasm32")]
294async fn wasm_sleep_seconds(seconds: f64) -> Result<(), RuntimeError> {
295 use js_sys::{Function, Promise, Reflect};
296 use wasm_bindgen::JsCast;
297 use wasm_bindgen_futures::JsFuture;
298
299 let global = js_sys::global();
302 let set_timeout = Reflect::get(&global, &wasm_bindgen::JsValue::from_str("setTimeout"))
303 .map_err(|_| build_runtime_error("pause: setTimeout unavailable").build())?
304 .dyn_into::<Function>()
305 .map_err(|_| build_runtime_error("pause: setTimeout unavailable").build())?;
306
307 let millis = (seconds * 1000.0).max(0.0).round();
308 let millis_i32 = if millis > i32::MAX as f64 {
309 i32::MAX
310 } else {
311 millis as i32
312 };
313
314 let promise = Promise::new(&mut |resolve, _reject| {
315 let resolve: Function = resolve.unchecked_into();
316 let _ = set_timeout.call2(
317 &global,
318 &resolve.into(),
319 &wasm_bindgen::JsValue::from_f64(millis_i32 as f64),
320 );
321 });
322
323 let _ = JsFuture::from(promise)
324 .await
325 .map_err(|err| build_runtime_error(format!("pause: timer failed ({err:?})")).build())?;
326 Ok(())
327}
328
329async fn classify_argument(arg: &Value) -> Result<PauseArgument, RuntimeError> {
330 let host_value = gpu_helpers::gather_value_async(arg)
331 .await
332 .map_err(|e| pause_error_with_message(format!("pause: {e}"), &PAUSE_ERROR_GATHER_FAILED))?;
333 match host_value {
334 Value::String(text) => parse_command(&text),
335 Value::CharArray(ca) => {
336 if ca.rows == 0 || ca.data.is_empty() {
337 Ok(PauseArgument::Wait(PauseWait::Default))
338 } else if ca.rows == 1 {
339 let text: String = ca.data.iter().collect();
340 parse_command(&text)
341 } else {
342 Err(pause_error_with_message(
343 PAUSE_ERROR_INVALID_ARG.message,
344 &PAUSE_ERROR_INVALID_ARG,
345 ))
346 }
347 }
348 Value::StringArray(sa) => {
349 if sa.data.is_empty() {
350 Ok(PauseArgument::Wait(PauseWait::Default))
351 } else if sa.data.len() == 1 {
352 parse_command(&sa.data[0])
353 } else {
354 Err(pause_error_with_message(
355 PAUSE_ERROR_INVALID_ARG.message,
356 &PAUSE_ERROR_INVALID_ARG,
357 ))
358 }
359 }
360 Value::Num(value) => parse_numeric(value),
361 Value::Int(int_value) => parse_numeric(int_value.to_f64()),
362 Value::Bool(flag) => parse_numeric(if flag { 1.0 } else { 0.0 }),
363 Value::Tensor(tensor) => parse_tensor(tensor),
364 Value::LogicalArray(logical) => parse_logical(logical),
365 Value::GpuTensor(handle) => {
366 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
367 parse_tensor(tensor)
368 }
369 Value::Complex(_, _)
370 | Value::ComplexTensor(_)
371 | Value::Symbolic(_)
372 | Value::SparseTensor(_)
373 | Value::Cell(_)
374 | Value::Struct(_)
375 | Value::Object(_)
376 | Value::HandleObject(_)
377 | Value::Listener(_)
378 | Value::FunctionHandle(_)
379 | Value::ExternalFunctionHandle(_)
380 | Value::MethodFunctionHandle(_)
381 | Value::BoundFunctionHandle { .. }
382 | Value::Closure(_)
383 | Value::ClassRef(_)
384 | Value::MException(_)
385 | Value::OutputList(_) => Err(pause_error_with_message(
386 PAUSE_ERROR_INVALID_ARG.message,
387 &PAUSE_ERROR_INVALID_ARG,
388 )),
389 }
390}
391
392fn parse_command(raw: &str) -> Result<PauseArgument, RuntimeError> {
393 let trimmed = raw.trim();
394 if trimmed.is_empty() {
395 return Ok(PauseArgument::Wait(PauseWait::Default));
396 }
397 let lower = trimmed.to_ascii_lowercase();
398 match lower.as_str() {
399 "on" => Ok(PauseArgument::SetState(true)),
400 "off" => Ok(PauseArgument::SetState(false)),
401 "query" => Ok(PauseArgument::Query),
402 _ => Err(pause_error_with_message(
403 PAUSE_ERROR_INVALID_ARG.message,
404 &PAUSE_ERROR_INVALID_ARG,
405 )),
406 }
407}
408
409fn parse_numeric(value: f64) -> Result<PauseArgument, RuntimeError> {
410 if !value.is_finite() {
411 if value.is_sign_positive() {
412 return Ok(PauseArgument::Wait(PauseWait::Default));
413 }
414 return Err(pause_error_with_message(
415 PAUSE_ERROR_INVALID_ARG.message,
416 &PAUSE_ERROR_INVALID_ARG,
417 ));
418 }
419 if value < 0.0 {
420 return Err(pause_error_with_message(
421 PAUSE_ERROR_INVALID_ARG.message,
422 &PAUSE_ERROR_INVALID_ARG,
423 ));
424 }
425 Ok(PauseArgument::Wait(PauseWait::Seconds(value)))
426}
427
428fn parse_tensor(tensor: Tensor) -> Result<PauseArgument, RuntimeError> {
429 if tensor.data.is_empty() {
430 return Ok(PauseArgument::Wait(PauseWait::Default));
431 }
432 if tensor.data.len() != 1 {
433 return Err(pause_error_with_message(
434 PAUSE_ERROR_INVALID_ARG.message,
435 &PAUSE_ERROR_INVALID_ARG,
436 ));
437 }
438 parse_numeric(tensor.data[0])
439}
440
441fn parse_logical(logical: LogicalArray) -> Result<PauseArgument, RuntimeError> {
442 if logical.data.is_empty() {
443 return Ok(PauseArgument::Wait(PauseWait::Default));
444 }
445 if logical.data.len() != 1 {
446 return Err(pause_error_with_message(
447 PAUSE_ERROR_INVALID_ARG.message,
448 &PAUSE_ERROR_INVALID_ARG,
449 ));
450 }
451 let scalar = if logical.data[0] != 0 { 1.0 } else { 0.0 };
452 parse_numeric(scalar)
453}
454
455fn empty_return_value() -> Value {
456 Value::Tensor(Tensor::zeros(vec![0, 0]))
457}
458
459fn state_value(enabled: bool) -> Value {
460 let text = if enabled { "on" } else { "off" };
461 Value::CharArray(CharArray::new_row(text))
462}
463
464fn pause_enabled() -> Result<bool, RuntimeError> {
465 PAUSE_STATE.read().map(|guard| guard.enabled).map_err(|_| {
466 pause_error_with_message(PAUSE_ERROR_STATE_LOCK.message, &PAUSE_ERROR_STATE_LOCK)
467 })
468}
469
470fn set_pause_enabled(next: bool) -> Result<bool, RuntimeError> {
471 let mut guard = PAUSE_STATE.write().map_err(|_| {
472 pause_error_with_message(PAUSE_ERROR_STATE_LOCK.message, &PAUSE_ERROR_STATE_LOCK)
473 })?;
474 let previous = guard.enabled;
475 guard.enabled = next;
476 Ok(previous)
477}
478
479#[cfg(test)]
480pub(crate) mod tests {
481 use super::*;
482 use crate::builtins::common::test_support;
483 use futures::executor::block_on;
484 use runmat_accelerate_api::HostTensorView;
485 use runmat_builtins::{IntValue, LogicalArray, Tensor};
486
487 #[cfg(feature = "wgpu")]
488 use runmat_accelerate::backend::wgpu::provider as wgpu_provider;
489
490 fn reset_state(enabled: bool) {
491 let mut guard = PAUSE_STATE.write().unwrap_or_else(|e| e.into_inner());
492 guard.enabled = enabled;
493 }
494
495 fn char_array_to_string(value: Value) -> String {
496 match value {
497 Value::CharArray(ca) if ca.rows == 1 => ca.data.iter().collect(),
498 other => panic!("expected char array, got {other:?}"),
499 }
500 }
501
502 fn assert_pause_error_identifier(err: crate::RuntimeError, identifier: &str) {
503 assert_eq!(
504 err.identifier(),
505 Some(identifier),
506 "message: {}",
507 err.message()
508 );
509 }
510
511 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
512 #[test]
513 fn query_returns_on_by_default() {
514 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
515 reset_state(true);
516 let result = block_on(pause_builtin(vec![Value::from("query")])).expect("pause query");
517 assert_eq!(char_array_to_string(result), "on");
518 }
519
520 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
521 #[test]
522 fn pause_off_returns_previous_state() {
523 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
524 reset_state(true);
525 let previous = block_on(pause_builtin(vec![Value::from("off")])).expect("pause off");
526 assert_eq!(char_array_to_string(previous), "on");
527 assert!(!pause_enabled().unwrap());
528 }
529
530 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
531 #[test]
532 fn pause_on_restores_state() {
533 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
534 reset_state(false);
535 let previous = block_on(pause_builtin(vec![Value::from("on")])).expect("pause on");
536 assert_eq!(char_array_to_string(previous), "off");
537 assert!(pause_enabled().unwrap());
538 }
539
540 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
541 #[test]
542 fn pause_default_returns_empty_tensor() {
543 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
544 reset_state(true);
545 let result = block_on(pause_builtin(Vec::new())).expect("pause()");
546 match result {
547 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
548 other => panic!("expected empty tensor, got {other:?}"),
549 }
550 }
551
552 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
553 #[test]
554 fn numeric_zero_is_accepted() {
555 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
556 reset_state(true);
557 let result = block_on(pause_builtin(vec![Value::Num(0.0)])).expect("pause(0)");
558 match result {
559 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
560 other => panic!("expected empty tensor, got {other:?}"),
561 }
562 }
563
564 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
565 #[test]
566 fn integer_scalar_is_accepted() {
567 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
568 reset_state(true);
569 let result =
570 block_on(pause_builtin(vec![Value::Int(IntValue::I32(0))])).expect("pause(int)");
571 match result {
572 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
573 other => panic!("expected empty tensor, got {other:?}"),
574 }
575 }
576
577 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
578 #[test]
579 fn numeric_negative_zero_is_treated_as_zero() {
580 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
581 reset_state(true);
582 let result = block_on(pause_builtin(vec![Value::Num(-0.0)])).expect("pause(-0)");
583 match result {
584 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
585 other => panic!("expected empty tensor, got {other:?}"),
586 }
587 }
588
589 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
590 #[test]
591 fn negative_duration_raises_error() {
592 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
593 reset_state(true);
594 let err = block_on(pause_builtin(vec![Value::Num(-0.1)])).unwrap_err();
595 assert_pause_error_identifier(err, PAUSE_ERROR_INVALID_ARG.identifier.unwrap());
596 }
597
598 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
599 #[test]
600 fn non_scalar_tensor_is_rejected() {
601 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
602 reset_state(true);
603 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
604 let err = block_on(pause_builtin(vec![Value::Tensor(tensor)])).unwrap_err();
605 assert_pause_error_identifier(err, PAUSE_ERROR_INVALID_ARG.identifier.unwrap());
606 }
607
608 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
609 #[test]
610 fn empty_tensor_behaves_like_default_pause() {
611 let _guard = TEST_GUARD.lock().unwrap();
612 reset_state(true);
613 let empty = Tensor::zeros(vec![0, 0]);
614 let result = block_on(pause_builtin(vec![Value::Tensor(empty)])).expect("pause([])");
615 match result {
616 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
617 other => panic!("expected empty tensor, got {other:?}"),
618 }
619 }
620
621 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
622 #[test]
623 fn logical_scalar_is_accepted() {
624 let _guard = TEST_GUARD.lock().unwrap();
625 reset_state(true);
626 let logical = LogicalArray::new(vec![1u8], vec![1, 1]).unwrap();
627 let result =
628 block_on(pause_builtin(vec![Value::LogicalArray(logical)])).expect("pause(true)");
629 match result {
630 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
631 other => panic!("expected empty tensor, got {other:?}"),
632 }
633 }
634
635 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
636 #[test]
637 fn infinite_duration_behaves_like_default() {
638 let _guard = TEST_GUARD.lock().unwrap();
639 reset_state(true);
640 let result = block_on(pause_builtin(vec![Value::Num(f64::INFINITY)])).expect("pause(Inf)");
641 match result {
642 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
643 other => panic!("expected empty tensor, got {other:?}"),
644 }
645 }
646
647 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
648 #[test]
649 fn pause_gpu_duration_gathered() {
650 let _guard = TEST_GUARD.lock().unwrap();
651 reset_state(true);
652 test_support::with_test_provider(|provider| {
653 let tensor = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
654 let view = HostTensorView {
655 data: &tensor.data,
656 shape: &tensor.shape,
657 };
658 let handle = provider.upload(&view).expect("upload");
659 let result =
660 block_on(pause_builtin(vec![Value::GpuTensor(handle)])).expect("pause(gpuScalar)");
661 match result {
662 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
663 other => panic!("expected empty tensor, got {other:?}"),
664 }
665 });
666 }
667
668 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
669 #[test]
670 #[cfg(feature = "wgpu")]
671 fn pause_wgpu_duration_gathered() {
672 let _guard = TEST_GUARD.lock().unwrap();
673 reset_state(true);
674 if wgpu_provider::register_wgpu_provider(wgpu_provider::WgpuProviderOptions::default())
675 .is_err()
676 {
677 return;
678 }
679 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
680 let tensor = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
681 let view = HostTensorView {
682 data: &tensor.data,
683 shape: &tensor.shape,
684 };
685 let handle = provider.upload(&view).expect("upload");
686 let result =
687 block_on(pause_builtin(vec![Value::GpuTensor(handle)])).expect("pause(gpuScalar)");
688 match result {
689 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
690 other => panic!("expected empty tensor, got {other:?}"),
691 }
692 }
693
694 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
695 #[test]
696 fn invalid_command_raises_error() {
697 let _guard = TEST_GUARD.lock().unwrap();
698 reset_state(true);
699 let err = block_on(pause_builtin(vec![Value::from("invalid")])).unwrap_err();
700 assert_pause_error_identifier(err, PAUSE_ERROR_INVALID_ARG.identifier.unwrap());
701 }
702
703 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
704 #[test]
705 fn too_many_inputs_raises_error() {
706 let _guard = TEST_GUARD.lock().unwrap();
707 reset_state(true);
708 let err = block_on(pause_builtin(vec![Value::Num(0.0), Value::Num(0.0)])).unwrap_err();
709 assert_pause_error_identifier(err, PAUSE_ERROR_TOO_MANY_INPUTS.identifier.unwrap());
710 }
711}