1use runmat_builtins::{ComplexTensor, Tensor, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::format::format_variadic;
7use crate::builtins::common::gpu_helpers;
8use crate::builtins::common::spec::{
9 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10 ReductionNaN, ResidencyPolicy, ShapeRequirements,
11};
12use crate::builtins::diagnostics::type_resolvers::assert_type;
13use crate::{build_runtime_error, RuntimeError};
14
15const DEFAULT_IDENTIFIER: &str = "RunMat:assertion:failed";
16const DEFAULT_MESSAGE: &str = "Assertion failed.";
17const INVALID_CONDITION_IDENTIFIER: &str = "RunMat:assertion:invalidCondition";
18const INVALID_INPUT_IDENTIFIER: &str = "RunMat:assertion:invalidInput";
19const MIN_INPUT_IDENTIFIER: &str = "RunMat:minrhs";
20const MIN_INPUT_MESSAGE: &str = "Not enough input arguments.";
21
22#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::diagnostics::assert")]
23pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
24 name: "assert",
25 op_kind: GpuOpKind::Custom("control"),
26 supported_precisions: &[],
27 broadcast: BroadcastSemantics::None,
28 provider_hooks: &[],
29 constant_strategy: ConstantStrategy::InlineLiteral,
30 residency: ResidencyPolicy::GatherImmediately,
31 nan_mode: ReductionNaN::Include,
32 two_pass_threshold: None,
33 workgroup_size: None,
34 accepts_nan_mode: false,
35 notes: "Control-flow builtin; GPU tensors are gathered to host memory before evaluation.",
36};
37
38#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::diagnostics::assert")]
39pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
40 name: "assert",
41 shape: ShapeRequirements::Any,
42 constant_strategy: ConstantStrategy::InlineLiteral,
43 elementwise: None,
44 reduction: None,
45 emits_nan: false,
46 notes: "Control-flow builtin with no fusion support.",
47};
48
49fn assert_flow(identifier: &str, message: impl Into<String>) -> RuntimeError {
50 build_runtime_error(message)
51 .with_builtin("assert")
52 .with_identifier(normalize_identifier(identifier))
53 .build()
54}
55
56fn remap_assert_flow<F>(err: RuntimeError, identifier: &str, message: F) -> RuntimeError
57where
58 F: FnOnce(&crate::RuntimeError) -> String,
59{
60 build_runtime_error(message(&err))
61 .with_builtin("assert")
62 .with_identifier(normalize_identifier(identifier))
63 .with_source(err)
64 .build()
65}
66
67#[runtime_builtin(
68 name = "assert",
69 category = "diagnostics",
70 summary = "Throw a MATLAB-style error when a logical or numeric condition evaluates to false.",
71 keywords = "assert,diagnostics,validation,error",
72 accel = "metadata",
73 type_resolver(assert_type),
74 builtin_path = "crate::builtins::diagnostics::assert"
75)]
76async fn assert_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
77 if args.is_empty() {
78 return Err(assert_flow(MIN_INPUT_IDENTIFIER, MIN_INPUT_MESSAGE));
79 }
80
81 let mut iter = args.into_iter();
82 let condition_raw = iter.next().expect("checked length above");
83 let rest: Vec<Value> = iter.collect();
84
85 let condition = normalize_condition_value(condition_raw).await?;
86 match evaluate_condition(condition)? {
87 ConditionOutcome::Pass => Ok(Value::Num(0.0)),
88 ConditionOutcome::Fail => {
89 let payload = failure_payload(&rest)?;
90 Err(assert_flow(&payload.identifier, payload.message))
91 }
92 }
93}
94
95async fn normalize_condition_value(condition: Value) -> crate::BuiltinResult<Value> {
96 match condition {
97 Value::GpuTensor(handle) => {
98 let gpu_value = Value::GpuTensor(handle);
99 gpu_helpers::gather_value_async(&gpu_value)
100 .await
101 .map_err(|flow| {
102 remap_assert_flow(flow, INVALID_INPUT_IDENTIFIER, |err| {
103 format!("assert: {}", err.message())
104 })
105 })
106 }
107 other => Ok(other),
108 }
109}
110
111#[derive(Copy, Clone, Debug, PartialEq, Eq)]
112enum ConditionOutcome {
113 Pass,
114 Fail,
115}
116
117fn evaluate_condition(value: Value) -> crate::BuiltinResult<ConditionOutcome> {
118 match value {
119 Value::Bool(flag) => Ok(if flag {
120 ConditionOutcome::Pass
121 } else {
122 ConditionOutcome::Fail
123 }),
124 Value::Int(int_value) => {
125 if int_value.to_i64() != 0 {
126 Ok(ConditionOutcome::Pass)
127 } else {
128 Ok(ConditionOutcome::Fail)
129 }
130 }
131 Value::Num(num) => {
132 if num.is_nan() || num == 0.0 {
133 Ok(ConditionOutcome::Fail)
134 } else {
135 Ok(ConditionOutcome::Pass)
136 }
137 }
138 Value::Complex(re, im) => {
139 if complex_element_passes(re, im) {
140 Ok(ConditionOutcome::Pass)
141 } else {
142 Ok(ConditionOutcome::Fail)
143 }
144 }
145 Value::LogicalArray(array) => {
146 if array.data.iter().all(|&bit| bit != 0) {
147 Ok(ConditionOutcome::Pass)
148 } else {
149 Ok(ConditionOutcome::Fail)
150 }
151 }
152 Value::Tensor(tensor) => evaluate_tensor_condition(&tensor),
153 Value::ComplexTensor(tensor) => evaluate_complex_tensor(&tensor),
154 Value::GpuTensor(_) => {
155 unreachable!("gpu tensors are gathered in normalize_condition_value")
156 }
157 _ => Err(assert_flow(
158 INVALID_CONDITION_IDENTIFIER,
159 "assert: first input must be logical or numeric.",
160 )),
161 }
162}
163
164fn evaluate_tensor_condition(tensor: &Tensor) -> crate::BuiltinResult<ConditionOutcome> {
165 if tensor.data.is_empty() {
166 return Ok(ConditionOutcome::Pass);
167 }
168 for value in &tensor.data {
169 if value.is_nan() || *value == 0.0 {
170 return Ok(ConditionOutcome::Fail);
171 }
172 }
173 Ok(ConditionOutcome::Pass)
174}
175
176fn evaluate_complex_tensor(tensor: &ComplexTensor) -> crate::BuiltinResult<ConditionOutcome> {
177 if tensor.data.is_empty() {
178 return Ok(ConditionOutcome::Pass);
179 }
180 for &(re, im) in &tensor.data {
181 if !complex_element_passes(re, im) {
182 return Ok(ConditionOutcome::Fail);
183 }
184 }
185 Ok(ConditionOutcome::Pass)
186}
187
188fn complex_element_passes(re: f64, im: f64) -> bool {
189 if re.is_nan() || im.is_nan() {
190 return false;
191 }
192 re != 0.0 || im != 0.0
193}
194
195struct FailurePayload {
196 identifier: String,
197 message: String,
198}
199
200fn failure_payload(args: &[Value]) -> crate::BuiltinResult<FailurePayload> {
201 if args.is_empty() {
202 return Ok(FailurePayload {
203 identifier: DEFAULT_IDENTIFIER.to_string(),
204 message: DEFAULT_MESSAGE.to_string(),
205 });
206 }
207
208 let candidate = &args[0];
209 let treat_as_identifier = args.len() >= 2 && value_is_identifier(candidate);
210
211 if treat_as_identifier {
212 if args.len() < 2 {
213 return Err(assert_flow(
214 INVALID_INPUT_IDENTIFIER,
215 "assert: message text must follow the message identifier.",
216 ));
217 }
218 let identifier = identifier_from_value(candidate)?;
219 let template = message_from_value(&args[1])?;
220 let formatting_args: &[Value] = if args.len() > 2 { &args[2..] } else { &[] };
221 let message = format_message(&template, formatting_args)?;
222 Ok(FailurePayload {
223 identifier,
224 message,
225 })
226 } else {
227 let template = message_from_value(candidate)?;
228 let formatting_args: &[Value] = if args.len() > 1 { &args[1..] } else { &[] };
229 let message = format_message(&template, formatting_args)?;
230 Ok(FailurePayload {
231 identifier: DEFAULT_IDENTIFIER.to_string(),
232 message,
233 })
234 }
235}
236
237fn value_is_identifier(value: &Value) -> bool {
238 if let Some(text) = string_scalar_opt(value) {
239 is_message_identifier(&text) || looks_like_unqualified_identifier(&text)
240 } else {
241 false
242 }
243}
244
245fn identifier_from_value(value: &Value) -> crate::BuiltinResult<String> {
246 let text = string_scalar_from_value(
247 value,
248 "assert: message identifier must be a string scalar or character vector.",
249 )?;
250 if text.trim().is_empty() {
251 return Err(assert_flow(
252 INVALID_INPUT_IDENTIFIER,
253 "assert: message identifier must be nonempty.",
254 ));
255 }
256 Ok(normalize_identifier(&text))
257}
258
259fn message_from_value(value: &Value) -> crate::BuiltinResult<String> {
260 string_scalar_from_value(
261 value,
262 "assert: message text must be a string scalar or character vector.",
263 )
264}
265
266fn format_message(template: &str, args: &[Value]) -> crate::BuiltinResult<String> {
267 format_variadic(template, args).map_err(|flow| {
268 remap_assert_flow(flow, INVALID_INPUT_IDENTIFIER, |err| {
269 format!("assert: {}", err.message())
270 })
271 })
272}
273
274fn normalize_identifier(raw: &str) -> String {
275 let trimmed = raw.trim();
276 if trimmed.is_empty() {
277 DEFAULT_IDENTIFIER.to_string()
278 } else if trimmed.contains(':') {
279 trimmed.to_string()
280 } else {
281 format!("RunMat:{trimmed}")
282 }
283}
284
285fn is_message_identifier(text: &str) -> bool {
286 let trimmed = text.trim();
287 if trimmed.is_empty() || !trimmed.contains(':') {
288 return false;
289 }
290 trimmed
291 .chars()
292 .all(|ch| ch.is_ascii_alphanumeric() || matches!(ch, ':' | '_' | '.'))
293}
294
295fn looks_like_unqualified_identifier(text: &str) -> bool {
296 let trimmed = text.trim();
297 if trimmed.is_empty() || trimmed.contains(char::is_whitespace) {
298 return false;
299 }
300 trimmed
301 .chars()
302 .all(|ch| ch.is_ascii_alphanumeric() || matches!(ch, '_' | '.'))
303}
304
305fn string_scalar_from_value(value: &Value, context: &str) -> crate::BuiltinResult<String> {
306 match value {
307 Value::String(text) => Ok(text.clone()),
308 Value::StringArray(array) if array.data.len() == 1 => Ok(array.data[0].clone()),
309 Value::CharArray(char_array) if char_array.rows == 1 => {
310 Ok(char_array.data.iter().collect::<String>())
311 }
312 _ => Err(assert_flow(INVALID_INPUT_IDENTIFIER, context)),
313 }
314}
315
316fn string_scalar_opt(value: &Value) -> Option<String> {
317 match value {
318 Value::String(text) => Some(text.clone()),
319 Value::StringArray(array) if array.data.len() == 1 => Some(array.data[0].clone()),
320 Value::CharArray(char_array) if char_array.rows == 1 => {
321 Some(char_array.data.iter().collect())
322 }
323 _ => None,
324 }
325}
326
327#[cfg(test)]
328pub(crate) mod tests {
329 use super::*;
330 use crate::builtins::common::test_support;
331 use futures::executor::block_on;
332 use runmat_builtins::{ComplexTensor, IntValue, LogicalArray, ResolveContext, Tensor, Type};
333
334 fn assert_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
335 block_on(super::assert_builtin(args))
336 }
337
338 fn unwrap_error(err: crate::RuntimeError) -> crate::RuntimeError {
339 err
340 }
341
342 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
343 #[test]
344 fn assert_true_passes() {
345 let result = assert_builtin(vec![Value::Bool(true)]).expect("assert should pass");
346 assert_eq!(result, Value::Num(0.0));
347 }
348
349 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
350 #[test]
351 fn assert_empty_tensor_passes() {
352 let tensor = Tensor::new(Vec::new(), vec![0, 3]).unwrap();
353 assert_builtin(vec![Value::Tensor(tensor)]).expect("assert should pass");
354 }
355
356 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
357 #[test]
358 fn assert_empty_logical_passes() {
359 let logical = LogicalArray::new(Vec::new(), vec![0]).unwrap();
360 assert_builtin(vec![Value::LogicalArray(logical)]).expect("assert should pass");
361 }
362
363 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
364 #[test]
365 fn assert_false_uses_default_message() {
366 let err =
367 unwrap_error(assert_builtin(vec![Value::Bool(false)]).expect_err("assert should fail"));
368 assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
369 assert_eq!(err.message(), DEFAULT_MESSAGE);
370 }
371
372 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
373 #[test]
374 fn assert_handles_numeric_tensor() {
375 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
376 assert_builtin(vec![Value::Tensor(tensor)]).expect("assert should pass");
377 }
378
379 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
380 #[test]
381 fn assert_detects_zero_in_tensor() {
382 let tensor = Tensor::new(vec![1.0, 0.0, 3.0], vec![3, 1]).unwrap();
383 let err = unwrap_error(
384 assert_builtin(vec![Value::Tensor(tensor)]).expect_err("assert should fail"),
385 );
386 assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
387 }
388
389 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
390 #[test]
391 fn assert_detects_nan() {
392 let err = unwrap_error(
393 assert_builtin(vec![Value::Num(f64::NAN)]).expect_err("assert should fail"),
394 );
395 assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
396 }
397
398 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
399 #[test]
400 fn assert_complex_scalar_passes() {
401 assert_builtin(vec![Value::Complex(0.0, 2.0)]).expect("assert should pass");
402 }
403
404 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
405 #[test]
406 fn assert_complex_scalar_failure() {
407 let err = unwrap_error(
408 assert_builtin(vec![Value::Complex(0.0, 0.0)]).expect_err("assert should fail"),
409 );
410 assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
411 }
412
413 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
414 #[test]
415 fn assert_complex_tensor_failure() {
416 let tensor = ComplexTensor::new(vec![(1.0, 0.0), (0.0, 0.0)], vec![2, 1]).expect("tensor");
417 let err = unwrap_error(
418 assert_builtin(vec![Value::ComplexTensor(tensor)]).expect_err("assert should fail"),
419 );
420 assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
421 }
422
423 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
424 #[test]
425 fn assert_accepts_custom_message() {
426 let err = unwrap_error(
427 assert_builtin(vec![
428 Value::Bool(false),
429 Value::from("Vector length must be positive."),
430 ])
431 .expect_err("assert should fail"),
432 );
433 assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
434 assert!(err.message().contains("Vector length must be positive."));
435 }
436
437 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
438 #[test]
439 fn assert_supports_message_formatting() {
440 let err = unwrap_error(
441 assert_builtin(vec![
442 Value::Bool(false),
443 Value::from("Expected positive value, got %d."),
444 Value::Int(IntValue::I32(-4)),
445 ])
446 .expect_err("assert should fail"),
447 );
448 assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
449 assert!(err.message().contains("Expected positive value, got -4."));
450 }
451
452 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
453 #[test]
454 fn assert_supports_custom_identifier() {
455 let err = unwrap_error(
456 assert_builtin(vec![
457 Value::Bool(false),
458 Value::from("runmat:tests:failed"),
459 Value::from("Failure %d occurred."),
460 Value::Int(IntValue::I32(3)),
461 ])
462 .expect_err("assert should fail"),
463 );
464 assert_eq!(err.identifier(), Some("runmat:tests:failed"));
465 assert!(err.message().contains("Failure 3 occurred."));
466 }
467
468 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
469 #[test]
470 fn assert_unqualified_identifier_prefixed() {
471 let err = unwrap_error(
472 assert_builtin(vec![
473 Value::Bool(false),
474 Value::from("customAssertionFailed"),
475 Value::from("runtime failure"),
476 ])
477 .expect_err("assert should fail"),
478 );
479 assert_eq!(err.identifier(), Some("RunMat:customAssertionFailed"));
480 }
481
482 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
483 #[test]
484 fn assert_rejects_invalid_condition_type() {
485 let err = unwrap_error(
486 assert_builtin(vec![Value::from("invalid")]).expect_err("assert should error"),
487 );
488 assert_eq!(err.identifier(), Some(INVALID_CONDITION_IDENTIFIER));
489 }
490
491 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
492 #[test]
493 fn assert_gpu_tensor_passes() {
494 test_support::with_test_provider(|provider| {
495 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
496 let view = runmat_accelerate_api::HostTensorView {
497 data: &tensor.data,
498 shape: &tensor.shape,
499 };
500 let handle = provider.upload(&view).expect("upload");
501 let result = assert_builtin(vec![Value::GpuTensor(handle)]).expect("assert");
502 assert_eq!(result, Value::Num(0.0));
503 });
504 }
505
506 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
507 #[test]
508 fn assert_invalid_message_type_errors() {
509 let err = unwrap_error(
510 assert_builtin(vec![Value::Bool(false), Value::Num(5.0)])
511 .expect_err("assert should error"),
512 );
513 assert_eq!(err.identifier(), Some(INVALID_INPUT_IDENTIFIER));
514 }
515
516 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
517 #[test]
518 fn assert_formatting_error_propagates() {
519 let err = unwrap_error(
520 assert_builtin(vec![
521 Value::Bool(false),
522 Value::from("number %d must be > 0"),
523 ])
524 .expect_err("assert should fail"),
525 );
526 assert_eq!(err.identifier(), Some(INVALID_INPUT_IDENTIFIER));
527 assert!(err.message().contains("sprintf"));
528 }
529
530 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
531 #[test]
532 fn assert_gpu_tensor_failure() {
533 test_support::with_test_provider(|provider| {
534 let tensor = Tensor::new(vec![1.0, 0.0, 3.0], vec![3, 1]).unwrap();
535 let view = runmat_accelerate_api::HostTensorView {
536 data: &tensor.data,
537 shape: &tensor.shape,
538 };
539 let handle = provider.upload(&view).expect("upload");
540 let err =
541 unwrap_error(assert_builtin(vec![Value::GpuTensor(handle)]).expect_err("assert"));
542 assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
543 });
544 }
545
546 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
547 #[test]
548 fn assert_logical_array_failure() {
549 let logical = LogicalArray::new(vec![1, 0], vec![2]).unwrap();
550 let err = unwrap_error(
551 assert_builtin(vec![Value::LogicalArray(logical)]).expect_err("assert should fail"),
552 );
553 assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
554 }
555
556 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
557 #[test]
558 fn assert_requires_condition_argument() {
559 let err = unwrap_error(assert_builtin(Vec::new()).expect_err("assert should error"));
560 assert_eq!(err.identifier(), Some(MIN_INPUT_IDENTIFIER));
561 assert_eq!(err.message(), MIN_INPUT_MESSAGE);
562 }
563
564 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
565 #[test]
566 #[cfg(feature = "wgpu")]
567 fn assert_wgpu_tensor_failure_matches_cpu() {
568 use runmat_accelerate::backend::wgpu::provider::{
569 register_wgpu_provider, WgpuProviderOptions,
570 };
571
572 if register_wgpu_provider(WgpuProviderOptions::default()).is_err() {
573 return;
574 }
575 let Some(provider) = runmat_accelerate_api::provider() else {
576 return;
577 };
578
579 let tensor = Tensor::new(vec![1.0, 0.0], vec![2, 1]).unwrap();
580 let view = runmat_accelerate_api::HostTensorView {
581 data: &tensor.data,
582 shape: &tensor.shape,
583 };
584 let handle = provider.upload(&view).expect("upload");
585 let err = unwrap_error(
586 assert_builtin(vec![Value::GpuTensor(handle)]).expect_err("assert should fail"),
587 );
588 assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
589 }
590
591 #[test]
592 fn assert_type_is_numeric() {
593 assert_eq!(
594 assert_type(&[Type::Bool], &ResolveContext::new(Vec::new())),
595 Type::Num
596 );
597 }
598}