1use once_cell::sync::Lazy;
4use runmat_builtins::{
5 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7 CharArray, LogicalArray, Tensor, Value,
8};
9use runmat_macros::runtime_builtin;
10use std::sync::RwLock;
11
12use crate::builtins::common::gpu_helpers;
13use crate::builtins::common::spec::{
14 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15 ReductionNaN, ResidencyPolicy, ShapeRequirements,
16};
17#[cfg(all(target_arch = "wasm32", feature = "plot-web"))]
18use crate::builtins::plotting;
19use crate::builtins::timing::type_resolvers::pause_type;
20#[cfg(not(test))]
21use crate::interaction;
22use crate::{build_runtime_error, BuiltinResult, RuntimeError};
23
24#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::timing::pause")]
25pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
26 name: "pause",
27 op_kind: GpuOpKind::Custom("timer"),
28 supported_precisions: &[],
29 broadcast: BroadcastSemantics::None,
30 provider_hooks: &[],
31 constant_strategy: ConstantStrategy::InlineLiteral,
32 residency: ResidencyPolicy::GatherImmediately,
33 nan_mode: ReductionNaN::Include,
34 two_pass_threshold: None,
35 workgroup_size: None,
36 accepts_nan_mode: false,
37 notes: "pause executes entirely on the host. Acceleration providers are never queried.",
38};
39
40#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::timing::pause")]
41pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
42 name: "pause",
43 shape: ShapeRequirements::Any,
44 constant_strategy: ConstantStrategy::InlineLiteral,
45 elementwise: None,
46 reduction: None,
47 emits_nan: false,
48 notes: "pause suspends host execution and is excluded from fusion pipelines.",
49};
50
51static PAUSE_STATE: Lazy<RwLock<PauseState>> = Lazy::new(|| RwLock::new(PauseState::default()));
52
53#[cfg(test)]
54use std::sync::Mutex;
55#[cfg(test)]
56pub(crate) static TEST_GUARD: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
57
58#[derive(Debug, Clone, Copy)]
59struct PauseState {
60 enabled: bool,
61}
62
63impl Default for PauseState {
64 fn default() -> Self {
65 Self { enabled: true }
66 }
67}
68
69const BUILTIN_NAME: &str = "pause";
70
71const PAUSE_OUTPUT_EMPTY: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
72 name: "out",
73 ty: BuiltinParamType::NumericArray,
74 arity: BuiltinParamArity::Required,
75 default: None,
76 description: "Empty array when pausing or changing state.",
77}];
78
79const PAUSE_OUTPUT_STATE: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
80 name: "state",
81 ty: BuiltinParamType::StringScalar,
82 arity: BuiltinParamArity::Required,
83 default: None,
84 description: "Previous pause state ('on' or 'off').",
85}];
86
87const PAUSE_INPUTS_NONE: [BuiltinParamDescriptor; 0] = [];
88const PAUSE_INPUTS_DURATION: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
89 name: "duration",
90 ty: BuiltinParamType::Any,
91 arity: BuiltinParamArity::Required,
92 default: Some("0"),
93 description: "Duration scalar or command-like scalar value accepted by pause.",
94}];
95const PAUSE_INPUTS_COMMAND: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
96 name: "command",
97 ty: BuiltinParamType::StringScalar,
98 arity: BuiltinParamArity::Required,
99 default: None,
100 description: "One of 'on', 'off', or 'query'.",
101}];
102
103const PAUSE_SIGNATURES: [BuiltinSignatureDescriptor; 3] = [
104 BuiltinSignatureDescriptor {
105 label: "out = pause()",
106 inputs: &PAUSE_INPUTS_NONE,
107 outputs: &PAUSE_OUTPUT_EMPTY,
108 },
109 BuiltinSignatureDescriptor {
110 label: "out = pause(duration)",
111 inputs: &PAUSE_INPUTS_DURATION,
112 outputs: &PAUSE_OUTPUT_EMPTY,
113 },
114 BuiltinSignatureDescriptor {
115 label: "state = pause(command)",
116 inputs: &PAUSE_INPUTS_COMMAND,
117 outputs: &PAUSE_OUTPUT_STATE,
118 },
119];
120
121const PAUSE_ERROR_INVALID_ARG: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
122 code: "RM.PAUSE.INVALID_ARG",
123 identifier: Some("RunMat:pause:InvalidInputArgument"),
124 when: "Input argument is malformed, unsupported, non-scalar where scalar is required, or a negative/non-finite duration.",
125 message: "pause: invalid input argument",
126};
127
128const PAUSE_ERROR_TOO_MANY_INPUTS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
129 code: "RM.PAUSE.TOO_MANY_INPUTS",
130 identifier: Some("RunMat:pause:TooManyInputs"),
131 when: "More than one input argument is supplied.",
132 message: "pause: too many input arguments",
133};
134
135const PAUSE_ERROR_STATE_LOCK: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
136 code: "RM.PAUSE.STATE_LOCK",
137 identifier: Some("RunMat:pause:StateLockFailed"),
138 when: "Internal pause-state lock cannot be acquired.",
139 message: "pause: failed to acquire pause state",
140};
141
142const PAUSE_ERROR_GATHER_FAILED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
143 code: "RM.PAUSE.GPU_GATHER_FAILED",
144 identifier: Some("RunMat:pause:GpuGatherFailed"),
145 when: "Gathering a GPU argument to host fails during argument classification.",
146 message: "pause: failed to gather gpu input",
147};
148
149const PAUSE_ERRORS: [BuiltinErrorDescriptor; 4] = [
150 PAUSE_ERROR_INVALID_ARG,
151 PAUSE_ERROR_TOO_MANY_INPUTS,
152 PAUSE_ERROR_STATE_LOCK,
153 PAUSE_ERROR_GATHER_FAILED,
154];
155
156pub const PAUSE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
157 signatures: &PAUSE_SIGNATURES,
158 output_mode: BuiltinOutputMode::Fixed,
159 completion_policy: BuiltinCompletionPolicy::Public,
160 errors: &PAUSE_ERRORS,
161};
162
163fn pause_error_with_message(
164 message: impl Into<String>,
165 error: &'static BuiltinErrorDescriptor,
166) -> RuntimeError {
167 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
168 if let Some(identifier) = error.identifier {
169 builder = builder.with_identifier(identifier);
170 }
171 builder.build()
172}
173
174#[derive(Debug, Clone, Copy)]
175enum PauseArgument {
176 Wait(PauseWait),
177 SetState(bool),
178 Query,
179}
180
181#[derive(Debug, Clone, Copy)]
182enum PauseWait {
183 Default,
184 Seconds(f64),
185}
186
187#[runtime_builtin(
189 name = "pause",
190 category = "timing",
191 summary = "Pause execution until keypress or specified duration elapses.",
192 keywords = "pause,sleep,wait,delay",
193 accel = "metadata",
194 sink = true,
195 type_resolver(pause_type),
196 descriptor(crate::builtins::timing::pause::PAUSE_DESCRIPTOR),
197 builtin_path = "crate::builtins::timing::pause"
198)]
199async fn pause_builtin(args: Vec<Value>) -> BuiltinResult<Value> {
200 match args.len() {
201 0 => {
202 perform_wait(PauseWait::Default).await?;
203 Ok(empty_return_value())
204 }
205 1 => match classify_argument(&args[0]).await? {
206 PauseArgument::Wait(wait) => {
207 perform_wait(wait).await?;
208 Ok(empty_return_value())
209 }
210 PauseArgument::SetState(next_state) => {
211 let previous = set_pause_enabled(next_state)?;
212 Ok(state_value(previous))
213 }
214 PauseArgument::Query => {
215 let current = pause_enabled()?;
216 Ok(state_value(current))
217 }
218 },
219 _ => Err(pause_error_with_message(
220 PAUSE_ERROR_TOO_MANY_INPUTS.message,
221 &PAUSE_ERROR_TOO_MANY_INPUTS,
222 )),
223 }
224}
225
226async fn perform_wait(wait: PauseWait) -> Result<(), RuntimeError> {
227 if !pause_enabled()? {
228 return Ok(());
229 }
230
231 #[cfg(all(target_arch = "wasm32", feature = "plot-web"))]
232 {
233 let handle = plotting::current_figure_handle();
236 let _ = plotting::render_current_scene(handle.as_u32());
238 }
239
240 match wait {
241 PauseWait::Default => wait_for_key_press().await,
242 PauseWait::Seconds(seconds) => {
243 if seconds == 0.0 {
244 #[cfg(target_arch = "wasm32")]
246 {
247 return wasm_sleep_seconds(0.0).await;
248 }
249 #[cfg(not(target_arch = "wasm32"))]
250 {
251 return Ok(());
252 }
253 }
254 sleep_seconds(seconds).await?;
255 #[cfg(all(target_arch = "wasm32", feature = "plot-web"))]
256 {
257 let handle = plotting::current_figure_handle();
261 let _ = plotting::render_current_scene(handle.as_u32());
262 }
263 Ok(())
264 }
265 }
266}
267
268async fn wait_for_key_press() -> Result<(), RuntimeError> {
269 #[cfg(test)]
270 {
271 Ok(())
272 }
273 #[cfg(not(test))]
274 {
275 interaction::wait_for_key_async("").await
276 }
277}
278
279async fn sleep_seconds(seconds: f64) -> Result<(), RuntimeError> {
280 #[cfg(target_arch = "wasm32")]
281 {
282 wasm_sleep_seconds(seconds).await
283 }
284 #[cfg(not(target_arch = "wasm32"))]
285 {
286 let duration = std::time::Duration::from_secs_f64(seconds);
288 std::thread::sleep(duration);
289 Ok(())
290 }
291}
292
293#[cfg(target_arch = "wasm32")]
294async fn wasm_sleep_seconds(seconds: f64) -> Result<(), RuntimeError> {
295 use js_sys::{Function, Promise, Reflect};
296 use wasm_bindgen::JsCast;
297 use wasm_bindgen_futures::JsFuture;
298
299 let global = js_sys::global();
302 let set_timeout = Reflect::get(&global, &wasm_bindgen::JsValue::from_str("setTimeout"))
303 .map_err(|_| build_runtime_error("pause: setTimeout unavailable").build())?
304 .dyn_into::<Function>()
305 .map_err(|_| build_runtime_error("pause: setTimeout unavailable").build())?;
306
307 let millis = (seconds * 1000.0).max(0.0).round();
308 let millis_i32 = if millis > i32::MAX as f64 {
309 i32::MAX
310 } else {
311 millis as i32
312 };
313
314 let promise = Promise::new(&mut |resolve, _reject| {
315 let resolve: Function = resolve.unchecked_into();
316 let _ = set_timeout.call2(
317 &global,
318 &resolve.into(),
319 &wasm_bindgen::JsValue::from_f64(millis_i32 as f64),
320 );
321 });
322
323 let _ = JsFuture::from(promise)
324 .await
325 .map_err(|err| build_runtime_error(format!("pause: timer failed ({err:?})")).build())?;
326 Ok(())
327}
328
329async fn classify_argument(arg: &Value) -> Result<PauseArgument, RuntimeError> {
330 let host_value = gpu_helpers::gather_value_async(arg)
331 .await
332 .map_err(|e| pause_error_with_message(format!("pause: {e}"), &PAUSE_ERROR_GATHER_FAILED))?;
333 match host_value {
334 Value::String(text) => parse_command(&text),
335 Value::CharArray(ca) => {
336 if ca.rows == 0 || ca.data.is_empty() {
337 Ok(PauseArgument::Wait(PauseWait::Default))
338 } else if ca.rows == 1 {
339 let text: String = ca.data.iter().collect();
340 parse_command(&text)
341 } else {
342 Err(pause_error_with_message(
343 PAUSE_ERROR_INVALID_ARG.message,
344 &PAUSE_ERROR_INVALID_ARG,
345 ))
346 }
347 }
348 Value::StringArray(sa) => {
349 if sa.data.is_empty() {
350 Ok(PauseArgument::Wait(PauseWait::Default))
351 } else if sa.data.len() == 1 {
352 parse_command(&sa.data[0])
353 } else {
354 Err(pause_error_with_message(
355 PAUSE_ERROR_INVALID_ARG.message,
356 &PAUSE_ERROR_INVALID_ARG,
357 ))
358 }
359 }
360 Value::Num(value) => parse_numeric(value),
361 Value::Int(int_value) => parse_numeric(int_value.to_f64()),
362 Value::Bool(flag) => parse_numeric(if flag { 1.0 } else { 0.0 }),
363 Value::Tensor(tensor) => parse_tensor(tensor),
364 Value::LogicalArray(logical) => parse_logical(logical),
365 Value::GpuTensor(handle) => {
366 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
367 parse_tensor(tensor)
368 }
369 Value::Complex(_, _)
370 | Value::ComplexTensor(_)
371 | Value::Cell(_)
372 | Value::Struct(_)
373 | Value::Object(_)
374 | Value::HandleObject(_)
375 | Value::Listener(_)
376 | Value::FunctionHandle(_)
377 | Value::ExternalFunctionHandle(_)
378 | Value::MethodFunctionHandle(_)
379 | Value::BoundFunctionHandle { .. }
380 | Value::Closure(_)
381 | Value::ClassRef(_)
382 | Value::MException(_)
383 | Value::OutputList(_) => Err(pause_error_with_message(
384 PAUSE_ERROR_INVALID_ARG.message,
385 &PAUSE_ERROR_INVALID_ARG,
386 )),
387 }
388}
389
390fn parse_command(raw: &str) -> Result<PauseArgument, RuntimeError> {
391 let trimmed = raw.trim();
392 if trimmed.is_empty() {
393 return Ok(PauseArgument::Wait(PauseWait::Default));
394 }
395 let lower = trimmed.to_ascii_lowercase();
396 match lower.as_str() {
397 "on" => Ok(PauseArgument::SetState(true)),
398 "off" => Ok(PauseArgument::SetState(false)),
399 "query" => Ok(PauseArgument::Query),
400 _ => Err(pause_error_with_message(
401 PAUSE_ERROR_INVALID_ARG.message,
402 &PAUSE_ERROR_INVALID_ARG,
403 )),
404 }
405}
406
407fn parse_numeric(value: f64) -> Result<PauseArgument, RuntimeError> {
408 if !value.is_finite() {
409 if value.is_sign_positive() {
410 return Ok(PauseArgument::Wait(PauseWait::Default));
411 }
412 return Err(pause_error_with_message(
413 PAUSE_ERROR_INVALID_ARG.message,
414 &PAUSE_ERROR_INVALID_ARG,
415 ));
416 }
417 if value < 0.0 {
418 return Err(pause_error_with_message(
419 PAUSE_ERROR_INVALID_ARG.message,
420 &PAUSE_ERROR_INVALID_ARG,
421 ));
422 }
423 Ok(PauseArgument::Wait(PauseWait::Seconds(value)))
424}
425
426fn parse_tensor(tensor: Tensor) -> Result<PauseArgument, RuntimeError> {
427 if tensor.data.is_empty() {
428 return Ok(PauseArgument::Wait(PauseWait::Default));
429 }
430 if tensor.data.len() != 1 {
431 return Err(pause_error_with_message(
432 PAUSE_ERROR_INVALID_ARG.message,
433 &PAUSE_ERROR_INVALID_ARG,
434 ));
435 }
436 parse_numeric(tensor.data[0])
437}
438
439fn parse_logical(logical: LogicalArray) -> Result<PauseArgument, RuntimeError> {
440 if logical.data.is_empty() {
441 return Ok(PauseArgument::Wait(PauseWait::Default));
442 }
443 if logical.data.len() != 1 {
444 return Err(pause_error_with_message(
445 PAUSE_ERROR_INVALID_ARG.message,
446 &PAUSE_ERROR_INVALID_ARG,
447 ));
448 }
449 let scalar = if logical.data[0] != 0 { 1.0 } else { 0.0 };
450 parse_numeric(scalar)
451}
452
453fn empty_return_value() -> Value {
454 Value::Tensor(Tensor::zeros(vec![0, 0]))
455}
456
457fn state_value(enabled: bool) -> Value {
458 let text = if enabled { "on" } else { "off" };
459 Value::CharArray(CharArray::new_row(text))
460}
461
462fn pause_enabled() -> Result<bool, RuntimeError> {
463 PAUSE_STATE.read().map(|guard| guard.enabled).map_err(|_| {
464 pause_error_with_message(PAUSE_ERROR_STATE_LOCK.message, &PAUSE_ERROR_STATE_LOCK)
465 })
466}
467
468fn set_pause_enabled(next: bool) -> Result<bool, RuntimeError> {
469 let mut guard = PAUSE_STATE.write().map_err(|_| {
470 pause_error_with_message(PAUSE_ERROR_STATE_LOCK.message, &PAUSE_ERROR_STATE_LOCK)
471 })?;
472 let previous = guard.enabled;
473 guard.enabled = next;
474 Ok(previous)
475}
476
477#[cfg(test)]
478pub(crate) mod tests {
479 use super::*;
480 use crate::builtins::common::test_support;
481 use futures::executor::block_on;
482 use runmat_accelerate_api::HostTensorView;
483 use runmat_builtins::{IntValue, LogicalArray, Tensor};
484
485 #[cfg(feature = "wgpu")]
486 use runmat_accelerate::backend::wgpu::provider as wgpu_provider;
487
488 fn reset_state(enabled: bool) {
489 let mut guard = PAUSE_STATE.write().unwrap_or_else(|e| e.into_inner());
490 guard.enabled = enabled;
491 }
492
493 fn char_array_to_string(value: Value) -> String {
494 match value {
495 Value::CharArray(ca) if ca.rows == 1 => ca.data.iter().collect(),
496 other => panic!("expected char array, got {other:?}"),
497 }
498 }
499
500 fn assert_pause_error_identifier(err: crate::RuntimeError, identifier: &str) {
501 assert_eq!(
502 err.identifier(),
503 Some(identifier),
504 "message: {}",
505 err.message()
506 );
507 }
508
509 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
510 #[test]
511 fn query_returns_on_by_default() {
512 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
513 reset_state(true);
514 let result = block_on(pause_builtin(vec![Value::from("query")])).expect("pause query");
515 assert_eq!(char_array_to_string(result), "on");
516 }
517
518 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
519 #[test]
520 fn pause_off_returns_previous_state() {
521 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
522 reset_state(true);
523 let previous = block_on(pause_builtin(vec![Value::from("off")])).expect("pause off");
524 assert_eq!(char_array_to_string(previous), "on");
525 assert!(!pause_enabled().unwrap());
526 }
527
528 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
529 #[test]
530 fn pause_on_restores_state() {
531 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
532 reset_state(false);
533 let previous = block_on(pause_builtin(vec![Value::from("on")])).expect("pause on");
534 assert_eq!(char_array_to_string(previous), "off");
535 assert!(pause_enabled().unwrap());
536 }
537
538 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
539 #[test]
540 fn pause_default_returns_empty_tensor() {
541 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
542 reset_state(true);
543 let result = block_on(pause_builtin(Vec::new())).expect("pause()");
544 match result {
545 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
546 other => panic!("expected empty tensor, got {other:?}"),
547 }
548 }
549
550 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
551 #[test]
552 fn numeric_zero_is_accepted() {
553 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
554 reset_state(true);
555 let result = block_on(pause_builtin(vec![Value::Num(0.0)])).expect("pause(0)");
556 match result {
557 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
558 other => panic!("expected empty tensor, got {other:?}"),
559 }
560 }
561
562 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
563 #[test]
564 fn integer_scalar_is_accepted() {
565 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
566 reset_state(true);
567 let result =
568 block_on(pause_builtin(vec![Value::Int(IntValue::I32(0))])).expect("pause(int)");
569 match result {
570 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
571 other => panic!("expected empty tensor, got {other:?}"),
572 }
573 }
574
575 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
576 #[test]
577 fn numeric_negative_zero_is_treated_as_zero() {
578 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
579 reset_state(true);
580 let result = block_on(pause_builtin(vec![Value::Num(-0.0)])).expect("pause(-0)");
581 match result {
582 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
583 other => panic!("expected empty tensor, got {other:?}"),
584 }
585 }
586
587 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
588 #[test]
589 fn negative_duration_raises_error() {
590 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
591 reset_state(true);
592 let err = block_on(pause_builtin(vec![Value::Num(-0.1)])).unwrap_err();
593 assert_pause_error_identifier(err, PAUSE_ERROR_INVALID_ARG.identifier.unwrap());
594 }
595
596 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
597 #[test]
598 fn non_scalar_tensor_is_rejected() {
599 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
600 reset_state(true);
601 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
602 let err = block_on(pause_builtin(vec![Value::Tensor(tensor)])).unwrap_err();
603 assert_pause_error_identifier(err, PAUSE_ERROR_INVALID_ARG.identifier.unwrap());
604 }
605
606 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
607 #[test]
608 fn empty_tensor_behaves_like_default_pause() {
609 let _guard = TEST_GUARD.lock().unwrap();
610 reset_state(true);
611 let empty = Tensor::zeros(vec![0, 0]);
612 let result = block_on(pause_builtin(vec![Value::Tensor(empty)])).expect("pause([])");
613 match result {
614 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
615 other => panic!("expected empty tensor, got {other:?}"),
616 }
617 }
618
619 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
620 #[test]
621 fn logical_scalar_is_accepted() {
622 let _guard = TEST_GUARD.lock().unwrap();
623 reset_state(true);
624 let logical = LogicalArray::new(vec![1u8], vec![1, 1]).unwrap();
625 let result =
626 block_on(pause_builtin(vec![Value::LogicalArray(logical)])).expect("pause(true)");
627 match result {
628 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
629 other => panic!("expected empty tensor, got {other:?}"),
630 }
631 }
632
633 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
634 #[test]
635 fn infinite_duration_behaves_like_default() {
636 let _guard = TEST_GUARD.lock().unwrap();
637 reset_state(true);
638 let result = block_on(pause_builtin(vec![Value::Num(f64::INFINITY)])).expect("pause(Inf)");
639 match result {
640 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
641 other => panic!("expected empty tensor, got {other:?}"),
642 }
643 }
644
645 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
646 #[test]
647 fn pause_gpu_duration_gathered() {
648 let _guard = TEST_GUARD.lock().unwrap();
649 reset_state(true);
650 test_support::with_test_provider(|provider| {
651 let tensor = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
652 let view = HostTensorView {
653 data: &tensor.data,
654 shape: &tensor.shape,
655 };
656 let handle = provider.upload(&view).expect("upload");
657 let result =
658 block_on(pause_builtin(vec![Value::GpuTensor(handle)])).expect("pause(gpuScalar)");
659 match result {
660 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
661 other => panic!("expected empty tensor, got {other:?}"),
662 }
663 });
664 }
665
666 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
667 #[test]
668 #[cfg(feature = "wgpu")]
669 fn pause_wgpu_duration_gathered() {
670 let _guard = TEST_GUARD.lock().unwrap();
671 reset_state(true);
672 if wgpu_provider::register_wgpu_provider(wgpu_provider::WgpuProviderOptions::default())
673 .is_err()
674 {
675 return;
676 }
677 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
678 let tensor = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
679 let view = HostTensorView {
680 data: &tensor.data,
681 shape: &tensor.shape,
682 };
683 let handle = provider.upload(&view).expect("upload");
684 let result =
685 block_on(pause_builtin(vec![Value::GpuTensor(handle)])).expect("pause(gpuScalar)");
686 match result {
687 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
688 other => panic!("expected empty tensor, got {other:?}"),
689 }
690 }
691
692 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
693 #[test]
694 fn invalid_command_raises_error() {
695 let _guard = TEST_GUARD.lock().unwrap();
696 reset_state(true);
697 let err = block_on(pause_builtin(vec![Value::from("invalid")])).unwrap_err();
698 assert_pause_error_identifier(err, PAUSE_ERROR_INVALID_ARG.identifier.unwrap());
699 }
700
701 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
702 #[test]
703 fn too_many_inputs_raises_error() {
704 let _guard = TEST_GUARD.lock().unwrap();
705 reset_state(true);
706 let err = block_on(pause_builtin(vec![Value::Num(0.0), Value::Num(0.0)])).unwrap_err();
707 assert_pause_error_identifier(err, PAUSE_ERROR_TOO_MANY_INPUTS.identifier.unwrap());
708 }
709}