use crate::mlir::amx::{fma_opcode_and_flags, validate_amx_dtypes, z_row_stride};
use morok_dtype::DType;
use morok_ir::{WmmaMetadata, WmmaUpcastAxes};
fn make_metadata(dtype_in: DType, dtype_out: DType) -> WmmaMetadata {
WmmaMetadata {
name: "test_wmma".to_string(),
dims: (16, 16, 1),
dtype_in,
dtype_out,
device: "AppleAMX".to_string(),
threads: 1,
upcast_axes: WmmaUpcastAxes { a: vec![], b: vec![], c: vec![] },
reduce_axes: vec![],
tile_grid: (1, 1),
}
}
#[test]
fn test_z_row_stride() {
assert_eq!(z_row_stride(&DType::Float64, &DType::Float64), 8);
assert_eq!(z_row_stride(&DType::Float32, &DType::Float32), 4);
assert_eq!(z_row_stride(&DType::Float16, &DType::Float16), 2);
assert_eq!(z_row_stride(&DType::Float16, &DType::Float32), 2);
assert_eq!(z_row_stride(&DType::Int16, &DType::Int16), 2);
}
#[test]
fn test_fma_opcode_and_flags() {
let (op, flags) = fma_opcode_and_flags(&make_metadata(DType::Float64, DType::Float64)).unwrap();
assert_eq!(op, 10);
assert_eq!(flags, 0);
let (op, flags) = fma_opcode_and_flags(&make_metadata(DType::Float32, DType::Float32)).unwrap();
assert_eq!(op, 12);
assert_eq!(flags, 0);
let (op, flags) = fma_opcode_and_flags(&make_metadata(DType::Float16, DType::Float16)).unwrap();
assert_eq!(op, 15);
assert_eq!(flags, 0);
let (op, flags) = fma_opcode_and_flags(&make_metadata(DType::Float16, DType::Float32)).unwrap();
assert_eq!(op, 15);
assert_eq!(flags, 1 << 62);
let (op, flags) = fma_opcode_and_flags(&make_metadata(DType::Int16, DType::Int16)).unwrap();
assert_eq!(op, 14);
assert_eq!(flags, 0);
}
#[test]
fn test_validate_amx_dtypes_unsupported() {
let result = validate_amx_dtypes(&DType::Float32, &DType::Float16);
assert!(result.is_err());
let result = validate_amx_dtypes(&DType::Int32, &DType::Int32);
assert!(result.is_err());
let result = validate_amx_dtypes(&DType::BFloat16, &DType::BFloat16);
assert!(result.is_err());
}
#[test]
fn test_validate_amx_dtypes_supported() {
assert!(validate_amx_dtypes(&DType::Float32, &DType::Float32).is_ok());
assert!(validate_amx_dtypes(&DType::Float64, &DType::Float64).is_ok());
assert!(validate_amx_dtypes(&DType::Float16, &DType::Float16).is_ok());
assert!(validate_amx_dtypes(&DType::Float16, &DType::Float32).is_ok()); assert!(validate_amx_dtypes(&DType::Int16, &DType::Int16).is_ok());
}