1use once_cell::sync::Lazy;
4use runmat_builtins::{
5 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7 CharArray, LogicalArray, Tensor, Value,
8};
9use runmat_macros::runtime_builtin;
10use std::sync::RwLock;
11
12use crate::builtins::common::gpu_helpers;
13use crate::builtins::common::spec::{
14 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15 ReductionNaN, ResidencyPolicy, ShapeRequirements,
16};
17#[cfg(all(target_arch = "wasm32", feature = "plot-web"))]
18use crate::builtins::plotting;
19use crate::builtins::timing::type_resolvers::pause_type;
20#[cfg(not(test))]
21use crate::interaction;
22use crate::{build_runtime_error, BuiltinResult, RuntimeError};
23
24#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::timing::pause")]
25pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
26 name: "pause",
27 op_kind: GpuOpKind::Custom("timer"),
28 supported_precisions: &[],
29 broadcast: BroadcastSemantics::None,
30 provider_hooks: &[],
31 constant_strategy: ConstantStrategy::InlineLiteral,
32 residency: ResidencyPolicy::GatherImmediately,
33 nan_mode: ReductionNaN::Include,
34 two_pass_threshold: None,
35 workgroup_size: None,
36 accepts_nan_mode: false,
37 notes: "pause executes entirely on the host. Acceleration providers are never queried.",
38};
39
40#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::timing::pause")]
41pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
42 name: "pause",
43 shape: ShapeRequirements::Any,
44 constant_strategy: ConstantStrategy::InlineLiteral,
45 elementwise: None,
46 reduction: None,
47 emits_nan: false,
48 notes: "pause suspends host execution and is excluded from fusion pipelines.",
49};
50
51static PAUSE_STATE: Lazy<RwLock<PauseState>> = Lazy::new(|| RwLock::new(PauseState::default()));
52
53#[cfg(test)]
54use std::sync::Mutex;
55#[cfg(test)]
56pub(crate) static TEST_GUARD: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
57
58#[derive(Debug, Clone, Copy)]
59struct PauseState {
60 enabled: bool,
61}
62
63impl Default for PauseState {
64 fn default() -> Self {
65 Self { enabled: true }
66 }
67}
68
69const BUILTIN_NAME: &str = "pause";
70
71const PAUSE_OUTPUT_EMPTY: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
72 name: "out",
73 ty: BuiltinParamType::NumericArray,
74 arity: BuiltinParamArity::Required,
75 default: None,
76 description: "Empty array when pausing or changing state.",
77}];
78
79const PAUSE_OUTPUT_STATE: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
80 name: "state",
81 ty: BuiltinParamType::StringScalar,
82 arity: BuiltinParamArity::Required,
83 default: None,
84 description: "Previous pause state ('on' or 'off').",
85}];
86
87const PAUSE_INPUTS_NONE: [BuiltinParamDescriptor; 0] = [];
88const PAUSE_INPUTS_DURATION: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
89 name: "duration",
90 ty: BuiltinParamType::Any,
91 arity: BuiltinParamArity::Required,
92 default: Some("0"),
93 description: "Duration scalar or command-like scalar value accepted by pause.",
94}];
95const PAUSE_INPUTS_COMMAND: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
96 name: "command",
97 ty: BuiltinParamType::StringScalar,
98 arity: BuiltinParamArity::Required,
99 default: None,
100 description: "One of 'on', 'off', or 'query'.",
101}];
102
103const PAUSE_SIGNATURES: [BuiltinSignatureDescriptor; 3] = [
104 BuiltinSignatureDescriptor {
105 label: "out = pause()",
106 inputs: &PAUSE_INPUTS_NONE,
107 outputs: &PAUSE_OUTPUT_EMPTY,
108 },
109 BuiltinSignatureDescriptor {
110 label: "out = pause(duration)",
111 inputs: &PAUSE_INPUTS_DURATION,
112 outputs: &PAUSE_OUTPUT_EMPTY,
113 },
114 BuiltinSignatureDescriptor {
115 label: "state = pause(command)",
116 inputs: &PAUSE_INPUTS_COMMAND,
117 outputs: &PAUSE_OUTPUT_STATE,
118 },
119];
120
121const PAUSE_ERROR_INVALID_ARG: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
122 code: "RM.PAUSE.INVALID_ARG",
123 identifier: Some("RunMat:pause:InvalidInputArgument"),
124 when: "Input argument is malformed, unsupported, non-scalar where scalar is required, or a negative/non-finite duration.",
125 message: "pause: invalid input argument",
126};
127
128const PAUSE_ERROR_TOO_MANY_INPUTS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
129 code: "RM.PAUSE.TOO_MANY_INPUTS",
130 identifier: Some("RunMat:pause:TooManyInputs"),
131 when: "More than one input argument is supplied.",
132 message: "pause: too many input arguments",
133};
134
135const PAUSE_ERROR_STATE_LOCK: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
136 code: "RM.PAUSE.STATE_LOCK",
137 identifier: Some("RunMat:pause:StateLockFailed"),
138 when: "Internal pause-state lock cannot be acquired.",
139 message: "pause: failed to acquire pause state",
140};
141
142const PAUSE_ERROR_GATHER_FAILED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
143 code: "RM.PAUSE.GPU_GATHER_FAILED",
144 identifier: Some("RunMat:pause:GpuGatherFailed"),
145 when: "Gathering a GPU argument to host fails during argument classification.",
146 message: "pause: failed to gather gpu input",
147};
148
149const PAUSE_ERRORS: [BuiltinErrorDescriptor; 4] = [
150 PAUSE_ERROR_INVALID_ARG,
151 PAUSE_ERROR_TOO_MANY_INPUTS,
152 PAUSE_ERROR_STATE_LOCK,
153 PAUSE_ERROR_GATHER_FAILED,
154];
155
156pub const PAUSE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
157 signatures: &PAUSE_SIGNATURES,
158 output_mode: BuiltinOutputMode::Fixed,
159 completion_policy: BuiltinCompletionPolicy::Public,
160 errors: &PAUSE_ERRORS,
161};
162
163fn pause_error_with_message(
164 message: impl Into<String>,
165 error: &'static BuiltinErrorDescriptor,
166) -> RuntimeError {
167 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
168 if let Some(identifier) = error.identifier {
169 builder = builder.with_identifier(identifier);
170 }
171 builder.build()
172}
173
174#[derive(Debug, Clone, Copy)]
175enum PauseArgument {
176 Wait(PauseWait),
177 SetState(bool),
178 Query,
179}
180
181#[derive(Debug, Clone, Copy)]
182enum PauseWait {
183 Default,
184 Seconds(f64),
185}
186
187#[runtime_builtin(
189 name = "pause",
190 category = "timing",
191 summary = "Pause execution until keypress or specified duration elapses.",
192 keywords = "pause,sleep,wait,delay",
193 accel = "metadata",
194 sink = true,
195 type_resolver(pause_type),
196 descriptor(crate::builtins::timing::pause::PAUSE_DESCRIPTOR),
197 builtin_path = "crate::builtins::timing::pause"
198)]
199async fn pause_builtin(args: Vec<Value>) -> BuiltinResult<Value> {
200 match args.len() {
201 0 => {
202 perform_wait(PauseWait::Default).await?;
203 Ok(empty_return_value())
204 }
205 1 => match classify_argument(&args[0]).await? {
206 PauseArgument::Wait(wait) => {
207 perform_wait(wait).await?;
208 Ok(empty_return_value())
209 }
210 PauseArgument::SetState(next_state) => {
211 let previous = set_pause_enabled(next_state)?;
212 Ok(state_value(previous))
213 }
214 PauseArgument::Query => {
215 let current = pause_enabled()?;
216 Ok(state_value(current))
217 }
218 },
219 _ => Err(pause_error_with_message(
220 PAUSE_ERROR_TOO_MANY_INPUTS.message,
221 &PAUSE_ERROR_TOO_MANY_INPUTS,
222 )),
223 }
224}
225
226async fn perform_wait(wait: PauseWait) -> Result<(), RuntimeError> {
227 if !pause_enabled()? {
228 return Ok(());
229 }
230
231 #[cfg(all(target_arch = "wasm32", feature = "plot-web"))]
232 {
233 let handle = plotting::current_figure_handle();
236 let _ = plotting::render_current_scene(handle.as_u32());
238 }
239
240 match wait {
241 PauseWait::Default => wait_for_key_press().await,
242 PauseWait::Seconds(seconds) => {
243 if seconds == 0.0 {
244 #[cfg(target_arch = "wasm32")]
246 {
247 return wasm_sleep_seconds(0.0).await;
248 }
249 #[cfg(not(target_arch = "wasm32"))]
250 {
251 return Ok(());
252 }
253 }
254 sleep_seconds(seconds).await?;
255 #[cfg(all(target_arch = "wasm32", feature = "plot-web"))]
256 {
257 let handle = plotting::current_figure_handle();
261 let _ = plotting::render_current_scene(handle.as_u32());
262 }
263 Ok(())
264 }
265 }
266}
267
268async fn wait_for_key_press() -> Result<(), RuntimeError> {
269 #[cfg(test)]
270 {
271 Ok(())
272 }
273 #[cfg(not(test))]
274 {
275 interaction::wait_for_key_async("").await
276 }
277}
278
279async fn sleep_seconds(seconds: f64) -> Result<(), RuntimeError> {
280 #[cfg(target_arch = "wasm32")]
281 {
282 wasm_sleep_seconds(seconds).await
283 }
284 #[cfg(not(target_arch = "wasm32"))]
285 {
286 let duration = std::time::Duration::from_secs_f64(seconds);
288 std::thread::sleep(duration);
289 Ok(())
290 }
291}
292
293#[cfg(target_arch = "wasm32")]
294async fn wasm_sleep_seconds(seconds: f64) -> Result<(), RuntimeError> {
295 use js_sys::{Function, Promise, Reflect};
296 use wasm_bindgen::JsCast;
297 use wasm_bindgen_futures::JsFuture;
298
299 let global = js_sys::global();
302 let set_timeout = Reflect::get(&global, &wasm_bindgen::JsValue::from_str("setTimeout"))
303 .map_err(|_| build_runtime_error("pause: setTimeout unavailable").build())?
304 .dyn_into::<Function>()
305 .map_err(|_| build_runtime_error("pause: setTimeout unavailable").build())?;
306
307 let millis = (seconds * 1000.0).max(0.0).round();
308 let millis_i32 = if millis > i32::MAX as f64 {
309 i32::MAX
310 } else {
311 millis as i32
312 };
313
314 let promise = Promise::new(&mut |resolve, _reject| {
315 let resolve: Function = resolve.unchecked_into();
316 let _ = set_timeout.call2(
317 &global,
318 &resolve.into(),
319 &wasm_bindgen::JsValue::from_f64(millis_i32 as f64),
320 );
321 });
322
323 let _ = JsFuture::from(promise)
324 .await
325 .map_err(|err| build_runtime_error(format!("pause: timer failed ({err:?})")).build())?;
326 Ok(())
327}
328
329async fn classify_argument(arg: &Value) -> Result<PauseArgument, RuntimeError> {
330 let host_value = gpu_helpers::gather_value_async(arg)
331 .await
332 .map_err(|e| pause_error_with_message(format!("pause: {e}"), &PAUSE_ERROR_GATHER_FAILED))?;
333 match host_value {
334 Value::String(text) => parse_command(&text),
335 Value::CharArray(ca) => {
336 if ca.rows == 0 || ca.data.is_empty() {
337 Ok(PauseArgument::Wait(PauseWait::Default))
338 } else if ca.rows == 1 {
339 let text: String = ca.data.iter().collect();
340 parse_command(&text)
341 } else {
342 Err(pause_error_with_message(
343 PAUSE_ERROR_INVALID_ARG.message,
344 &PAUSE_ERROR_INVALID_ARG,
345 ))
346 }
347 }
348 Value::StringArray(sa) => {
349 if sa.data.is_empty() {
350 Ok(PauseArgument::Wait(PauseWait::Default))
351 } else if sa.data.len() == 1 {
352 parse_command(&sa.data[0])
353 } else {
354 Err(pause_error_with_message(
355 PAUSE_ERROR_INVALID_ARG.message,
356 &PAUSE_ERROR_INVALID_ARG,
357 ))
358 }
359 }
360 Value::Num(value) => parse_numeric(value),
361 Value::Int(int_value) => parse_numeric(int_value.to_f64()),
362 Value::Bool(flag) => parse_numeric(if flag { 1.0 } else { 0.0 }),
363 Value::Tensor(tensor) => parse_tensor(tensor),
364 Value::LogicalArray(logical) => parse_logical(logical),
365 Value::GpuTensor(handle) => {
366 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
367 parse_tensor(tensor)
368 }
369 Value::Complex(_, _)
370 | Value::ComplexTensor(_)
371 | Value::SparseTensor(_)
372 | Value::Cell(_)
373 | Value::Struct(_)
374 | Value::Object(_)
375 | Value::HandleObject(_)
376 | Value::Listener(_)
377 | Value::FunctionHandle(_)
378 | Value::ExternalFunctionHandle(_)
379 | Value::MethodFunctionHandle(_)
380 | Value::BoundFunctionHandle { .. }
381 | Value::Closure(_)
382 | Value::ClassRef(_)
383 | Value::MException(_)
384 | Value::OutputList(_) => Err(pause_error_with_message(
385 PAUSE_ERROR_INVALID_ARG.message,
386 &PAUSE_ERROR_INVALID_ARG,
387 )),
388 }
389}
390
391fn parse_command(raw: &str) -> Result<PauseArgument, RuntimeError> {
392 let trimmed = raw.trim();
393 if trimmed.is_empty() {
394 return Ok(PauseArgument::Wait(PauseWait::Default));
395 }
396 let lower = trimmed.to_ascii_lowercase();
397 match lower.as_str() {
398 "on" => Ok(PauseArgument::SetState(true)),
399 "off" => Ok(PauseArgument::SetState(false)),
400 "query" => Ok(PauseArgument::Query),
401 _ => Err(pause_error_with_message(
402 PAUSE_ERROR_INVALID_ARG.message,
403 &PAUSE_ERROR_INVALID_ARG,
404 )),
405 }
406}
407
408fn parse_numeric(value: f64) -> Result<PauseArgument, RuntimeError> {
409 if !value.is_finite() {
410 if value.is_sign_positive() {
411 return Ok(PauseArgument::Wait(PauseWait::Default));
412 }
413 return Err(pause_error_with_message(
414 PAUSE_ERROR_INVALID_ARG.message,
415 &PAUSE_ERROR_INVALID_ARG,
416 ));
417 }
418 if value < 0.0 {
419 return Err(pause_error_with_message(
420 PAUSE_ERROR_INVALID_ARG.message,
421 &PAUSE_ERROR_INVALID_ARG,
422 ));
423 }
424 Ok(PauseArgument::Wait(PauseWait::Seconds(value)))
425}
426
427fn parse_tensor(tensor: Tensor) -> Result<PauseArgument, RuntimeError> {
428 if tensor.data.is_empty() {
429 return Ok(PauseArgument::Wait(PauseWait::Default));
430 }
431 if tensor.data.len() != 1 {
432 return Err(pause_error_with_message(
433 PAUSE_ERROR_INVALID_ARG.message,
434 &PAUSE_ERROR_INVALID_ARG,
435 ));
436 }
437 parse_numeric(tensor.data[0])
438}
439
440fn parse_logical(logical: LogicalArray) -> Result<PauseArgument, RuntimeError> {
441 if logical.data.is_empty() {
442 return Ok(PauseArgument::Wait(PauseWait::Default));
443 }
444 if logical.data.len() != 1 {
445 return Err(pause_error_with_message(
446 PAUSE_ERROR_INVALID_ARG.message,
447 &PAUSE_ERROR_INVALID_ARG,
448 ));
449 }
450 let scalar = if logical.data[0] != 0 { 1.0 } else { 0.0 };
451 parse_numeric(scalar)
452}
453
454fn empty_return_value() -> Value {
455 Value::Tensor(Tensor::zeros(vec![0, 0]))
456}
457
458fn state_value(enabled: bool) -> Value {
459 let text = if enabled { "on" } else { "off" };
460 Value::CharArray(CharArray::new_row(text))
461}
462
463fn pause_enabled() -> Result<bool, RuntimeError> {
464 PAUSE_STATE.read().map(|guard| guard.enabled).map_err(|_| {
465 pause_error_with_message(PAUSE_ERROR_STATE_LOCK.message, &PAUSE_ERROR_STATE_LOCK)
466 })
467}
468
469fn set_pause_enabled(next: bool) -> Result<bool, RuntimeError> {
470 let mut guard = PAUSE_STATE.write().map_err(|_| {
471 pause_error_with_message(PAUSE_ERROR_STATE_LOCK.message, &PAUSE_ERROR_STATE_LOCK)
472 })?;
473 let previous = guard.enabled;
474 guard.enabled = next;
475 Ok(previous)
476}
477
478#[cfg(test)]
479pub(crate) mod tests {
480 use super::*;
481 use crate::builtins::common::test_support;
482 use futures::executor::block_on;
483 use runmat_accelerate_api::HostTensorView;
484 use runmat_builtins::{IntValue, LogicalArray, Tensor};
485
486 #[cfg(feature = "wgpu")]
487 use runmat_accelerate::backend::wgpu::provider as wgpu_provider;
488
489 fn reset_state(enabled: bool) {
490 let mut guard = PAUSE_STATE.write().unwrap_or_else(|e| e.into_inner());
491 guard.enabled = enabled;
492 }
493
494 fn char_array_to_string(value: Value) -> String {
495 match value {
496 Value::CharArray(ca) if ca.rows == 1 => ca.data.iter().collect(),
497 other => panic!("expected char array, got {other:?}"),
498 }
499 }
500
501 fn assert_pause_error_identifier(err: crate::RuntimeError, identifier: &str) {
502 assert_eq!(
503 err.identifier(),
504 Some(identifier),
505 "message: {}",
506 err.message()
507 );
508 }
509
510 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
511 #[test]
512 fn query_returns_on_by_default() {
513 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
514 reset_state(true);
515 let result = block_on(pause_builtin(vec![Value::from("query")])).expect("pause query");
516 assert_eq!(char_array_to_string(result), "on");
517 }
518
519 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
520 #[test]
521 fn pause_off_returns_previous_state() {
522 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
523 reset_state(true);
524 let previous = block_on(pause_builtin(vec![Value::from("off")])).expect("pause off");
525 assert_eq!(char_array_to_string(previous), "on");
526 assert!(!pause_enabled().unwrap());
527 }
528
529 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
530 #[test]
531 fn pause_on_restores_state() {
532 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
533 reset_state(false);
534 let previous = block_on(pause_builtin(vec![Value::from("on")])).expect("pause on");
535 assert_eq!(char_array_to_string(previous), "off");
536 assert!(pause_enabled().unwrap());
537 }
538
539 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
540 #[test]
541 fn pause_default_returns_empty_tensor() {
542 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
543 reset_state(true);
544 let result = block_on(pause_builtin(Vec::new())).expect("pause()");
545 match result {
546 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
547 other => panic!("expected empty tensor, got {other:?}"),
548 }
549 }
550
551 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
552 #[test]
553 fn numeric_zero_is_accepted() {
554 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
555 reset_state(true);
556 let result = block_on(pause_builtin(vec![Value::Num(0.0)])).expect("pause(0)");
557 match result {
558 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
559 other => panic!("expected empty tensor, got {other:?}"),
560 }
561 }
562
563 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
564 #[test]
565 fn integer_scalar_is_accepted() {
566 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
567 reset_state(true);
568 let result =
569 block_on(pause_builtin(vec![Value::Int(IntValue::I32(0))])).expect("pause(int)");
570 match result {
571 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
572 other => panic!("expected empty tensor, got {other:?}"),
573 }
574 }
575
576 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
577 #[test]
578 fn numeric_negative_zero_is_treated_as_zero() {
579 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
580 reset_state(true);
581 let result = block_on(pause_builtin(vec![Value::Num(-0.0)])).expect("pause(-0)");
582 match result {
583 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
584 other => panic!("expected empty tensor, got {other:?}"),
585 }
586 }
587
588 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
589 #[test]
590 fn negative_duration_raises_error() {
591 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
592 reset_state(true);
593 let err = block_on(pause_builtin(vec![Value::Num(-0.1)])).unwrap_err();
594 assert_pause_error_identifier(err, PAUSE_ERROR_INVALID_ARG.identifier.unwrap());
595 }
596
597 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
598 #[test]
599 fn non_scalar_tensor_is_rejected() {
600 let _guard = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
601 reset_state(true);
602 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
603 let err = block_on(pause_builtin(vec![Value::Tensor(tensor)])).unwrap_err();
604 assert_pause_error_identifier(err, PAUSE_ERROR_INVALID_ARG.identifier.unwrap());
605 }
606
607 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
608 #[test]
609 fn empty_tensor_behaves_like_default_pause() {
610 let _guard = TEST_GUARD.lock().unwrap();
611 reset_state(true);
612 let empty = Tensor::zeros(vec![0, 0]);
613 let result = block_on(pause_builtin(vec![Value::Tensor(empty)])).expect("pause([])");
614 match result {
615 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
616 other => panic!("expected empty tensor, got {other:?}"),
617 }
618 }
619
620 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
621 #[test]
622 fn logical_scalar_is_accepted() {
623 let _guard = TEST_GUARD.lock().unwrap();
624 reset_state(true);
625 let logical = LogicalArray::new(vec![1u8], vec![1, 1]).unwrap();
626 let result =
627 block_on(pause_builtin(vec![Value::LogicalArray(logical)])).expect("pause(true)");
628 match result {
629 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
630 other => panic!("expected empty tensor, got {other:?}"),
631 }
632 }
633
634 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
635 #[test]
636 fn infinite_duration_behaves_like_default() {
637 let _guard = TEST_GUARD.lock().unwrap();
638 reset_state(true);
639 let result = block_on(pause_builtin(vec![Value::Num(f64::INFINITY)])).expect("pause(Inf)");
640 match result {
641 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
642 other => panic!("expected empty tensor, got {other:?}"),
643 }
644 }
645
646 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
647 #[test]
648 fn pause_gpu_duration_gathered() {
649 let _guard = TEST_GUARD.lock().unwrap();
650 reset_state(true);
651 test_support::with_test_provider(|provider| {
652 let tensor = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
653 let view = HostTensorView {
654 data: &tensor.data,
655 shape: &tensor.shape,
656 };
657 let handle = provider.upload(&view).expect("upload");
658 let result =
659 block_on(pause_builtin(vec![Value::GpuTensor(handle)])).expect("pause(gpuScalar)");
660 match result {
661 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
662 other => panic!("expected empty tensor, got {other:?}"),
663 }
664 });
665 }
666
667 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
668 #[test]
669 #[cfg(feature = "wgpu")]
670 fn pause_wgpu_duration_gathered() {
671 let _guard = TEST_GUARD.lock().unwrap();
672 reset_state(true);
673 if wgpu_provider::register_wgpu_provider(wgpu_provider::WgpuProviderOptions::default())
674 .is_err()
675 {
676 return;
677 }
678 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
679 let tensor = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
680 let view = HostTensorView {
681 data: &tensor.data,
682 shape: &tensor.shape,
683 };
684 let handle = provider.upload(&view).expect("upload");
685 let result =
686 block_on(pause_builtin(vec![Value::GpuTensor(handle)])).expect("pause(gpuScalar)");
687 match result {
688 Value::Tensor(t) => assert_eq!(t.data.len(), 0),
689 other => panic!("expected empty tensor, got {other:?}"),
690 }
691 }
692
693 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
694 #[test]
695 fn invalid_command_raises_error() {
696 let _guard = TEST_GUARD.lock().unwrap();
697 reset_state(true);
698 let err = block_on(pause_builtin(vec![Value::from("invalid")])).unwrap_err();
699 assert_pause_error_identifier(err, PAUSE_ERROR_INVALID_ARG.identifier.unwrap());
700 }
701
702 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
703 #[test]
704 fn too_many_inputs_raises_error() {
705 let _guard = TEST_GUARD.lock().unwrap();
706 reset_state(true);
707 let err = block_on(pause_builtin(vec![Value::Num(0.0), Value::Num(0.0)])).unwrap_err();
708 assert_pause_error_identifier(err, PAUSE_ERROR_TOO_MANY_INPUTS.identifier.unwrap());
709 }
710}