runmat-runtime 0.4.1

Core runtime for RunMat with builtins, BLAS/LAPACK integration, and execution APIs
Documentation
use crate::{build_runtime_error, make_cell_with_shape, new_object_builtin, RuntimeError};
use runmat_accelerate_api::{AccelProvider, GpuTensorHandle, GpuTensorStorage, HostTensorOwned};
use runmat_builtins::{
    builtin_functions, ComplexTensor, LogicalArray, NumericDType, Tensor, Value,
};

/// Return `true` when the passed value is a GPU-resident tensor handle.
pub fn is_gpu_value(value: &Value) -> bool {
    matches!(value, Value::GpuTensor(_))
}

/// Returns true when the value (or nested elements) contains any GPU-resident tensors.
pub fn value_contains_gpu(value: &Value) -> bool {
    match value {
        Value::GpuTensor(_) => true,
        Value::Cell(ca) => ca.data.iter().any(|ptr| value_contains_gpu(ptr)),
        Value::Struct(sv) => sv.fields.values().any(value_contains_gpu),
        Value::Object(obj) => obj.properties.values().any(value_contains_gpu),
        _ => false,
    }
}

/// Convert GPU-resident values to host tensors when an acceleration provider exists.
/// Non-GPU inputs are passed through unchanged.
pub async fn gather_if_needed_async(value: &Value) -> Result<Value, RuntimeError> {
    gather_if_needed_async_impl(value).await
}

pub async fn download_handle_async(
    provider: &dyn AccelProvider,
    handle: &GpuTensorHandle,
) -> anyhow::Result<HostTensorOwned> {
    provider.download(handle).await
}

fn gather_if_needed_async_impl<'a>(
    value: &'a Value,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Value, RuntimeError>> + 'a>> {
    Box::pin(async move {
        match value {
            Value::GpuTensor(handle) => {
                // In parallel test runs, ensure the WGPU provider is reasserted for WGPU handles.
                #[cfg(all(test, feature = "wgpu"))]
                {
                    if handle.device_id != 0 {
                        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
                        runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
                    );
                    }
                }
                let provider =
                    runmat_accelerate_api::provider_for_handle(handle).ok_or_else(|| {
                        build_runtime_error("gather: no acceleration provider registered")
                            .with_identifier("RunMat:gather:ProviderUnavailable")
                            .build()
                    })?;
                let is_logical = runmat_accelerate_api::handle_is_logical(handle);
                let host = download_handle_async(provider, handle)
                    .await
                    .map_err(|err| {
                        build_runtime_error(format!("gather: {err}"))
                            .with_identifier("RunMat:gather:DownloadFailed")
                            .build()
                    })?;
                runmat_accelerate_api::clear_residency(handle);
                let runmat_accelerate_api::HostTensorOwned {
                    data,
                    shape,
                    storage,
                } = host;
                if is_logical {
                    let bits: Vec<u8> =
                        data.iter().map(|&v| if v != 0.0 { 1 } else { 0 }).collect();
                    let logical = LogicalArray::new(bits, shape).map_err(|e| {
                        build_runtime_error(format!("gather: {e}"))
                            .with_identifier("RunMat:gather:LogicalShapeError")
                            .build()
                    })?;
                    Ok(Value::LogicalArray(logical))
                } else if storage == GpuTensorStorage::ComplexInterleaved {
                    let mut data = data;
                    let precision = runmat_accelerate_api::handle_precision(handle)
                        .unwrap_or_else(|| provider.precision());
                    if matches!(precision, runmat_accelerate_api::ProviderPrecision::F32) {
                        for value in &mut data {
                            *value = (*value as f32) as f64;
                        }
                    }
                    let mut complex = Vec::with_capacity(data.len() / 2);
                    for chunk in data.chunks_exact(2) {
                        complex.push((chunk[0], chunk[1]));
                    }
                    let tensor = ComplexTensor::new(complex, shape).map_err(|e| {
                        build_runtime_error(format!("gather: {e}"))
                            .with_identifier("RunMat:gather:TensorShapeError")
                            .build()
                    })?;
                    Ok(Value::ComplexTensor(tensor))
                } else {
                    let mut data = data;
                    let precision = runmat_accelerate_api::handle_precision(handle)
                        .unwrap_or_else(|| provider.precision());
                    if matches!(precision, runmat_accelerate_api::ProviderPrecision::F32) {
                        for value in &mut data {
                            *value = (*value as f32) as f64;
                        }
                    }
                    let dtype = match precision {
                        runmat_accelerate_api::ProviderPrecision::F32 => NumericDType::F32,
                        runmat_accelerate_api::ProviderPrecision::F64 => NumericDType::F64,
                    };
                    let tensor = Tensor::new_with_dtype(data, shape, dtype).map_err(|e| {
                        build_runtime_error(format!("gather: {e}"))
                            .with_identifier("RunMat:gather:TensorShapeError")
                            .build()
                    })?;
                    Ok(Value::Tensor(tensor))
                }
            }
            Value::Cell(ca) => {
                let mut gathered = Vec::with_capacity(ca.data.len());
                for ptr in &ca.data {
                    gathered.push(gather_if_needed_async_impl(ptr).await?);
                }
                make_cell_with_shape(gathered, ca.shape.clone()).map_err(|err| {
                    build_runtime_error(format!("gather: {err}"))
                        .with_identifier("RunMat:gather:CellShapeError")
                        .build()
                })
            }
            Value::Struct(sv) => {
                let mut gathered = sv.clone();
                for value in gathered.fields.values_mut() {
                    let updated = gather_if_needed_async_impl(value).await?;
                    *value = updated;
                }
                Ok(Value::Struct(gathered))
            }
            Value::Object(obj) => {
                let mut cloned = obj.clone();
                for value in cloned.properties.values_mut() {
                    *value = gather_if_needed_async_impl(value).await?;
                }
                Ok(Value::Object(cloned))
            }
            other => Ok(other.clone()),
        }
    })
}

#[cfg(not(target_arch = "wasm32"))]
pub fn gather_if_needed(value: &Value) -> Result<Value, RuntimeError> {
    futures::executor::block_on(gather_if_needed_async(value))
}

#[cfg(target_arch = "wasm32")]
pub fn gather_if_needed(_value: &Value) -> Result<Value, RuntimeError> {
    Err(
        build_runtime_error("gather: synchronous gather is unavailable on wasm")
            .with_identifier("RunMat:gather:UnavailableOnWasm")
            .build(),
    )
}

/// Call a registered language builtin by name.
/// Supports function overloading by trying different argument patterns.
/// Returns an error if no builtin with that name and compatible arguments is found.
pub fn call_builtin(name: &str, args: &[Value]) -> Result<Value, RuntimeError> {
    futures::executor::block_on(call_builtin_async(name, args))
}

#[async_recursion::async_recursion(?Send)]
async fn call_builtin_async_impl(
    name: &str,
    args: &[Value],
    output_count: Option<usize>,
) -> Result<Value, RuntimeError> {
    let _output_guard = crate::output_count::push_output_count(output_count);
    let mut matching_builtins = Vec::new();

    // Collect all builtins with the matching name
    for b in builtin_functions() {
        if b.name == name {
            matching_builtins.push(b);
        }
    }

    if matching_builtins.is_empty() {
        // Fallback: treat as class constructor if class is registered
        if let Some(cls) = runmat_builtins::get_class(name) {
            // Prefer explicit constructor method with the same name as class (static)
            if let Some(ctor) = cls.methods.get(name) {
                // Dispatch to constructor builtin; pass args through
                return call_builtin_async_impl(&ctor.function_name, args, output_count).await;
            }
            // Otherwise default-construct object
            return new_object_builtin(name.to_string()).await;
        }
        return Err(build_runtime_error(format!("Undefined function: {name}"))
            .with_identifier("RunMat:UndefinedFunction")
            .build());
    }

    // Partition into no-category (tests/legacy shims) and categorized (library) builtins.
    let mut no_category: Vec<&runmat_builtins::BuiltinFunction> = Vec::new();
    let mut categorized: Vec<&runmat_builtins::BuiltinFunction> = Vec::new();
    for b in matching_builtins {
        if b.category.is_empty() {
            no_category.push(b);
        } else {
            categorized.push(b);
        }
    }

    // Try each builtin until one succeeds. Within each group, prefer later-registered
    // implementations to allow overrides when names collide.
    let mut last_error = RuntimeError::new("unknown error");
    for builtin in no_category
        .into_iter()
        .rev()
        .chain(categorized.into_iter().rev())
    {
        let f = builtin.implementation;
        match (f)(args).await {
            Ok(mut result) => {
                // Normalize certain logical scalar results to numeric 0/1 for
                // compatibility with legacy expectations in dispatcher tests
                // and VM shims.
                if matches!(name, "eq" | "ne" | "gt" | "ge" | "lt" | "le") {
                    if let Value::Bool(flag) = result {
                        result = Value::Num(if flag { 1.0 } else { 0.0 });
                    }
                }
                return Ok(result);
            }
            Err(err) => {
                if should_retry_with_gpu_gather(&err, args) {
                    match gather_args_for_retry_async(args).await {
                        Ok(Some(gathered_args)) => match (f)(&gathered_args).await {
                            Ok(result) => return Ok(result),
                            Err(retry_err) => last_error = retry_err,
                        },
                        Ok(None) => last_error = err,
                        Err(gather_err) => last_error = gather_err,
                    }
                } else {
                    last_error = err;
                }
            }
        }
    }

    // If none succeeded, return the last error
    let identifier = last_error
        .identifier()
        .unwrap_or("RunMat:NoMatchingOverload")
        .to_string();
    let mut builder = build_runtime_error(format!(
        "No matching overload for `{}` with {} args: {}",
        name,
        args.len(),
        last_error.message()
    ))
    .with_source(last_error);
    builder = builder.with_identifier(identifier);
    Err(builder.build())
}

pub async fn call_builtin_async(name: &str, args: &[Value]) -> Result<Value, RuntimeError> {
    call_builtin_async_impl(name, args, None).await
}

pub async fn call_builtin_async_with_outputs(
    name: &str,
    args: &[Value],
    output_count: usize,
) -> Result<Value, RuntimeError> {
    call_builtin_async_impl(name, args, Some(output_count)).await
}

fn should_retry_with_gpu_gather(err: &RuntimeError, args: &[Value]) -> bool {
    if !args.iter().any(value_contains_gpu) {
        return false;
    }
    let lowered = err.message().to_ascii_lowercase();
    lowered.contains("gpu")
}

async fn gather_args_for_retry_async(args: &[Value]) -> Result<Option<Vec<Value>>, RuntimeError> {
    let mut gathered_any = false;
    let mut gathered_args = Vec::with_capacity(args.len());
    for arg in args {
        if value_contains_gpu(arg) {
            gathered_args.push(gather_if_needed_async(arg).await?);
            gathered_any = true;
        } else {
            gathered_args.push(arg.clone());
        }
    }
    if gathered_any {
        Ok(Some(gathered_args))
    } else {
        Ok(None)
    }
}