Skip to main content

runmat_runtime/builtins/diagnostics/
assert.rs

1//! MATLAB-compatible `assert` builtin that mirrors MATLAB diagnostic semantics.
2
3use runmat_builtins::{ComplexTensor, Tensor, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::format::format_variadic;
7use crate::builtins::common::gpu_helpers;
8use crate::builtins::common::spec::{
9    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10    ReductionNaN, ResidencyPolicy, ShapeRequirements,
11};
12use crate::builtins::diagnostics::type_resolvers::assert_type;
13use crate::{build_runtime_error, RuntimeError};
14
15const DEFAULT_IDENTIFIER: &str = "RunMat:assertion:failed";
16const DEFAULT_MESSAGE: &str = "Assertion failed.";
17const INVALID_CONDITION_IDENTIFIER: &str = "RunMat:assertion:invalidCondition";
18const INVALID_INPUT_IDENTIFIER: &str = "RunMat:assertion:invalidInput";
19const MIN_INPUT_IDENTIFIER: &str = "RunMat:minrhs";
20const MIN_INPUT_MESSAGE: &str = "Not enough input arguments.";
21
22#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::diagnostics::assert")]
23pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
24    name: "assert",
25    op_kind: GpuOpKind::Custom("control"),
26    supported_precisions: &[],
27    broadcast: BroadcastSemantics::None,
28    provider_hooks: &[],
29    constant_strategy: ConstantStrategy::InlineLiteral,
30    residency: ResidencyPolicy::GatherImmediately,
31    nan_mode: ReductionNaN::Include,
32    two_pass_threshold: None,
33    workgroup_size: None,
34    accepts_nan_mode: false,
35    notes: "Control-flow builtin; GPU tensors are gathered to host memory before evaluation.",
36};
37
38#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::diagnostics::assert")]
39pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
40    name: "assert",
41    shape: ShapeRequirements::Any,
42    constant_strategy: ConstantStrategy::InlineLiteral,
43    elementwise: None,
44    reduction: None,
45    emits_nan: false,
46    notes: "Control-flow builtin with no fusion support.",
47};
48
49fn assert_flow(identifier: &str, message: impl Into<String>) -> RuntimeError {
50    build_runtime_error(message)
51        .with_builtin("assert")
52        .with_identifier(normalize_identifier(identifier))
53        .build()
54}
55
56fn remap_assert_flow<F>(err: RuntimeError, identifier: &str, message: F) -> RuntimeError
57where
58    F: FnOnce(&crate::RuntimeError) -> String,
59{
60    build_runtime_error(message(&err))
61        .with_builtin("assert")
62        .with_identifier(normalize_identifier(identifier))
63        .with_source(err)
64        .build()
65}
66
67#[runtime_builtin(
68    name = "assert",
69    category = "diagnostics",
70    summary = "Throw a MATLAB-style error when a logical or numeric condition evaluates to false.",
71    keywords = "assert,diagnostics,validation,error",
72    accel = "metadata",
73    type_resolver(assert_type),
74    builtin_path = "crate::builtins::diagnostics::assert"
75)]
76async fn assert_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
77    if args.is_empty() {
78        return Err(assert_flow(MIN_INPUT_IDENTIFIER, MIN_INPUT_MESSAGE));
79    }
80
81    let mut iter = args.into_iter();
82    let condition_raw = iter.next().expect("checked length above");
83    let rest: Vec<Value> = iter.collect();
84
85    let condition = normalize_condition_value(condition_raw).await?;
86    match evaluate_condition(condition)? {
87        ConditionOutcome::Pass => Ok(Value::Num(0.0)),
88        ConditionOutcome::Fail => {
89            let payload = failure_payload(&rest)?;
90            Err(assert_flow(&payload.identifier, payload.message))
91        }
92    }
93}
94
95async fn normalize_condition_value(condition: Value) -> crate::BuiltinResult<Value> {
96    match condition {
97        Value::GpuTensor(handle) => {
98            let gpu_value = Value::GpuTensor(handle);
99            gpu_helpers::gather_value_async(&gpu_value)
100                .await
101                .map_err(|flow| {
102                    remap_assert_flow(flow, INVALID_INPUT_IDENTIFIER, |err| {
103                        format!("assert: {}", err.message())
104                    })
105                })
106        }
107        other => Ok(other),
108    }
109}
110
111#[derive(Copy, Clone, Debug, PartialEq, Eq)]
112enum ConditionOutcome {
113    Pass,
114    Fail,
115}
116
117fn evaluate_condition(value: Value) -> crate::BuiltinResult<ConditionOutcome> {
118    match value {
119        Value::Bool(flag) => Ok(if flag {
120            ConditionOutcome::Pass
121        } else {
122            ConditionOutcome::Fail
123        }),
124        Value::Int(int_value) => {
125            if int_value.to_i64() != 0 {
126                Ok(ConditionOutcome::Pass)
127            } else {
128                Ok(ConditionOutcome::Fail)
129            }
130        }
131        Value::Num(num) => {
132            if num.is_nan() || num == 0.0 {
133                Ok(ConditionOutcome::Fail)
134            } else {
135                Ok(ConditionOutcome::Pass)
136            }
137        }
138        Value::Complex(re, im) => {
139            if complex_element_passes(re, im) {
140                Ok(ConditionOutcome::Pass)
141            } else {
142                Ok(ConditionOutcome::Fail)
143            }
144        }
145        Value::LogicalArray(array) => {
146            if array.data.iter().all(|&bit| bit != 0) {
147                Ok(ConditionOutcome::Pass)
148            } else {
149                Ok(ConditionOutcome::Fail)
150            }
151        }
152        Value::Tensor(tensor) => evaluate_tensor_condition(&tensor),
153        Value::ComplexTensor(tensor) => evaluate_complex_tensor(&tensor),
154        Value::GpuTensor(_) => {
155            unreachable!("gpu tensors are gathered in normalize_condition_value")
156        }
157        _ => Err(assert_flow(
158            INVALID_CONDITION_IDENTIFIER,
159            "assert: first input must be logical or numeric.",
160        )),
161    }
162}
163
164fn evaluate_tensor_condition(tensor: &Tensor) -> crate::BuiltinResult<ConditionOutcome> {
165    if tensor.data.is_empty() {
166        return Ok(ConditionOutcome::Pass);
167    }
168    for value in &tensor.data {
169        if value.is_nan() || *value == 0.0 {
170            return Ok(ConditionOutcome::Fail);
171        }
172    }
173    Ok(ConditionOutcome::Pass)
174}
175
176fn evaluate_complex_tensor(tensor: &ComplexTensor) -> crate::BuiltinResult<ConditionOutcome> {
177    if tensor.data.is_empty() {
178        return Ok(ConditionOutcome::Pass);
179    }
180    for &(re, im) in &tensor.data {
181        if !complex_element_passes(re, im) {
182            return Ok(ConditionOutcome::Fail);
183        }
184    }
185    Ok(ConditionOutcome::Pass)
186}
187
188fn complex_element_passes(re: f64, im: f64) -> bool {
189    if re.is_nan() || im.is_nan() {
190        return false;
191    }
192    re != 0.0 || im != 0.0
193}
194
195struct FailurePayload {
196    identifier: String,
197    message: String,
198}
199
200fn failure_payload(args: &[Value]) -> crate::BuiltinResult<FailurePayload> {
201    if args.is_empty() {
202        return Ok(FailurePayload {
203            identifier: DEFAULT_IDENTIFIER.to_string(),
204            message: DEFAULT_MESSAGE.to_string(),
205        });
206    }
207
208    let candidate = &args[0];
209    let treat_as_identifier = args.len() >= 2 && value_is_identifier(candidate);
210
211    if treat_as_identifier {
212        if args.len() < 2 {
213            return Err(assert_flow(
214                INVALID_INPUT_IDENTIFIER,
215                "assert: message text must follow the message identifier.",
216            ));
217        }
218        let identifier = identifier_from_value(candidate)?;
219        let template = message_from_value(&args[1])?;
220        let formatting_args: &[Value] = if args.len() > 2 { &args[2..] } else { &[] };
221        let message = format_message(&template, formatting_args)?;
222        Ok(FailurePayload {
223            identifier,
224            message,
225        })
226    } else {
227        let template = message_from_value(candidate)?;
228        let formatting_args: &[Value] = if args.len() > 1 { &args[1..] } else { &[] };
229        let message = format_message(&template, formatting_args)?;
230        Ok(FailurePayload {
231            identifier: DEFAULT_IDENTIFIER.to_string(),
232            message,
233        })
234    }
235}
236
237fn value_is_identifier(value: &Value) -> bool {
238    if let Some(text) = string_scalar_opt(value) {
239        is_message_identifier(&text) || looks_like_unqualified_identifier(&text)
240    } else {
241        false
242    }
243}
244
245fn identifier_from_value(value: &Value) -> crate::BuiltinResult<String> {
246    let text = string_scalar_from_value(
247        value,
248        "assert: message identifier must be a string scalar or character vector.",
249    )?;
250    if text.trim().is_empty() {
251        return Err(assert_flow(
252            INVALID_INPUT_IDENTIFIER,
253            "assert: message identifier must be nonempty.",
254        ));
255    }
256    Ok(normalize_identifier(&text))
257}
258
259fn message_from_value(value: &Value) -> crate::BuiltinResult<String> {
260    string_scalar_from_value(
261        value,
262        "assert: message text must be a string scalar or character vector.",
263    )
264}
265
266fn format_message(template: &str, args: &[Value]) -> crate::BuiltinResult<String> {
267    format_variadic(template, args).map_err(|flow| {
268        remap_assert_flow(flow, INVALID_INPUT_IDENTIFIER, |err| {
269            format!("assert: {}", err.message())
270        })
271    })
272}
273
274fn normalize_identifier(raw: &str) -> String {
275    let trimmed = raw.trim();
276    if trimmed.is_empty() {
277        DEFAULT_IDENTIFIER.to_string()
278    } else if trimmed.contains(':') {
279        trimmed.to_string()
280    } else {
281        format!("RunMat:{trimmed}")
282    }
283}
284
285fn is_message_identifier(text: &str) -> bool {
286    let trimmed = text.trim();
287    if trimmed.is_empty() || !trimmed.contains(':') {
288        return false;
289    }
290    trimmed
291        .chars()
292        .all(|ch| ch.is_ascii_alphanumeric() || matches!(ch, ':' | '_' | '.'))
293}
294
295fn looks_like_unqualified_identifier(text: &str) -> bool {
296    let trimmed = text.trim();
297    if trimmed.is_empty() || trimmed.contains(char::is_whitespace) {
298        return false;
299    }
300    trimmed
301        .chars()
302        .all(|ch| ch.is_ascii_alphanumeric() || matches!(ch, '_' | '.'))
303}
304
305fn string_scalar_from_value(value: &Value, context: &str) -> crate::BuiltinResult<String> {
306    match value {
307        Value::String(text) => Ok(text.clone()),
308        Value::StringArray(array) if array.data.len() == 1 => Ok(array.data[0].clone()),
309        Value::CharArray(char_array) if char_array.rows == 1 => {
310            Ok(char_array.data.iter().collect::<String>())
311        }
312        _ => Err(assert_flow(INVALID_INPUT_IDENTIFIER, context)),
313    }
314}
315
316fn string_scalar_opt(value: &Value) -> Option<String> {
317    match value {
318        Value::String(text) => Some(text.clone()),
319        Value::StringArray(array) if array.data.len() == 1 => Some(array.data[0].clone()),
320        Value::CharArray(char_array) if char_array.rows == 1 => {
321            Some(char_array.data.iter().collect())
322        }
323        _ => None,
324    }
325}
326
327#[cfg(test)]
328pub(crate) mod tests {
329    use super::*;
330    use crate::builtins::common::test_support;
331    use futures::executor::block_on;
332    use runmat_builtins::{ComplexTensor, IntValue, LogicalArray, ResolveContext, Tensor, Type};
333
334    fn assert_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
335        block_on(super::assert_builtin(args))
336    }
337
338    fn unwrap_error(err: crate::RuntimeError) -> crate::RuntimeError {
339        err
340    }
341
342    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
343    #[test]
344    fn assert_true_passes() {
345        let result = assert_builtin(vec![Value::Bool(true)]).expect("assert should pass");
346        assert_eq!(result, Value::Num(0.0));
347    }
348
349    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
350    #[test]
351    fn assert_empty_tensor_passes() {
352        let tensor = Tensor::new(Vec::new(), vec![0, 3]).unwrap();
353        assert_builtin(vec![Value::Tensor(tensor)]).expect("assert should pass");
354    }
355
356    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
357    #[test]
358    fn assert_empty_logical_passes() {
359        let logical = LogicalArray::new(Vec::new(), vec![0]).unwrap();
360        assert_builtin(vec![Value::LogicalArray(logical)]).expect("assert should pass");
361    }
362
363    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
364    #[test]
365    fn assert_false_uses_default_message() {
366        let err =
367            unwrap_error(assert_builtin(vec![Value::Bool(false)]).expect_err("assert should fail"));
368        assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
369        assert_eq!(err.message(), DEFAULT_MESSAGE);
370    }
371
372    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
373    #[test]
374    fn assert_handles_numeric_tensor() {
375        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
376        assert_builtin(vec![Value::Tensor(tensor)]).expect("assert should pass");
377    }
378
379    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
380    #[test]
381    fn assert_detects_zero_in_tensor() {
382        let tensor = Tensor::new(vec![1.0, 0.0, 3.0], vec![3, 1]).unwrap();
383        let err = unwrap_error(
384            assert_builtin(vec![Value::Tensor(tensor)]).expect_err("assert should fail"),
385        );
386        assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
387    }
388
389    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
390    #[test]
391    fn assert_detects_nan() {
392        let err = unwrap_error(
393            assert_builtin(vec![Value::Num(f64::NAN)]).expect_err("assert should fail"),
394        );
395        assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
396    }
397
398    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
399    #[test]
400    fn assert_complex_scalar_passes() {
401        assert_builtin(vec![Value::Complex(0.0, 2.0)]).expect("assert should pass");
402    }
403
404    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
405    #[test]
406    fn assert_complex_scalar_failure() {
407        let err = unwrap_error(
408            assert_builtin(vec![Value::Complex(0.0, 0.0)]).expect_err("assert should fail"),
409        );
410        assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
411    }
412
413    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
414    #[test]
415    fn assert_complex_tensor_failure() {
416        let tensor = ComplexTensor::new(vec![(1.0, 0.0), (0.0, 0.0)], vec![2, 1]).expect("tensor");
417        let err = unwrap_error(
418            assert_builtin(vec![Value::ComplexTensor(tensor)]).expect_err("assert should fail"),
419        );
420        assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
421    }
422
423    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
424    #[test]
425    fn assert_accepts_custom_message() {
426        let err = unwrap_error(
427            assert_builtin(vec![
428                Value::Bool(false),
429                Value::from("Vector length must be positive."),
430            ])
431            .expect_err("assert should fail"),
432        );
433        assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
434        assert!(err.message().contains("Vector length must be positive."));
435    }
436
437    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
438    #[test]
439    fn assert_supports_message_formatting() {
440        let err = unwrap_error(
441            assert_builtin(vec![
442                Value::Bool(false),
443                Value::from("Expected positive value, got %d."),
444                Value::Int(IntValue::I32(-4)),
445            ])
446            .expect_err("assert should fail"),
447        );
448        assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
449        assert!(err.message().contains("Expected positive value, got -4."));
450    }
451
452    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
453    #[test]
454    fn assert_supports_custom_identifier() {
455        let err = unwrap_error(
456            assert_builtin(vec![
457                Value::Bool(false),
458                Value::from("runmat:tests:failed"),
459                Value::from("Failure %d occurred."),
460                Value::Int(IntValue::I32(3)),
461            ])
462            .expect_err("assert should fail"),
463        );
464        assert_eq!(err.identifier(), Some("runmat:tests:failed"));
465        assert!(err.message().contains("Failure 3 occurred."));
466    }
467
468    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
469    #[test]
470    fn assert_unqualified_identifier_prefixed() {
471        let err = unwrap_error(
472            assert_builtin(vec![
473                Value::Bool(false),
474                Value::from("customAssertionFailed"),
475                Value::from("runtime failure"),
476            ])
477            .expect_err("assert should fail"),
478        );
479        assert_eq!(err.identifier(), Some("RunMat:customAssertionFailed"));
480    }
481
482    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
483    #[test]
484    fn assert_rejects_invalid_condition_type() {
485        let err = unwrap_error(
486            assert_builtin(vec![Value::from("invalid")]).expect_err("assert should error"),
487        );
488        assert_eq!(err.identifier(), Some(INVALID_CONDITION_IDENTIFIER));
489    }
490
491    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
492    #[test]
493    fn assert_gpu_tensor_passes() {
494        test_support::with_test_provider(|provider| {
495            let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
496            let view = runmat_accelerate_api::HostTensorView {
497                data: &tensor.data,
498                shape: &tensor.shape,
499            };
500            let handle = provider.upload(&view).expect("upload");
501            let result = assert_builtin(vec![Value::GpuTensor(handle)]).expect("assert");
502            assert_eq!(result, Value::Num(0.0));
503        });
504    }
505
506    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
507    #[test]
508    fn assert_invalid_message_type_errors() {
509        let err = unwrap_error(
510            assert_builtin(vec![Value::Bool(false), Value::Num(5.0)])
511                .expect_err("assert should error"),
512        );
513        assert_eq!(err.identifier(), Some(INVALID_INPUT_IDENTIFIER));
514    }
515
516    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
517    #[test]
518    fn assert_formatting_error_propagates() {
519        let err = unwrap_error(
520            assert_builtin(vec![
521                Value::Bool(false),
522                Value::from("number %d must be > 0"),
523            ])
524            .expect_err("assert should fail"),
525        );
526        assert_eq!(err.identifier(), Some(INVALID_INPUT_IDENTIFIER));
527        assert!(err.message().contains("sprintf"));
528    }
529
530    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
531    #[test]
532    fn assert_gpu_tensor_failure() {
533        test_support::with_test_provider(|provider| {
534            let tensor = Tensor::new(vec![1.0, 0.0, 3.0], vec![3, 1]).unwrap();
535            let view = runmat_accelerate_api::HostTensorView {
536                data: &tensor.data,
537                shape: &tensor.shape,
538            };
539            let handle = provider.upload(&view).expect("upload");
540            let err =
541                unwrap_error(assert_builtin(vec![Value::GpuTensor(handle)]).expect_err("assert"));
542            assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
543        });
544    }
545
546    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
547    #[test]
548    fn assert_logical_array_failure() {
549        let logical = LogicalArray::new(vec![1, 0], vec![2]).unwrap();
550        let err = unwrap_error(
551            assert_builtin(vec![Value::LogicalArray(logical)]).expect_err("assert should fail"),
552        );
553        assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
554    }
555
556    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
557    #[test]
558    fn assert_requires_condition_argument() {
559        let err = unwrap_error(assert_builtin(Vec::new()).expect_err("assert should error"));
560        assert_eq!(err.identifier(), Some(MIN_INPUT_IDENTIFIER));
561        assert_eq!(err.message(), MIN_INPUT_MESSAGE);
562    }
563
564    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
565    #[test]
566    #[cfg(feature = "wgpu")]
567    fn assert_wgpu_tensor_failure_matches_cpu() {
568        use runmat_accelerate::backend::wgpu::provider::{
569            register_wgpu_provider, WgpuProviderOptions,
570        };
571
572        if register_wgpu_provider(WgpuProviderOptions::default()).is_err() {
573            return;
574        }
575        let Some(provider) = runmat_accelerate_api::provider() else {
576            return;
577        };
578
579        let tensor = Tensor::new(vec![1.0, 0.0], vec![2, 1]).unwrap();
580        let view = runmat_accelerate_api::HostTensorView {
581            data: &tensor.data,
582            shape: &tensor.shape,
583        };
584        let handle = provider.upload(&view).expect("upload");
585        let err = unwrap_error(
586            assert_builtin(vec![Value::GpuTensor(handle)]).expect_err("assert should fail"),
587        );
588        assert_eq!(err.identifier(), Some(DEFAULT_IDENTIFIER));
589    }
590
591    #[test]
592    fn assert_type_is_numeric() {
593        assert_eq!(
594            assert_type(&[Type::Bool], &ResolveContext::new(Vec::new())),
595            Type::Num
596        );
597    }
598}