Skip to main content

runmat_runtime/builtins/diagnostics/
assert.rs

1//! MATLAB-compatible `assert` builtin that mirrors MATLAB diagnostic semantics.
2
3use runmat_builtins::{
4    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6    ComplexTensor, Tensor, Value,
7};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::format::format_variadic;
11use crate::builtins::common::gpu_helpers;
12use crate::builtins::common::spec::{
13    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14    ReductionNaN, ResidencyPolicy, ShapeRequirements,
15};
16use crate::builtins::diagnostics::type_resolvers::assert_type;
17use crate::{build_runtime_error, RuntimeError};
18
19const BUILTIN_NAME: &str = "assert";
20
21const ASSERT_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
22    name: "out",
23    ty: BuiltinParamType::NumericArray,
24    arity: BuiltinParamArity::Required,
25    default: None,
26    description: "Zero when the assertion passes.",
27}];
28
29const ASSERT_INPUTS_CONDITION: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
30    name: "condition",
31    ty: BuiltinParamType::Any,
32    arity: BuiltinParamArity::Required,
33    default: None,
34    description: "Logical/numeric condition that must evaluate to true.",
35}];
36
37const ASSERT_INPUTS_MESSAGE: [BuiltinParamDescriptor; 2] = [
38    BuiltinParamDescriptor {
39        name: "condition",
40        ty: BuiltinParamType::Any,
41        arity: BuiltinParamArity::Required,
42        default: None,
43        description: "Logical/numeric condition that must evaluate to true.",
44    },
45    BuiltinParamDescriptor {
46        name: "message",
47        ty: BuiltinParamType::StringScalar,
48        arity: BuiltinParamArity::Required,
49        default: Some("\"Assertion failed.\""),
50        description: "Failure message text.",
51    },
52];
53
54const ASSERT_INPUTS_MESSAGE_VARIADIC: [BuiltinParamDescriptor; 3] = [
55    BuiltinParamDescriptor {
56        name: "condition",
57        ty: BuiltinParamType::Any,
58        arity: BuiltinParamArity::Required,
59        default: None,
60        description: "Logical/numeric condition that must evaluate to true.",
61    },
62    BuiltinParamDescriptor {
63        name: "message",
64        ty: BuiltinParamType::StringScalar,
65        arity: BuiltinParamArity::Required,
66        default: Some("\"Assertion failed.\""),
67        description: "Failure message template text.",
68    },
69    BuiltinParamDescriptor {
70        name: "A",
71        ty: BuiltinParamType::Any,
72        arity: BuiltinParamArity::Variadic,
73        default: None,
74        description: "Formatting values for the message template.",
75    },
76];
77
78const ASSERT_INPUTS_IDENTIFIER_MESSAGE: [BuiltinParamDescriptor; 3] = [
79    BuiltinParamDescriptor {
80        name: "condition",
81        ty: BuiltinParamType::Any,
82        arity: BuiltinParamArity::Required,
83        default: None,
84        description: "Logical/numeric condition that must evaluate to true.",
85    },
86    BuiltinParamDescriptor {
87        name: "message_id",
88        ty: BuiltinParamType::StringScalar,
89        arity: BuiltinParamArity::Required,
90        default: Some("\"RunMat:assertion:failed\""),
91        description: "Message identifier.",
92    },
93    BuiltinParamDescriptor {
94        name: "message",
95        ty: BuiltinParamType::StringScalar,
96        arity: BuiltinParamArity::Required,
97        default: Some("\"Assertion failed.\""),
98        description: "Failure message text.",
99    },
100];
101
102const ASSERT_INPUTS_IDENTIFIER_MESSAGE_VARIADIC: [BuiltinParamDescriptor; 4] = [
103    BuiltinParamDescriptor {
104        name: "condition",
105        ty: BuiltinParamType::Any,
106        arity: BuiltinParamArity::Required,
107        default: None,
108        description: "Logical/numeric condition that must evaluate to true.",
109    },
110    BuiltinParamDescriptor {
111        name: "message_id",
112        ty: BuiltinParamType::StringScalar,
113        arity: BuiltinParamArity::Required,
114        default: Some("\"RunMat:assertion:failed\""),
115        description: "Message identifier.",
116    },
117    BuiltinParamDescriptor {
118        name: "message",
119        ty: BuiltinParamType::StringScalar,
120        arity: BuiltinParamArity::Required,
121        default: Some("\"Assertion failed.\""),
122        description: "Failure message template text.",
123    },
124    BuiltinParamDescriptor {
125        name: "A",
126        ty: BuiltinParamType::Any,
127        arity: BuiltinParamArity::Variadic,
128        default: None,
129        description: "Formatting values for the message template.",
130    },
131];
132
133const ASSERT_SIGNATURES: [BuiltinSignatureDescriptor; 5] = [
134    BuiltinSignatureDescriptor {
135        label: "out = assert(condition)",
136        inputs: &ASSERT_INPUTS_CONDITION,
137        outputs: &ASSERT_OUTPUT,
138    },
139    BuiltinSignatureDescriptor {
140        label: "out = assert(condition, message)",
141        inputs: &ASSERT_INPUTS_MESSAGE,
142        outputs: &ASSERT_OUTPUT,
143    },
144    BuiltinSignatureDescriptor {
145        label: "out = assert(condition, message, A...)",
146        inputs: &ASSERT_INPUTS_MESSAGE_VARIADIC,
147        outputs: &ASSERT_OUTPUT,
148    },
149    BuiltinSignatureDescriptor {
150        label: "out = assert(condition, message_id, message)",
151        inputs: &ASSERT_INPUTS_IDENTIFIER_MESSAGE,
152        outputs: &ASSERT_OUTPUT,
153    },
154    BuiltinSignatureDescriptor {
155        label: "out = assert(condition, message_id, message, A...)",
156        inputs: &ASSERT_INPUTS_IDENTIFIER_MESSAGE_VARIADIC,
157        outputs: &ASSERT_OUTPUT,
158    },
159];
160
161const ASSERT_ERROR_ASSERTION_FAILED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
162    code: "RM.ASSERT.ASSERTION_FAILED",
163    identifier: Some("RunMat:assertion:failed"),
164    when: "Condition evaluates to false and no custom identifier/message override is provided.",
165    message: "Assertion failed.",
166};
167
168const ASSERT_ERROR_INVALID_CONDITION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
169    code: "RM.ASSERT.INVALID_CONDITION",
170    identifier: Some("RunMat:assertion:invalidCondition"),
171    when: "First argument is not a supported logical or numeric condition input.",
172    message: "assert: first input must be logical or numeric.",
173};
174
175const ASSERT_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
176    code: "RM.ASSERT.INVALID_INPUT",
177    identifier: Some("RunMat:assertion:invalidInput"),
178    when: "Message identifier/message text or formatting payload is invalid.",
179    message: "assert: invalid input argument",
180};
181
182const ASSERT_ERROR_NOT_ENOUGH_INPUTS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
183    code: "RM.ASSERT.NOT_ENOUGH_INPUTS",
184    identifier: Some("RunMat:minrhs"),
185    when: "No condition argument is provided.",
186    message: "Not enough input arguments.",
187};
188
189const ASSERT_ERRORS: [BuiltinErrorDescriptor; 4] = [
190    ASSERT_ERROR_ASSERTION_FAILED,
191    ASSERT_ERROR_INVALID_CONDITION,
192    ASSERT_ERROR_INVALID_INPUT,
193    ASSERT_ERROR_NOT_ENOUGH_INPUTS,
194];
195
196pub const ASSERT_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
197    signatures: &ASSERT_SIGNATURES,
198    output_mode: BuiltinOutputMode::Fixed,
199    completion_policy: BuiltinCompletionPolicy::Public,
200    errors: &ASSERT_ERRORS,
201};
202
203#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::diagnostics::assert")]
204pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
205    name: "assert",
206    op_kind: GpuOpKind::Custom("control"),
207    supported_precisions: &[],
208    broadcast: BroadcastSemantics::None,
209    provider_hooks: &[],
210    constant_strategy: ConstantStrategy::InlineLiteral,
211    residency: ResidencyPolicy::GatherImmediately,
212    nan_mode: ReductionNaN::Include,
213    two_pass_threshold: None,
214    workgroup_size: None,
215    accepts_nan_mode: false,
216    notes: "Control-flow builtin; GPU tensors are gathered to host memory before evaluation.",
217};
218
219#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::diagnostics::assert")]
220pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
221    name: "assert",
222    shape: ShapeRequirements::Any,
223    constant_strategy: ConstantStrategy::InlineLiteral,
224    elementwise: None,
225    reduction: None,
226    emits_nan: false,
227    notes: "Control-flow builtin with no fusion support.",
228};
229
230fn assert_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
231    assert_error_with_message(error.message, error)
232}
233
234fn assert_default_identifier() -> &'static str {
235    ASSERT_ERROR_ASSERTION_FAILED
236        .identifier
237        .expect("assert default identifier must be defined")
238}
239
240fn assert_default_message() -> &'static str {
241    ASSERT_ERROR_ASSERTION_FAILED.message
242}
243
244fn assert_error_with_message(
245    message: impl Into<String>,
246    error: &'static BuiltinErrorDescriptor,
247) -> RuntimeError {
248    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
249    if let Some(identifier) = error.identifier {
250        builder = builder.with_identifier(normalize_identifier(identifier));
251    }
252    builder.build()
253}
254
255fn assert_flow(identifier: &str, message: impl Into<String>) -> RuntimeError {
256    build_runtime_error(message)
257        .with_builtin(BUILTIN_NAME)
258        .with_identifier(normalize_identifier(identifier))
259        .build()
260}
261
262fn remap_assert_flow<F>(
263    err: RuntimeError,
264    error: &'static BuiltinErrorDescriptor,
265    message: F,
266) -> RuntimeError
267where
268    F: FnOnce(&crate::RuntimeError) -> String,
269{
270    let mut builder = build_runtime_error(message(&err))
271        .with_builtin(BUILTIN_NAME)
272        .with_source(err);
273    if let Some(identifier) = error.identifier {
274        builder = builder.with_identifier(normalize_identifier(identifier));
275    }
276    builder.build()
277}
278
279#[runtime_builtin(
280    name = "assert",
281    category = "diagnostics",
282    summary = "Throw an error when a condition is false, matching MATLAB assert semantics.",
283    keywords = "assert,diagnostics,validation,error",
284    accel = "metadata",
285    type_resolver(assert_type),
286    descriptor(crate::builtins::diagnostics::assert::ASSERT_DESCRIPTOR),
287    builtin_path = "crate::builtins::diagnostics::assert"
288)]
289async fn assert_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
290    if args.is_empty() {
291        return Err(assert_error(&ASSERT_ERROR_NOT_ENOUGH_INPUTS));
292    }
293
294    let mut iter = args.into_iter();
295    let condition_raw = iter.next().expect("checked length above");
296    let rest: Vec<Value> = iter.collect();
297
298    let condition = normalize_condition_value(condition_raw).await?;
299    match evaluate_condition(condition)? {
300        ConditionOutcome::Pass => Ok(Value::Num(0.0)),
301        ConditionOutcome::Fail => {
302            let payload = failure_payload(&rest)?;
303            Err(assert_flow(&payload.identifier, payload.message))
304        }
305    }
306}
307
308async fn normalize_condition_value(condition: Value) -> crate::BuiltinResult<Value> {
309    match condition {
310        Value::GpuTensor(handle) => {
311            let gpu_value = Value::GpuTensor(handle);
312            gpu_helpers::gather_value_async(&gpu_value)
313                .await
314                .map_err(|flow| {
315                    remap_assert_flow(flow, &ASSERT_ERROR_INVALID_INPUT, |err| {
316                        format!("assert: {}", err.message())
317                    })
318                })
319        }
320        other => Ok(other),
321    }
322}
323
324#[derive(Copy, Clone, Debug, PartialEq, Eq)]
325enum ConditionOutcome {
326    Pass,
327    Fail,
328}
329
330fn evaluate_condition(value: Value) -> crate::BuiltinResult<ConditionOutcome> {
331    match value {
332        Value::Bool(flag) => Ok(if flag {
333            ConditionOutcome::Pass
334        } else {
335            ConditionOutcome::Fail
336        }),
337        Value::Int(int_value) => {
338            if int_value.to_i64() != 0 {
339                Ok(ConditionOutcome::Pass)
340            } else {
341                Ok(ConditionOutcome::Fail)
342            }
343        }
344        Value::Num(num) => {
345            if num.is_nan() || num == 0.0 {
346                Ok(ConditionOutcome::Fail)
347            } else {
348                Ok(ConditionOutcome::Pass)
349            }
350        }
351        Value::Complex(re, im) => {
352            if complex_element_passes(re, im) {
353                Ok(ConditionOutcome::Pass)
354            } else {
355                Ok(ConditionOutcome::Fail)
356            }
357        }
358        Value::LogicalArray(array) => {
359            if array.data.iter().all(|&bit| bit != 0) {
360                Ok(ConditionOutcome::Pass)
361            } else {
362                Ok(ConditionOutcome::Fail)
363            }
364        }
365        Value::Tensor(tensor) => evaluate_tensor_condition(&tensor),
366        Value::ComplexTensor(tensor) => evaluate_complex_tensor(&tensor),
367        Value::GpuTensor(_) => {
368            unreachable!("gpu tensors are gathered in normalize_condition_value")
369        }
370        _ => Err(assert_error(&ASSERT_ERROR_INVALID_CONDITION)),
371    }
372}
373
374fn evaluate_tensor_condition(tensor: &Tensor) -> crate::BuiltinResult<ConditionOutcome> {
375    if tensor.data.is_empty() {
376        return Ok(ConditionOutcome::Pass);
377    }
378    for value in &tensor.data {
379        if value.is_nan() || *value == 0.0 {
380            return Ok(ConditionOutcome::Fail);
381        }
382    }
383    Ok(ConditionOutcome::Pass)
384}
385
386fn evaluate_complex_tensor(tensor: &ComplexTensor) -> crate::BuiltinResult<ConditionOutcome> {
387    if tensor.data.is_empty() {
388        return Ok(ConditionOutcome::Pass);
389    }
390    for &(re, im) in &tensor.data {
391        if !complex_element_passes(re, im) {
392            return Ok(ConditionOutcome::Fail);
393        }
394    }
395    Ok(ConditionOutcome::Pass)
396}
397
398fn complex_element_passes(re: f64, im: f64) -> bool {
399    if re.is_nan() || im.is_nan() {
400        return false;
401    }
402    re != 0.0 || im != 0.0
403}
404
405struct FailurePayload {
406    identifier: String,
407    message: String,
408}
409
410fn failure_payload(args: &[Value]) -> crate::BuiltinResult<FailurePayload> {
411    if args.is_empty() {
412        return Ok(FailurePayload {
413            identifier: assert_default_identifier().to_string(),
414            message: assert_default_message().to_string(),
415        });
416    }
417
418    let candidate = &args[0];
419    let treat_as_identifier = args.len() >= 2 && value_is_identifier(candidate);
420
421    if treat_as_identifier {
422        if args.len() < 2 {
423            return Err(assert_flow(
424                ASSERT_ERROR_INVALID_INPUT
425                    .identifier
426                    .expect("assert invalid-input identifier must be defined"),
427                "assert: message text must follow the message identifier.",
428            ));
429        }
430        let identifier = identifier_from_value(candidate)?;
431        let template = message_from_value(&args[1])?;
432        let formatting_args: &[Value] = if args.len() > 2 { &args[2..] } else { &[] };
433        let message = format_message(&template, formatting_args)?;
434        Ok(FailurePayload {
435            identifier,
436            message,
437        })
438    } else {
439        let template = message_from_value(candidate)?;
440        let formatting_args: &[Value] = if args.len() > 1 { &args[1..] } else { &[] };
441        let message = format_message(&template, formatting_args)?;
442        Ok(FailurePayload {
443            identifier: assert_default_identifier().to_string(),
444            message,
445        })
446    }
447}
448
449fn value_is_identifier(value: &Value) -> bool {
450    if let Some(text) = string_scalar_opt(value) {
451        is_message_identifier(&text) || looks_like_unqualified_identifier(&text)
452    } else {
453        false
454    }
455}
456
457fn identifier_from_value(value: &Value) -> crate::BuiltinResult<String> {
458    let text = string_scalar_from_value(
459        value,
460        "assert: message identifier must be a string scalar or character vector.",
461    )?;
462    if text.trim().is_empty() {
463        return Err(assert_flow(
464            ASSERT_ERROR_INVALID_INPUT
465                .identifier
466                .expect("assert invalid-input identifier must be defined"),
467            "assert: message identifier must be nonempty.",
468        ));
469    }
470    Ok(normalize_identifier(&text))
471}
472
473fn message_from_value(value: &Value) -> crate::BuiltinResult<String> {
474    string_scalar_from_value(
475        value,
476        "assert: message text must be a string scalar or character vector.",
477    )
478}
479
480fn format_message(template: &str, args: &[Value]) -> crate::BuiltinResult<String> {
481    format_variadic(template, args).map_err(|flow| {
482        remap_assert_flow(flow, &ASSERT_ERROR_INVALID_INPUT, |err| {
483            format!("assert: {}", err.message())
484        })
485    })
486}
487
488fn normalize_identifier(raw: &str) -> String {
489    let trimmed = raw.trim();
490    if trimmed.is_empty() {
491        assert_default_identifier().to_string()
492    } else if trimmed.contains(':') {
493        trimmed.to_string()
494    } else {
495        format!("RunMat:{trimmed}")
496    }
497}
498
499fn is_message_identifier(text: &str) -> bool {
500    let trimmed = text.trim();
501    if trimmed.is_empty() || !trimmed.contains(':') {
502        return false;
503    }
504    trimmed
505        .chars()
506        .all(|ch| ch.is_ascii_alphanumeric() || matches!(ch, ':' | '_' | '.'))
507}
508
509fn looks_like_unqualified_identifier(text: &str) -> bool {
510    let trimmed = text.trim();
511    if trimmed.is_empty() || trimmed.contains(char::is_whitespace) {
512        return false;
513    }
514    trimmed
515        .chars()
516        .all(|ch| ch.is_ascii_alphanumeric() || matches!(ch, '_' | '.'))
517}
518
519fn string_scalar_from_value(value: &Value, context: &str) -> crate::BuiltinResult<String> {
520    match value {
521        Value::String(text) => Ok(text.clone()),
522        Value::StringArray(array) if array.data.len() == 1 => Ok(array.data[0].clone()),
523        Value::CharArray(char_array) if char_array.rows == 1 => {
524            Ok(char_array.data.iter().collect::<String>())
525        }
526        _ => Err(assert_error_with_message(
527            context,
528            &ASSERT_ERROR_INVALID_INPUT,
529        )),
530    }
531}
532
533fn string_scalar_opt(value: &Value) -> Option<String> {
534    match value {
535        Value::String(text) => Some(text.clone()),
536        Value::StringArray(array) if array.data.len() == 1 => Some(array.data[0].clone()),
537        Value::CharArray(char_array) if char_array.rows == 1 => {
538            Some(char_array.data.iter().collect())
539        }
540        _ => None,
541    }
542}
543
544#[cfg(test)]
545pub(crate) mod tests {
546    use super::*;
547    use crate::builtins::common::test_support;
548    use futures::executor::block_on;
549    use runmat_builtins::{ComplexTensor, IntValue, LogicalArray, ResolveContext, Tensor, Type};
550
551    fn assert_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
552        block_on(super::assert_builtin(args))
553    }
554
555    fn unwrap_error(err: crate::RuntimeError) -> crate::RuntimeError {
556        err
557    }
558
559    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
560    #[test]
561    fn assert_true_passes() {
562        let result = assert_builtin(vec![Value::Bool(true)]).expect("assert should pass");
563        assert_eq!(result, Value::Num(0.0));
564    }
565
566    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
567    #[test]
568    fn assert_empty_tensor_passes() {
569        let tensor = Tensor::new(Vec::new(), vec![0, 3]).unwrap();
570        assert_builtin(vec![Value::Tensor(tensor)]).expect("assert should pass");
571    }
572
573    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
574    #[test]
575    fn assert_empty_logical_passes() {
576        let logical = LogicalArray::new(Vec::new(), vec![0]).unwrap();
577        assert_builtin(vec![Value::LogicalArray(logical)]).expect("assert should pass");
578    }
579
580    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
581    #[test]
582    fn assert_false_uses_default_message() {
583        let err =
584            unwrap_error(assert_builtin(vec![Value::Bool(false)]).expect_err("assert should fail"));
585        assert_eq!(err.identifier(), Some(assert_default_identifier()));
586        assert_eq!(err.message(), assert_default_message());
587    }
588
589    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
590    #[test]
591    fn assert_handles_numeric_tensor() {
592        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
593        assert_builtin(vec![Value::Tensor(tensor)]).expect("assert should pass");
594    }
595
596    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
597    #[test]
598    fn assert_detects_zero_in_tensor() {
599        let tensor = Tensor::new(vec![1.0, 0.0, 3.0], vec![3, 1]).unwrap();
600        let err = unwrap_error(
601            assert_builtin(vec![Value::Tensor(tensor)]).expect_err("assert should fail"),
602        );
603        assert_eq!(err.identifier(), Some(assert_default_identifier()));
604    }
605
606    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
607    #[test]
608    fn assert_detects_nan() {
609        let err = unwrap_error(
610            assert_builtin(vec![Value::Num(f64::NAN)]).expect_err("assert should fail"),
611        );
612        assert_eq!(err.identifier(), Some(assert_default_identifier()));
613    }
614
615    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
616    #[test]
617    fn assert_complex_scalar_passes() {
618        assert_builtin(vec![Value::Complex(0.0, 2.0)]).expect("assert should pass");
619    }
620
621    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
622    #[test]
623    fn assert_complex_scalar_failure() {
624        let err = unwrap_error(
625            assert_builtin(vec![Value::Complex(0.0, 0.0)]).expect_err("assert should fail"),
626        );
627        assert_eq!(err.identifier(), Some(assert_default_identifier()));
628    }
629
630    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
631    #[test]
632    fn assert_complex_tensor_failure() {
633        let tensor = ComplexTensor::new(vec![(1.0, 0.0), (0.0, 0.0)], vec![2, 1]).expect("tensor");
634        let err = unwrap_error(
635            assert_builtin(vec![Value::ComplexTensor(tensor)]).expect_err("assert should fail"),
636        );
637        assert_eq!(err.identifier(), Some(assert_default_identifier()));
638    }
639
640    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
641    #[test]
642    fn assert_accepts_custom_message() {
643        let err = unwrap_error(
644            assert_builtin(vec![
645                Value::Bool(false),
646                Value::from("Vector length must be positive."),
647            ])
648            .expect_err("assert should fail"),
649        );
650        assert_eq!(err.identifier(), Some(assert_default_identifier()));
651        assert!(err.message().contains("Vector length must be positive."));
652    }
653
654    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
655    #[test]
656    fn assert_supports_message_formatting() {
657        let err = unwrap_error(
658            assert_builtin(vec![
659                Value::Bool(false),
660                Value::from("Expected positive value, got %d."),
661                Value::Int(IntValue::I32(-4)),
662            ])
663            .expect_err("assert should fail"),
664        );
665        assert_eq!(err.identifier(), Some(assert_default_identifier()));
666        assert!(err.message().contains("Expected positive value, got -4."));
667    }
668
669    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
670    #[test]
671    fn assert_supports_custom_identifier() {
672        let err = unwrap_error(
673            assert_builtin(vec![
674                Value::Bool(false),
675                Value::from("runmat:tests:failed"),
676                Value::from("Failure %d occurred."),
677                Value::Int(IntValue::I32(3)),
678            ])
679            .expect_err("assert should fail"),
680        );
681        assert_eq!(err.identifier(), Some("runmat:tests:failed"));
682        assert!(err.message().contains("Failure 3 occurred."));
683    }
684
685    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
686    #[test]
687    fn assert_unqualified_identifier_prefixed() {
688        let err = unwrap_error(
689            assert_builtin(vec![
690                Value::Bool(false),
691                Value::from("customAssertionFailed"),
692                Value::from("runtime failure"),
693            ])
694            .expect_err("assert should fail"),
695        );
696        assert_eq!(err.identifier(), Some("RunMat:customAssertionFailed"));
697    }
698
699    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
700    #[test]
701    fn assert_rejects_invalid_condition_type() {
702        let err = unwrap_error(
703            assert_builtin(vec![Value::from("invalid")]).expect_err("assert should error"),
704        );
705        assert_eq!(
706            err.identifier(),
707            Some(ASSERT_ERROR_INVALID_CONDITION.identifier.unwrap())
708        );
709    }
710
711    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
712    #[test]
713    fn assert_gpu_tensor_passes() {
714        test_support::with_test_provider(|provider| {
715            let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
716            let view = runmat_accelerate_api::HostTensorView {
717                data: &tensor.data,
718                shape: &tensor.shape,
719            };
720            let handle = provider.upload(&view).expect("upload");
721            let result = assert_builtin(vec![Value::GpuTensor(handle)]).expect("assert");
722            assert_eq!(result, Value::Num(0.0));
723        });
724    }
725
726    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
727    #[test]
728    fn assert_invalid_message_type_errors() {
729        let err = unwrap_error(
730            assert_builtin(vec![Value::Bool(false), Value::Num(5.0)])
731                .expect_err("assert should error"),
732        );
733        assert_eq!(
734            err.identifier(),
735            Some(ASSERT_ERROR_INVALID_INPUT.identifier.unwrap())
736        );
737    }
738
739    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
740    #[test]
741    fn assert_formatting_error_propagates() {
742        let err = unwrap_error(
743            assert_builtin(vec![
744                Value::Bool(false),
745                Value::from("number %d must be > 0"),
746            ])
747            .expect_err("assert should fail"),
748        );
749        assert_eq!(
750            err.identifier(),
751            Some(ASSERT_ERROR_INVALID_INPUT.identifier.unwrap())
752        );
753        assert!(err.message().contains("sprintf"));
754    }
755
756    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
757    #[test]
758    fn assert_gpu_tensor_failure() {
759        test_support::with_test_provider(|provider| {
760            let tensor = Tensor::new(vec![1.0, 0.0, 3.0], vec![3, 1]).unwrap();
761            let view = runmat_accelerate_api::HostTensorView {
762                data: &tensor.data,
763                shape: &tensor.shape,
764            };
765            let handle = provider.upload(&view).expect("upload");
766            let err =
767                unwrap_error(assert_builtin(vec![Value::GpuTensor(handle)]).expect_err("assert"));
768            assert_eq!(err.identifier(), Some(assert_default_identifier()));
769        });
770    }
771
772    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
773    #[test]
774    fn assert_logical_array_failure() {
775        let logical = LogicalArray::new(vec![1, 0], vec![2]).unwrap();
776        let err = unwrap_error(
777            assert_builtin(vec![Value::LogicalArray(logical)]).expect_err("assert should fail"),
778        );
779        assert_eq!(err.identifier(), Some(assert_default_identifier()));
780    }
781
782    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
783    #[test]
784    fn assert_requires_condition_argument() {
785        let err = unwrap_error(assert_builtin(Vec::new()).expect_err("assert should error"));
786        assert_eq!(
787            err.identifier(),
788            Some(ASSERT_ERROR_NOT_ENOUGH_INPUTS.identifier.unwrap())
789        );
790        assert_eq!(err.message(), ASSERT_ERROR_NOT_ENOUGH_INPUTS.message);
791    }
792
793    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
794    #[test]
795    #[cfg(feature = "wgpu")]
796    fn assert_wgpu_tensor_failure_matches_cpu() {
797        use runmat_accelerate::backend::wgpu::provider::{
798            register_wgpu_provider, WgpuProviderOptions,
799        };
800
801        if register_wgpu_provider(WgpuProviderOptions::default()).is_err() {
802            return;
803        }
804        let Some(provider) = runmat_accelerate_api::provider() else {
805            return;
806        };
807
808        let tensor = Tensor::new(vec![1.0, 0.0], vec![2, 1]).unwrap();
809        let view = runmat_accelerate_api::HostTensorView {
810            data: &tensor.data,
811            shape: &tensor.shape,
812        };
813        let handle = provider.upload(&view).expect("upload");
814        let err = unwrap_error(
815            assert_builtin(vec![Value::GpuTensor(handle)]).expect_err("assert should fail"),
816        );
817        assert_eq!(err.identifier(), Some(assert_default_identifier()));
818    }
819
820    #[test]
821    fn assert_type_is_numeric() {
822        assert_eq!(
823            assert_type(&[Type::Bool], &ResolveContext::new(Vec::new())),
824            Type::Num
825        );
826    }
827}