use runmat_accelerate_api::{ApiDeviceInfo, ProviderPrecision};
use runmat_builtins::{IntValue, StructValue, Value};
use runmat_macros::runtime_builtin;
use crate::builtins::acceleration::gpu::type_resolvers::gpudevice_type;
use crate::builtins::common::spec::{
BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
ReductionNaN, ResidencyPolicy, ShapeRequirements,
};
use crate::{build_runtime_error, BuiltinResult, RuntimeError};
pub(crate) const ERR_NO_PROVIDER: &str = "gpuDevice: no acceleration provider registered";
const ERR_TOO_MANY_INPUTS: &str = "gpuDevice: too many input arguments";
const ERR_UNSUPPORTED_ARGUMENT: &str = "gpuDevice: unsupported input argument";
const ERR_RESET_NOT_SUPPORTED: &str = "gpuDevice: reset is not supported by the active provider";
const ERR_INVALID_INDEX: &str = "gpuDevice: device index must be a positive integer";
fn gpu_device_error(message: impl Into<String>) -> RuntimeError {
build_runtime_error(message)
.with_builtin("gpuDevice")
.build()
}
#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::acceleration::gpu::gpudevice")]
pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
name: "gpuDevice",
op_kind: GpuOpKind::Custom("device-info"),
supported_precisions: &[],
broadcast: BroadcastSemantics::None,
provider_hooks: &[],
constant_strategy: ConstantStrategy::InlineLiteral,
residency: ResidencyPolicy::GatherImmediately,
nan_mode: ReductionNaN::Include,
two_pass_threshold: None,
workgroup_size: None,
accepts_nan_mode: false,
notes: "Pure metadata query; does not enqueue GPU kernels. Returns an error when no provider is registered.",
};
#[runmat_macros::register_fusion_spec(
builtin_path = "crate::builtins::acceleration::gpu::gpudevice"
)]
pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
name: "gpuDevice",
shape: ShapeRequirements::Any,
constant_strategy: ConstantStrategy::InlineLiteral,
elementwise: None,
reduction: None,
emits_nan: false,
notes: "Not eligible for fusion; the builtin returns a host-resident struct.",
};
#[runtime_builtin(
name = "gpuDevice",
category = "acceleration/gpu",
summary = "Query metadata about the active GPU provider and return it as a MATLAB struct.",
keywords = "gpu,gpuDevice,device,info,accelerate",
examples = "info = gpuDevice();",
type_resolver(gpudevice_type),
builtin_path = "crate::builtins::acceleration::gpu::gpudevice"
)]
async fn gpu_device_builtin(args: Vec<Value>) -> crate::BuiltinResult<Value> {
match args.as_slice() {
[] => active_device_struct().map(Value::Struct),
[arg] => handle_single_argument(arg),
_ => Err(gpu_device_error(ERR_TOO_MANY_INPUTS).into()),
}
}
pub(crate) fn active_device_struct() -> BuiltinResult<StructValue> {
let provider =
runmat_accelerate_api::provider().ok_or_else(|| gpu_device_error(ERR_NO_PROVIDER))?;
let info = provider.device_info_struct();
let precision = provider.precision();
Ok(build_struct(&info, precision))
}
fn build_struct(info: &ApiDeviceInfo, precision: ProviderPrecision) -> StructValue {
let mut st = StructValue::new();
st.insert("device_id", Value::Int(IntValue::U32(info.device_id)));
st.insert(
"index",
Value::Int(IntValue::U32(info.device_id.saturating_add(1))),
);
st.insert("name", Value::String(info.name.clone()));
st.insert("vendor", Value::String(info.vendor.clone()));
if let Some(backend) = info.backend.as_ref() {
st.insert("backend", Value::String(backend.clone()));
}
if let Some(bytes) = info.memory_bytes {
st.insert("memory_bytes", Value::Int(IntValue::U64(bytes)));
}
st.insert(
"precision",
Value::String(
match precision {
ProviderPrecision::F64 => "double",
ProviderPrecision::F32 => "single",
}
.to_string(),
),
);
st.insert(
"supports_double",
Value::Bool(matches!(precision, ProviderPrecision::F64)),
);
st
}
fn is_keyword(value: &Value, keyword: &str) -> bool {
match value {
Value::String(s) => s.trim().eq_ignore_ascii_case(keyword),
Value::CharArray(ca) if ca.rows == 1 => {
let collected: String = ca.data.iter().collect();
collected.trim().eq_ignore_ascii_case(keyword)
}
_ => false,
}
}
fn handle_single_argument(arg: &Value) -> BuiltinResult<Value> {
if is_reset_arg(arg) {
return Err(gpu_device_error(ERR_RESET_NOT_SUPPORTED).into());
}
match parse_device_index(arg)? {
Some(index) => {
let info = active_device_struct()?;
let current_index = struct_device_index(&info).unwrap_or(1);
if index == current_index {
Ok(Value::Struct(info))
} else {
Err(gpu_device_error(format!(
"gpuDevice: GPU device with index {} not available",
index
))
.into())
}
}
None => Err(gpu_device_error(ERR_UNSUPPORTED_ARGUMENT).into()),
}
}
fn struct_device_index(info: &StructValue) -> Option<u32> {
info.fields.get("index").and_then(|value| match value {
Value::Int(intv) => {
let idx = intv.to_i64();
if idx >= 0 && idx <= u32::MAX as i64 {
Some(idx as u32)
} else {
None
}
}
Value::Num(n) if n.is_finite() && *n >= 0.0 => {
let rounded = n.round();
if (rounded - n).abs() <= 1e-9 {
Some(rounded as u32)
} else {
None
}
}
_ => None,
})
}
fn is_reset_arg(value: &Value) -> bool {
if is_keyword(value, "reset") {
return true;
}
match value {
Value::Tensor(t) => t.data.is_empty(),
Value::LogicalArray(la) => la.data.is_empty(),
_ => false,
}
}
fn parse_device_index(value: &Value) -> BuiltinResult<Option<u32>> {
match value {
Value::Int(i) => int_to_index(i.to_i64()),
Value::Num(n) => num_to_index(*n),
Value::Bool(b) => {
if *b {
Ok(Some(1))
} else {
Err(gpu_device_error(ERR_INVALID_INDEX).into())
}
}
Value::Tensor(t) => match t.data.len() {
0 => Ok(None),
1 => num_to_index(t.data[0]),
_ => Err(gpu_device_error(ERR_INVALID_INDEX).into()),
},
Value::LogicalArray(la) => match la.data.len() {
0 => Ok(None),
1 => {
if la.data[0] != 0 {
Ok(Some(1))
} else {
Err(gpu_device_error(ERR_INVALID_INDEX).into())
}
}
_ => Err(gpu_device_error(ERR_INVALID_INDEX).into()),
},
_ => Ok(None),
}
}
fn int_to_index(raw: i64) -> BuiltinResult<Option<u32>> {
if raw <= 0 {
return Err(gpu_device_error(ERR_INVALID_INDEX).into());
}
if raw > u32::MAX as i64 {
return Err(gpu_device_error(ERR_INVALID_INDEX).into());
}
Ok(Some(raw as u32))
}
fn num_to_index(raw: f64) -> BuiltinResult<Option<u32>> {
if !raw.is_finite() {
return Err(gpu_device_error(ERR_INVALID_INDEX).into());
}
let rounded = raw.round();
if (rounded - raw).abs() > 1e-9 {
return Err(gpu_device_error(ERR_INVALID_INDEX).into());
}
let idx = rounded as i64;
int_to_index(idx)
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::builtins::common::test_support;
use futures::executor::block_on;
use runmat_builtins::{ResolveContext, Type};
fn call(args: Vec<Value>) -> crate::BuiltinResult<Value> {
block_on(gpu_device_builtin(args))
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn gpu_device_returns_struct() {
test_support::with_test_provider(|_| {
let value = call(Vec::new()).expect("gpuDevice");
match value {
Value::Struct(s) => {
assert!(s.fields.contains_key("device_id"));
assert!(s.fields.contains_key("index"));
assert!(s.fields.contains_key("name"));
assert!(s.fields.contains_key("vendor"));
assert!(s.fields.contains_key("precision"));
assert!(s.fields.contains_key("supports_double"));
}
other => panic!("expected struct, got {other:?}"),
}
});
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn gpu_device_accepts_current_index() {
test_support::with_test_provider(|_| {
let tensor_scalar =
runmat_builtins::Tensor::new(vec![1.0], vec![1, 1]).expect("scalar tensor");
let logical_scalar = runmat_builtins::LogicalArray::new(vec![1u8], vec![1]).unwrap();
let cases = vec![
Value::Int(IntValue::I32(1)),
Value::Num(1.0),
Value::Bool(true),
Value::Tensor(tensor_scalar),
Value::LogicalArray(logical_scalar),
];
for case in cases {
let value = call(vec![case]).expect("gpuDevice");
assert!(matches!(value, Value::Struct(_)));
}
});
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn gpu_device_out_of_range_index_errors() {
test_support::with_test_provider(|_| {
let err = call(vec![Value::Num(2.0)]).unwrap_err().to_string();
assert!(
err.contains("gpuDevice: GPU device with index 2 not available"),
"unexpected error: {err}"
);
});
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn gpu_device_unsupported_argument_errors() {
test_support::with_test_provider(|_| {
let err = call(vec![Value::from("status")]).unwrap_err().to_string();
assert_eq!(err, ERR_UNSUPPORTED_ARGUMENT);
});
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn gpu_device_reset_argument_reports_not_supported() {
test_support::with_test_provider(|_| {
let err = call(vec![Value::from(" RESET ")]).unwrap_err().to_string();
assert_eq!(err, ERR_RESET_NOT_SUPPORTED);
});
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn gpu_device_reset_char_array_argument_reports_not_supported() {
test_support::with_test_provider(|_| {
let chars = runmat_builtins::CharArray::new("reset".chars().collect(), 1, 5).unwrap();
let err = call(vec![Value::CharArray(chars)]).unwrap_err().to_string();
assert_eq!(err, ERR_RESET_NOT_SUPPORTED);
});
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn gpu_device_empty_array_argument_reports_not_supported() {
test_support::with_test_provider(|_| {
let empty = runmat_builtins::Tensor::zeros(vec![0, 0]);
let err = call(vec![Value::Tensor(empty)]).unwrap_err().to_string();
assert_eq!(err, ERR_RESET_NOT_SUPPORTED);
});
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn gpu_device_invalid_index_rejected() {
test_support::with_test_provider(|_| {
let cases = vec![
Value::Num(0.0),
Value::Int(IntValue::I32(0)),
Value::Num(-1.0),
Value::Num(1.5),
Value::Bool(false),
Value::LogicalArray(
runmat_builtins::LogicalArray::new(vec![0u8], vec![1]).unwrap(),
),
Value::Tensor(runmat_builtins::Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap()),
];
for case in cases {
let err = call(vec![case]).unwrap_err().to_string();
assert_eq!(err, ERR_INVALID_INDEX);
}
});
}
#[test]
fn gpudevice_type_is_struct() {
assert!(matches!(
gpudevice_type(&[Type::Num], &ResolveContext::new(Vec::new())),
Type::Struct { .. }
));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "wgpu")]
fn gpu_device_wgpu_reports_metadata() {
use runmat_accelerate::backend::wgpu::provider as wgpu_provider;
let _ =
wgpu_provider::register_wgpu_provider(wgpu_provider::WgpuProviderOptions::default());
let value = call(Vec::new()).expect("gpuDevice");
match value {
Value::Struct(info) => {
let name = info
.fields
.get("name")
.and_then(|v| match v {
Value::String(s) => Some(s),
_ => None,
})
.expect("name field");
assert!(
!name.is_empty(),
"expected non-empty adapter name from wgpu provider"
);
let backend = info.fields.get("backend").and_then(|v| match v {
Value::String(s) => Some(s),
_ => None,
});
assert!(backend.is_some(), "expected backend field to be present");
if let Some(Value::Int(memory)) = info.fields.get("memory_bytes") {
assert!(
memory.to_i64() > 0,
"expected positive memory_bytes, got {:?}",
memory
);
}
if let Some(Value::Bool(supports_double)) = info.fields.get("supports_double") {
if *supports_double {
assert_eq!(
info.fields.get("precision"),
Some(&Value::String("double".to_string()))
);
}
}
}
other => panic!("expected struct, got {other:?}"),
}
}
}