use archmage::{ScalarToken, rite};
#[rite(scalar)]
fn scalar_only_sum(_t: ScalarToken, data: &[f32]) -> f32 {
data.iter().sum()
}
#[rite(_scalar)]
fn scalar_underscore_sum(_t: ScalarToken, data: &[f32]) -> f32 {
data.iter().sum()
}
#[test]
fn single_scalar_tokenful() {
let result = scalar_only_sum(ScalarToken, &[1.0, 2.0, 3.0, 4.0]);
assert_eq!(result, 10.0);
}
#[test]
fn single_scalar_underscore_accepted() {
let result = scalar_underscore_sum(ScalarToken, &[1.0, 2.0, 3.0, 4.0]);
assert_eq!(result, 10.0);
}
#[rite(default)]
fn default_only_sum(data: &[f32]) -> f32 {
data.iter().sum()
}
#[rite(_default)]
fn default_underscore_sum(data: &[f32]) -> f32 {
data.iter().sum()
}
#[test]
fn single_default_tokenless() {
let result = default_only_sum(&[1.0, 2.0, 3.0, 4.0]);
assert_eq!(result, 10.0);
}
#[test]
fn single_default_underscore_accepted() {
let result = default_underscore_sum(&[1.0, 2.0, 3.0, 4.0]);
assert_eq!(result, 10.0);
}
#[rite(v3, neon, wasm128)]
fn multi_tokenless_square(data: &[f32; 4]) -> f32 {
data.iter().map(|x| x * x).sum()
}
#[rite(scalar)]
fn multi_tokenless_square_scalar(_t: ScalarToken, data: &[f32; 4]) -> f32 {
data.iter().map(|x| x * x).sum()
}
#[test]
fn tokenless_multi_tier_scalar_variant_callable() {
let result = multi_tokenless_square_scalar(ScalarToken, &[1.0, 2.0, 3.0, 4.0]);
assert_eq!(result, 30.0);
}
#[rite(v3, neon, wasm128, scalar)]
fn multi_with_scalar(token: ScalarToken, data: &[f32; 4]) -> f32 {
let _ = token; data.iter().sum()
}
#[test]
fn multi_with_scalar_variant_callable() {
let result = multi_with_scalar_scalar(ScalarToken, &[1.0, 2.0, 3.0, 4.0]);
assert_eq!(result, 10.0);
}
#[rite(v3, neon, wasm128, default)]
fn multi_with_default(data: &[f32; 4]) -> f32 {
data.iter().sum()
}
#[test]
fn multi_with_default_variant_callable() {
let result = multi_with_default_default(&[1.0, 2.0, 3.0, 4.0]);
assert_eq!(result, 10.0);
}
#[rite(v3, neon, wasm128, scalar, default)]
fn multi_with_both(_t: ScalarToken, data: &[f32; 4]) -> f32 {
data.iter().sum()
}
#[test]
fn multi_with_both_scalar_callable() {
let result = multi_with_both_scalar(ScalarToken, &[5.0, 10.0, 15.0, 20.0]);
assert_eq!(result, 50.0);
}
#[test]
fn multi_with_both_default_callable() {
let result = multi_with_both_default(ScalarToken, &[5.0, 10.0, 15.0, 20.0]);
assert_eq!(result, 50.0);
}
#[rite(scalar, v3, neon)]
fn order_scalar_first(_t: ScalarToken, x: i32) -> i32 {
x * 2
}
#[test]
fn tier_order_does_not_matter() {
assert_eq!(order_scalar_first_scalar(ScalarToken, 21), 42);
}
#[rite(v3, neon, wasm128, scalar, default)]
fn suffix_convention_test(_t: ScalarToken, x: f32) -> f32 {
x
}
#[test]
fn suffix_names_match_convention() {
assert_eq!(suffix_convention_test_scalar(ScalarToken, 1.0), 1.0);
assert_eq!(suffix_convention_test_default(ScalarToken, 2.0), 2.0);
}
#[rite(scalar)]
fn scalar_mutates(_t: ScalarToken, out: &mut Vec<f32>, n: f32) {
out.push(n);
}
#[rite(default)]
fn default_mutates(out: &mut Vec<f32>, n: f32) {
out.push(n);
}
#[test]
fn scalar_and_default_support_mutation() {
let mut v = Vec::new();
scalar_mutates(ScalarToken, &mut v, 1.0);
default_mutates(&mut v, 2.0);
assert_eq!(v, vec![1.0, 2.0]);
}
#[rite(scalar)]
fn scalar_generic<T: Copy + core::ops::Add<Output = T> + Default>(
_t: ScalarToken,
data: &[T],
) -> T {
let mut acc = T::default();
for &x in data {
acc = acc + x;
}
acc
}
#[test]
fn scalar_rite_with_generics() {
assert_eq!(scalar_generic::<f32>(ScalarToken, &[1.0, 2.0, 3.0]), 6.0);
assert_eq!(scalar_generic::<i32>(ScalarToken, &[10, 20, 30]), 60);
}