1use runmat_builtins::{
4 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6 ComplexTensor, Tensor, Value,
7};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::format::format_variadic;
11use crate::builtins::common::gpu_helpers;
12use crate::builtins::common::spec::{
13 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14 ReductionNaN, ResidencyPolicy, ShapeRequirements,
15};
16use crate::builtins::diagnostics::type_resolvers::assert_type;
17use crate::{build_runtime_error, RuntimeError};
18
19const BUILTIN_NAME: &str = "assert";
20
21const ASSERT_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
22 name: "out",
23 ty: BuiltinParamType::NumericArray,
24 arity: BuiltinParamArity::Required,
25 default: None,
26 description: "Zero when the assertion passes.",
27}];
28
29const ASSERT_INPUTS_CONDITION: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
30 name: "condition",
31 ty: BuiltinParamType::Any,
32 arity: BuiltinParamArity::Required,
33 default: None,
34 description: "Logical/numeric condition that must evaluate to true.",
35}];
36
37const ASSERT_INPUTS_MESSAGE: [BuiltinParamDescriptor; 2] = [
38 BuiltinParamDescriptor {
39 name: "condition",
40 ty: BuiltinParamType::Any,
41 arity: BuiltinParamArity::Required,
42 default: None,
43 description: "Logical/numeric condition that must evaluate to true.",
44 },
45 BuiltinParamDescriptor {
46 name: "message",
47 ty: BuiltinParamType::StringScalar,
48 arity: BuiltinParamArity::Required,
49 default: Some("\"Assertion failed.\""),
50 description: "Failure message text.",
51 },
52];
53
54const ASSERT_INPUTS_MESSAGE_VARIADIC: [BuiltinParamDescriptor; 3] = [
55 BuiltinParamDescriptor {
56 name: "condition",
57 ty: BuiltinParamType::Any,
58 arity: BuiltinParamArity::Required,
59 default: None,
60 description: "Logical/numeric condition that must evaluate to true.",
61 },
62 BuiltinParamDescriptor {
63 name: "message",
64 ty: BuiltinParamType::StringScalar,
65 arity: BuiltinParamArity::Required,
66 default: Some("\"Assertion failed.\""),
67 description: "Failure message template text.",
68 },
69 BuiltinParamDescriptor {
70 name: "A",
71 ty: BuiltinParamType::Any,
72 arity: BuiltinParamArity::Variadic,
73 default: None,
74 description: "Formatting values for the message template.",
75 },
76];
77
78const ASSERT_INPUTS_IDENTIFIER_MESSAGE: [BuiltinParamDescriptor; 3] = [
79 BuiltinParamDescriptor {
80 name: "condition",
81 ty: BuiltinParamType::Any,
82 arity: BuiltinParamArity::Required,
83 default: None,
84 description: "Logical/numeric condition that must evaluate to true.",
85 },
86 BuiltinParamDescriptor {
87 name: "message_id",
88 ty: BuiltinParamType::StringScalar,
89 arity: BuiltinParamArity::Required,
90 default: Some("\"RunMat:assertion:failed\""),
91 description: "Message identifier.",
92 },
93 BuiltinParamDescriptor {
94 name: "message",
95 ty: BuiltinParamType::StringScalar,
96 arity: BuiltinParamArity::Required,
97 default: Some("\"Assertion failed.\""),
98 description: "Failure message text.",
99 },
100];
101
102const ASSERT_INPUTS_IDENTIFIER_MESSAGE_VARIADIC: [BuiltinParamDescriptor; 4] = [
103 BuiltinParamDescriptor {
104 name: "condition",
105 ty: BuiltinParamType::Any,
106 arity: BuiltinParamArity::Required,
107 default: None,
108 description: "Logical/numeric condition that must evaluate to true.",
109 },
110 BuiltinParamDescriptor {
111 name: "message_id",
112 ty: BuiltinParamType::StringScalar,
113 arity: BuiltinParamArity::Required,
114 default: Some("\"RunMat:assertion:failed\""),
115 description: "Message identifier.",
116 },
117 BuiltinParamDescriptor {
118 name: "message",
119 ty: BuiltinParamType::StringScalar,
120 arity: BuiltinParamArity::Required,
121 default: Some("\"Assertion failed.\""),
122 description: "Failure message template text.",
123 },
124 BuiltinParamDescriptor {
125 name: "A",
126 ty: BuiltinParamType::Any,
127 arity: BuiltinParamArity::Variadic,
128 default: None,
129 description: "Formatting values for the message template.",
130 },
131];
132
133const ASSERT_SIGNATURES: [BuiltinSignatureDescriptor; 5] = [
134 BuiltinSignatureDescriptor {
135 label: "out = assert(condition)",
136 inputs: &ASSERT_INPUTS_CONDITION,
137 outputs: &ASSERT_OUTPUT,
138 },
139 BuiltinSignatureDescriptor {
140 label: "out = assert(condition, message)",
141 inputs: &ASSERT_INPUTS_MESSAGE,
142 outputs: &ASSERT_OUTPUT,
143 },
144 BuiltinSignatureDescriptor {
145 label: "out = assert(condition, message, A...)",
146 inputs: &ASSERT_INPUTS_MESSAGE_VARIADIC,
147 outputs: &ASSERT_OUTPUT,
148 },
149 BuiltinSignatureDescriptor {
150 label: "out = assert(condition, message_id, message)",
151 inputs: &ASSERT_INPUTS_IDENTIFIER_MESSAGE,
152 outputs: &ASSERT_OUTPUT,
153 },
154 BuiltinSignatureDescriptor {
155 label: "out = assert(condition, message_id, message, A...)",
156 inputs: &ASSERT_INPUTS_IDENTIFIER_MESSAGE_VARIADIC,
157 outputs: &ASSERT_OUTPUT,
158 },
159];
160
161const ASSERT_ERROR_ASSERTION_FAILED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
162 code: "RM.ASSERT.ASSERTION_FAILED",
163 identifier: Some("RunMat:assertion:failed"),
164 when: "Condition evaluates to false and no custom identifier/message override is provided.",
165 message: "Assertion failed.",
166};
167
168const ASSERT_ERROR_INVALID_CONDITION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
169 code: "RM.ASSERT.INVALID_CONDITION",
170 identifier: Some("RunMat:assertion:invalidCondition"),
171 when: "First argument is not a supported logical or numeric condition input.",
172 message: "assert: first input must be logical or numeric.",
173};
174
175const ASSERT_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
176 code: "RM.ASSERT.INVALID_INPUT",
177 identifier: Some("RunMat:assertion:invalidInput"),
178 when: "Message identifier/message text or formatting payload is invalid.",
179 message: "assert: invalid input argument",
180};
181
182const ASSERT_ERROR_NOT_ENOUGH_INPUTS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
183 code: "RM.ASSERT.NOT_ENOUGH_INPUTS",
184 identifier: Some("RunMat:minrhs"),
185 when: "No condition argument is provided.",
186 message: "Not enough input arguments.",
187};
188
189const ASSERT_ERRORS: [BuiltinErrorDescriptor; 4] = [
190 ASSERT_ERROR_ASSERTION_FAILED,
191 ASSERT_ERROR_INVALID_CONDITION,
192 ASSERT_ERROR_INVALID_INPUT,
193 ASSERT_ERROR_NOT_ENOUGH_INPUTS,
194];
195
196pub const ASSERT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
197 signatures: &ASSERT_SIGNATURES,
198 output_mode: BuiltinOutputMode::Fixed,
199 completion_policy: BuiltinCompletionPolicy::Public,
200 errors: &ASSERT_ERRORS,
201};
202
203#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::diagnostics::assert")]
204pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
205 name: "assert",
206 op_kind: GpuOpKind::Custom("control"),
207 supported_precisions: &[],
208 broadcast: BroadcastSemantics::None,
209 provider_hooks: &[],
210 constant_strategy: ConstantStrategy::InlineLiteral,
211 residency: ResidencyPolicy::GatherImmediately,
212 nan_mode: ReductionNaN::Include,
213 two_pass_threshold: None,
214 workgroup_size: None,
215 accepts_nan_mode: false,
216 notes: "Control-flow builtin; GPU tensors are gathered to host memory before evaluation.",
217};
218
219#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::diagnostics::assert")]
220pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
221 name: "assert",
222 shape: ShapeRequirements::Any,
223 constant_strategy: ConstantStrategy::InlineLiteral,
224 elementwise: None,
225 reduction: None,
226 emits_nan: false,
227 notes: "Control-flow builtin with no fusion support.",
228};
229
230fn assert_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
231 assert_error_with_message(error.message, error)
232}
233
234fn assert_default_identifier() -> &'static str {
235 ASSERT_ERROR_ASSERTION_FAILED
236 .identifier
237 .expect("assert default identifier must be defined")
238}
239
240fn assert_default_message() -> &'static str {
241 ASSERT_ERROR_ASSERTION_FAILED.message
242}
243
244fn assert_error_with_message(
245 message: impl Into<String>,
246 error: &'static BuiltinErrorDescriptor,
247) -> RuntimeError {
248 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
249 if let Some(identifier) = error.identifier {
250 builder = builder.with_identifier(normalize_identifier(identifier));
251 }
252 builder.build()
253}
254
255fn assert_flow(identifier: &str, message: impl Into<String>) -> RuntimeError {
256 build_runtime_error(message)
257 .with_builtin(BUILTIN_NAME)
258 .with_identifier(normalize_identifier(identifier))
259 .build()
260}
261
262fn remap_assert_flow<F>(
263 err: RuntimeError,
264 error: &'static BuiltinErrorDescriptor,
265 message: F,
266) -> RuntimeError
267where
268 F: FnOnce(&crate::RuntimeError) -> String,
269{
270 let mut builder = build_runtime_error(message(&err))
271 .with_builtin(BUILTIN_NAME)
272 .with_source(err);
273 if let Some(identifier) = error.identifier {
274 builder = builder.with_identifier(normalize_identifier(identifier));
275 }
276 builder.build()
277}
278
279#[runtime_builtin(
280 name = "assert",
281 category = "diagnostics",
282 summary = "Throw an error when a condition is false, matching MATLAB assert semantics.",
283 keywords = "assert,diagnostics,validation,error",
284 accel = "metadata",
285 type_resolver(assert_type),
286 descriptor(crate::builtins::diagnostics::assert::ASSERT_DESCRIPTOR),
287 builtin_path = "crate::builtins::diagnostics::assert"
288)]
289async fn assert_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
290 if args.is_empty() {
291 return Err(assert_error(&ASSERT_ERROR_NOT_ENOUGH_INPUTS));
292 }
293
294 let mut iter = args.into_iter();
295 let condition_raw = iter.next().expect("checked length above");
296 let rest: Vec<Value> = iter.collect();
297
298 let condition = normalize_condition_value(condition_raw).await?;
299 match evaluate_condition(condition)? {
300 ConditionOutcome::Pass => Ok(Value::Num(0.0)),
301 ConditionOutcome::Fail => {
302 let payload = failure_payload(&rest)?;
303 Err(assert_flow(&payload.identifier, payload.message))
304 }
305 }
306}
307
308async fn normalize_condition_value(condition: Value) -> crate::BuiltinResult<Value> {
309 match condition {
310 Value::GpuTensor(handle) => {
311 let gpu_value = Value::GpuTensor(handle);
312 gpu_helpers::gather_value_async(&gpu_value)
313 .await
314 .map_err(|flow| {
315 remap_assert_flow(flow, &ASSERT_ERROR_INVALID_INPUT, |err| {
316 format!("assert: {}", err.message())
317 })
318 })
319 }
320 other => Ok(other),
321 }
322}
323
324#[derive(Copy, Clone, Debug, PartialEq, Eq)]
325enum ConditionOutcome {
326 Pass,
327 Fail,
328}
329
330fn evaluate_condition(value: Value) -> crate::BuiltinResult<ConditionOutcome> {
331 match value {
332 Value::Bool(flag) => Ok(if flag {
333 ConditionOutcome::Pass
334 } else {
335 ConditionOutcome::Fail
336 }),
337 Value::Int(int_value) => {
338 if int_value.to_i64() != 0 {
339 Ok(ConditionOutcome::Pass)
340 } else {
341 Ok(ConditionOutcome::Fail)
342 }
343 }
344 Value::Num(num) => {
345 if num.is_nan() || num == 0.0 {
346 Ok(ConditionOutcome::Fail)
347 } else {
348 Ok(ConditionOutcome::Pass)
349 }
350 }
351 Value::Complex(re, im) => {
352 if complex_element_passes(re, im) {
353 Ok(ConditionOutcome::Pass)
354 } else {
355 Ok(ConditionOutcome::Fail)
356 }
357 }
358 Value::LogicalArray(array) => {
359 if array.data.iter().all(|&bit| bit != 0) {
360 Ok(ConditionOutcome::Pass)
361 } else {
362 Ok(ConditionOutcome::Fail)
363 }
364 }
365 Value::Tensor(tensor) => evaluate_tensor_condition(&tensor),
366 Value::ComplexTensor(tensor) => evaluate_complex_tensor(&tensor),
367 Value::GpuTensor(_) => {
368 unreachable!("gpu tensors are gathered in normalize_condition_value")
369 }
370 _ => Err(assert_error(&ASSERT_ERROR_INVALID_CONDITION)),
371 }
372}
373
374fn evaluate_tensor_condition(tensor: &Tensor) -> crate::BuiltinResult<ConditionOutcome> {
375 if tensor.data.is_empty() {
376 return Ok(ConditionOutcome::Pass);
377 }
378 for value in &tensor.data {
379 if value.is_nan() || *value == 0.0 {
380 return Ok(ConditionOutcome::Fail);
381 }
382 }
383 Ok(ConditionOutcome::Pass)
384}
385
386fn evaluate_complex_tensor(tensor: &ComplexTensor) -> crate::BuiltinResult<ConditionOutcome> {
387 if tensor.data.is_empty() {
388 return Ok(ConditionOutcome::Pass);
389 }
390 for &(re, im) in &tensor.data {
391 if !complex_element_passes(re, im) {
392 return Ok(ConditionOutcome::Fail);
393 }
394 }
395 Ok(ConditionOutcome::Pass)
396}
397
398fn complex_element_passes(re: f64, im: f64) -> bool {
399 if re.is_nan() || im.is_nan() {
400 return false;
401 }
402 re != 0.0 || im != 0.0
403}
404
405struct FailurePayload {
406 identifier: String,
407 message: String,
408}
409
410fn failure_payload(args: &[Value]) -> crate::BuiltinResult<FailurePayload> {
411 if args.is_empty() {
412 return Ok(FailurePayload {
413 identifier: assert_default_identifier().to_string(),
414 message: assert_default_message().to_string(),
415 });
416 }
417
418 let candidate = &args[0];
419 let treat_as_identifier = args.len() >= 2 && value_is_identifier(candidate);
420
421 if treat_as_identifier {
422 if args.len() < 2 {
423 return Err(assert_flow(
424 ASSERT_ERROR_INVALID_INPUT
425 .identifier
426 .expect("assert invalid-input identifier must be defined"),
427 "assert: message text must follow the message identifier.",
428 ));
429 }
430 let identifier = identifier_from_value(candidate)?;
431 let template = message_from_value(&args[1])?;
432 let formatting_args: &[Value] = if args.len() > 2 { &args[2..] } else { &[] };
433 let message = format_message(&template, formatting_args)?;
434 Ok(FailurePayload {
435 identifier,
436 message,
437 })
438 } else {
439 let template = message_from_value(candidate)?;
440 let formatting_args: &[Value] = if args.len() > 1 { &args[1..] } else { &[] };
441 let message = format_message(&template, formatting_args)?;
442 Ok(FailurePayload {
443 identifier: assert_default_identifier().to_string(),
444 message,
445 })
446 }
447}
448
449fn value_is_identifier(value: &Value) -> bool {
450 if let Some(text) = string_scalar_opt(value) {
451 is_message_identifier(&text) || looks_like_unqualified_identifier(&text)
452 } else {
453 false
454 }
455}
456
457fn identifier_from_value(value: &Value) -> crate::BuiltinResult<String> {
458 let text = string_scalar_from_value(
459 value,
460 "assert: message identifier must be a string scalar or character vector.",
461 )?;
462 if text.trim().is_empty() {
463 return Err(assert_flow(
464 ASSERT_ERROR_INVALID_INPUT
465 .identifier
466 .expect("assert invalid-input identifier must be defined"),
467 "assert: message identifier must be nonempty.",
468 ));
469 }
470 Ok(normalize_identifier(&text))
471}
472
473fn message_from_value(value: &Value) -> crate::BuiltinResult<String> {
474 string_scalar_from_value(
475 value,
476 "assert: message text must be a string scalar or character vector.",
477 )
478}
479
480fn format_message(template: &str, args: &[Value]) -> crate::BuiltinResult<String> {
481 format_variadic(template, args).map_err(|flow| {
482 remap_assert_flow(flow, &ASSERT_ERROR_INVALID_INPUT, |err| {
483 format!("assert: {}", err.message())
484 })
485 })
486}
487
488fn normalize_identifier(raw: &str) -> String {
489 let trimmed = raw.trim();
490 if trimmed.is_empty() {
491 assert_default_identifier().to_string()
492 } else if trimmed.contains(':') {
493 trimmed.to_string()
494 } else {
495 format!("RunMat:{trimmed}")
496 }
497}
498
499fn is_message_identifier(text: &str) -> bool {
500 let trimmed = text.trim();
501 if trimmed.is_empty() || !trimmed.contains(':') {
502 return false;
503 }
504 trimmed
505 .chars()
506 .all(|ch| ch.is_ascii_alphanumeric() || matches!(ch, ':' | '_' | '.'))
507}
508
509fn looks_like_unqualified_identifier(text: &str) -> bool {
510 let trimmed = text.trim();
511 if trimmed.is_empty() || trimmed.contains(char::is_whitespace) {
512 return false;
513 }
514 trimmed
515 .chars()
516 .all(|ch| ch.is_ascii_alphanumeric() || matches!(ch, '_' | '.'))
517}
518
519fn string_scalar_from_value(value: &Value, context: &str) -> crate::BuiltinResult<String> {
520 match value {
521 Value::String(text) => Ok(text.clone()),
522 Value::StringArray(array) if array.data.len() == 1 => Ok(array.data[0].clone()),
523 Value::CharArray(char_array) if char_array.rows == 1 => {
524 Ok(char_array.data.iter().collect::<String>())
525 }
526 _ => Err(assert_error_with_message(
527 context,
528 &ASSERT_ERROR_INVALID_INPUT,
529 )),
530 }
531}
532
533fn string_scalar_opt(value: &Value) -> Option<String> {
534 match value {
535 Value::String(text) => Some(text.clone()),
536 Value::StringArray(array) if array.data.len() == 1 => Some(array.data[0].clone()),
537 Value::CharArray(char_array) if char_array.rows == 1 => {
538 Some(char_array.data.iter().collect())
539 }
540 _ => None,
541 }
542}
543
544#[cfg(test)]
545pub(crate) mod tests {
546 use super::*;
547 use crate::builtins::common::test_support;
548 use futures::executor::block_on;
549 use runmat_builtins::{ComplexTensor, IntValue, LogicalArray, ResolveContext, Tensor, Type};
550
551 fn assert_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
552 block_on(super::assert_builtin(args))
553 }
554
555 fn unwrap_error(err: crate::RuntimeError) -> crate::RuntimeError {
556 err
557 }
558
559 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
560 #[test]
561 fn assert_true_passes() {
562 let result = assert_builtin(vec![Value::Bool(true)]).expect("assert should pass");
563 assert_eq!(result, Value::Num(0.0));
564 }
565
566 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
567 #[test]
568 fn assert_empty_tensor_passes() {
569 let tensor = Tensor::new(Vec::new(), vec![0, 3]).unwrap();
570 assert_builtin(vec![Value::Tensor(tensor)]).expect("assert should pass");
571 }
572
573 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
574 #[test]
575 fn assert_empty_logical_passes() {
576 let logical = LogicalArray::new(Vec::new(), vec![0]).unwrap();
577 assert_builtin(vec![Value::LogicalArray(logical)]).expect("assert should pass");
578 }
579
580 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
581 #[test]
582 fn assert_false_uses_default_message() {
583 let err =
584 unwrap_error(assert_builtin(vec![Value::Bool(false)]).expect_err("assert should fail"));
585 assert_eq!(err.identifier(), Some(assert_default_identifier()));
586 assert_eq!(err.message(), assert_default_message());
587 }
588
589 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
590 #[test]
591 fn assert_handles_numeric_tensor() {
592 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
593 assert_builtin(vec![Value::Tensor(tensor)]).expect("assert should pass");
594 }
595
596 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
597 #[test]
598 fn assert_detects_zero_in_tensor() {
599 let tensor = Tensor::new(vec![1.0, 0.0, 3.0], vec![3, 1]).unwrap();
600 let err = unwrap_error(
601 assert_builtin(vec![Value::Tensor(tensor)]).expect_err("assert should fail"),
602 );
603 assert_eq!(err.identifier(), Some(assert_default_identifier()));
604 }
605
606 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
607 #[test]
608 fn assert_detects_nan() {
609 let err = unwrap_error(
610 assert_builtin(vec![Value::Num(f64::NAN)]).expect_err("assert should fail"),
611 );
612 assert_eq!(err.identifier(), Some(assert_default_identifier()));
613 }
614
615 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
616 #[test]
617 fn assert_complex_scalar_passes() {
618 assert_builtin(vec![Value::Complex(0.0, 2.0)]).expect("assert should pass");
619 }
620
621 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
622 #[test]
623 fn assert_complex_scalar_failure() {
624 let err = unwrap_error(
625 assert_builtin(vec![Value::Complex(0.0, 0.0)]).expect_err("assert should fail"),
626 );
627 assert_eq!(err.identifier(), Some(assert_default_identifier()));
628 }
629
630 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
631 #[test]
632 fn assert_complex_tensor_failure() {
633 let tensor = ComplexTensor::new(vec![(1.0, 0.0), (0.0, 0.0)], vec![2, 1]).expect("tensor");
634 let err = unwrap_error(
635 assert_builtin(vec![Value::ComplexTensor(tensor)]).expect_err("assert should fail"),
636 );
637 assert_eq!(err.identifier(), Some(assert_default_identifier()));
638 }
639
640 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
641 #[test]
642 fn assert_accepts_custom_message() {
643 let err = unwrap_error(
644 assert_builtin(vec![
645 Value::Bool(false),
646 Value::from("Vector length must be positive."),
647 ])
648 .expect_err("assert should fail"),
649 );
650 assert_eq!(err.identifier(), Some(assert_default_identifier()));
651 assert!(err.message().contains("Vector length must be positive."));
652 }
653
654 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
655 #[test]
656 fn assert_supports_message_formatting() {
657 let err = unwrap_error(
658 assert_builtin(vec![
659 Value::Bool(false),
660 Value::from("Expected positive value, got %d."),
661 Value::Int(IntValue::I32(-4)),
662 ])
663 .expect_err("assert should fail"),
664 );
665 assert_eq!(err.identifier(), Some(assert_default_identifier()));
666 assert!(err.message().contains("Expected positive value, got -4."));
667 }
668
669 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
670 #[test]
671 fn assert_supports_custom_identifier() {
672 let err = unwrap_error(
673 assert_builtin(vec![
674 Value::Bool(false),
675 Value::from("runmat:tests:failed"),
676 Value::from("Failure %d occurred."),
677 Value::Int(IntValue::I32(3)),
678 ])
679 .expect_err("assert should fail"),
680 );
681 assert_eq!(err.identifier(), Some("runmat:tests:failed"));
682 assert!(err.message().contains("Failure 3 occurred."));
683 }
684
685 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
686 #[test]
687 fn assert_unqualified_identifier_prefixed() {
688 let err = unwrap_error(
689 assert_builtin(vec![
690 Value::Bool(false),
691 Value::from("customAssertionFailed"),
692 Value::from("runtime failure"),
693 ])
694 .expect_err("assert should fail"),
695 );
696 assert_eq!(err.identifier(), Some("RunMat:customAssertionFailed"));
697 }
698
699 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
700 #[test]
701 fn assert_rejects_invalid_condition_type() {
702 let err = unwrap_error(
703 assert_builtin(vec![Value::from("invalid")]).expect_err("assert should error"),
704 );
705 assert_eq!(
706 err.identifier(),
707 Some(ASSERT_ERROR_INVALID_CONDITION.identifier.unwrap())
708 );
709 }
710
711 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
712 #[test]
713 fn assert_gpu_tensor_passes() {
714 test_support::with_test_provider(|provider| {
715 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
716 let view = runmat_accelerate_api::HostTensorView {
717 data: &tensor.data,
718 shape: &tensor.shape,
719 };
720 let handle = provider.upload(&view).expect("upload");
721 let result = assert_builtin(vec![Value::GpuTensor(handle)]).expect("assert");
722 assert_eq!(result, Value::Num(0.0));
723 });
724 }
725
726 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
727 #[test]
728 fn assert_invalid_message_type_errors() {
729 let err = unwrap_error(
730 assert_builtin(vec![Value::Bool(false), Value::Num(5.0)])
731 .expect_err("assert should error"),
732 );
733 assert_eq!(
734 err.identifier(),
735 Some(ASSERT_ERROR_INVALID_INPUT.identifier.unwrap())
736 );
737 }
738
739 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
740 #[test]
741 fn assert_formatting_error_propagates() {
742 let err = unwrap_error(
743 assert_builtin(vec![
744 Value::Bool(false),
745 Value::from("number %d must be > 0"),
746 ])
747 .expect_err("assert should fail"),
748 );
749 assert_eq!(
750 err.identifier(),
751 Some(ASSERT_ERROR_INVALID_INPUT.identifier.unwrap())
752 );
753 assert!(err.message().contains("sprintf"));
754 }
755
756 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
757 #[test]
758 fn assert_gpu_tensor_failure() {
759 test_support::with_test_provider(|provider| {
760 let tensor = Tensor::new(vec![1.0, 0.0, 3.0], vec![3, 1]).unwrap();
761 let view = runmat_accelerate_api::HostTensorView {
762 data: &tensor.data,
763 shape: &tensor.shape,
764 };
765 let handle = provider.upload(&view).expect("upload");
766 let err =
767 unwrap_error(assert_builtin(vec![Value::GpuTensor(handle)]).expect_err("assert"));
768 assert_eq!(err.identifier(), Some(assert_default_identifier()));
769 });
770 }
771
772 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
773 #[test]
774 fn assert_logical_array_failure() {
775 let logical = LogicalArray::new(vec![1, 0], vec![2]).unwrap();
776 let err = unwrap_error(
777 assert_builtin(vec![Value::LogicalArray(logical)]).expect_err("assert should fail"),
778 );
779 assert_eq!(err.identifier(), Some(assert_default_identifier()));
780 }
781
782 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
783 #[test]
784 fn assert_requires_condition_argument() {
785 let err = unwrap_error(assert_builtin(Vec::new()).expect_err("assert should error"));
786 assert_eq!(
787 err.identifier(),
788 Some(ASSERT_ERROR_NOT_ENOUGH_INPUTS.identifier.unwrap())
789 );
790 assert_eq!(err.message(), ASSERT_ERROR_NOT_ENOUGH_INPUTS.message);
791 }
792
793 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
794 #[test]
795 #[cfg(feature = "wgpu")]
796 fn assert_wgpu_tensor_failure_matches_cpu() {
797 use runmat_accelerate::backend::wgpu::provider::{
798 register_wgpu_provider, WgpuProviderOptions,
799 };
800
801 if register_wgpu_provider(WgpuProviderOptions::default()).is_err() {
802 return;
803 }
804 let Some(provider) = runmat_accelerate_api::provider() else {
805 return;
806 };
807
808 let tensor = Tensor::new(vec![1.0, 0.0], vec![2, 1]).unwrap();
809 let view = runmat_accelerate_api::HostTensorView {
810 data: &tensor.data,
811 shape: &tensor.shape,
812 };
813 let handle = provider.upload(&view).expect("upload");
814 let err = unwrap_error(
815 assert_builtin(vec![Value::GpuTensor(handle)]).expect_err("assert should fail"),
816 );
817 assert_eq!(err.identifier(), Some(assert_default_identifier()));
818 }
819
820 #[test]
821 fn assert_type_is_numeric() {
822 assert_eq!(
823 assert_type(&[Type::Bool], &ResolveContext::new(Vec::new())),
824 Type::Num
825 );
826 }
827}