runmat_runtime/
dispatcher.rs

1use runmat_builtins::{builtin_functions, LogicalArray, NumericDType, Tensor, Value};
2
3use crate::{make_cell_with_shape, new_object_builtin};
4
5/// Return `true` when the passed value is a GPU-resident tensor handle.
6pub fn is_gpu_value(value: &Value) -> bool {
7    matches!(value, Value::GpuTensor(_))
8}
9
10/// Returns true when the value (or nested elements) contains any GPU-resident tensors.
11pub fn value_contains_gpu(value: &Value) -> bool {
12    match value {
13        Value::GpuTensor(_) => true,
14        Value::Cell(ca) => ca.data.iter().any(|ptr| value_contains_gpu(ptr)),
15        Value::Struct(sv) => sv.fields.values().any(value_contains_gpu),
16        Value::Object(obj) => obj.properties.values().any(value_contains_gpu),
17        _ => false,
18    }
19}
20
21/// Convert GPU-resident values to host tensors when an acceleration provider exists.
22/// Non-GPU inputs are passed through unchanged.
23pub fn gather_if_needed(value: &Value) -> Result<Value, String> {
24    match value {
25        Value::GpuTensor(handle) => {
26            // In parallel test runs, ensure the WGPU provider is reasserted for WGPU handles.
27            #[cfg(all(test, feature = "wgpu"))]
28            {
29                if handle.device_id != 0 {
30                    let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
31                        runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
32                    );
33                }
34            }
35            let provider = runmat_accelerate_api::provider_for_handle(handle)
36                .ok_or_else(|| "gather: no acceleration provider registered".to_string())?;
37            let is_logical = runmat_accelerate_api::handle_is_logical(handle);
38            let host = provider.download(handle).map_err(|e| e.to_string())?;
39            runmat_accelerate_api::clear_residency(handle);
40            let runmat_accelerate_api::HostTensorOwned { data, shape } = host;
41            if is_logical {
42                let bits: Vec<u8> = data.iter().map(|&v| if v != 0.0 { 1 } else { 0 }).collect();
43                let logical = LogicalArray::new(bits, shape).map_err(|e| e.to_string())?;
44                Ok(Value::LogicalArray(logical))
45            } else {
46                let mut data = data;
47                let precision = runmat_accelerate_api::handle_precision(handle)
48                    .unwrap_or_else(|| provider.precision());
49                if matches!(precision, runmat_accelerate_api::ProviderPrecision::F32) {
50                    for value in &mut data {
51                        *value = (*value as f32) as f64;
52                    }
53                }
54                let dtype = match precision {
55                    runmat_accelerate_api::ProviderPrecision::F32 => NumericDType::F32,
56                    runmat_accelerate_api::ProviderPrecision::F64 => NumericDType::F64,
57                };
58                let tensor =
59                    Tensor::new_with_dtype(data, shape, dtype).map_err(|e| e.to_string())?;
60                Ok(Value::Tensor(tensor))
61            }
62        }
63        Value::Cell(ca) => {
64            let mut gathered = Vec::with_capacity(ca.data.len());
65            for ptr in &ca.data {
66                gathered.push(gather_if_needed(ptr)?);
67            }
68            make_cell_with_shape(gathered, ca.shape.clone())
69        }
70        Value::Struct(sv) => {
71            let mut gathered = sv.clone();
72            for value in gathered.fields.values_mut() {
73                let updated = gather_if_needed(value)?;
74                *value = updated;
75            }
76            Ok(Value::Struct(gathered))
77        }
78        Value::Object(obj) => {
79            let mut cloned = obj.clone();
80            for value in cloned.properties.values_mut() {
81                *value = gather_if_needed(value)?;
82            }
83            Ok(Value::Object(cloned))
84        }
85        other => Ok(other.clone()),
86    }
87}
88
89/// Call a registered language builtin by name.
90/// Supports function overloading by trying different argument patterns.
91/// Returns an error if no builtin with that name and compatible arguments is found.
92pub fn call_builtin(name: &str, args: &[Value]) -> Result<Value, String> {
93    let mut matching_builtins = Vec::new();
94
95    // Collect all builtins with the matching name
96    for b in builtin_functions() {
97        if b.name == name {
98            matching_builtins.push(b);
99        }
100    }
101
102    if matching_builtins.is_empty() {
103        // Fallback: treat as class constructor if class is registered
104        if let Some(cls) = runmat_builtins::get_class(name) {
105            // Prefer explicit constructor method with the same name as class (static)
106            if let Some(ctor) = cls.methods.get(name) {
107                // Dispatch to constructor builtin; pass args through
108                return call_builtin(&ctor.function_name, args);
109            }
110            // Otherwise default-construct object
111            return new_object_builtin(name.to_string());
112        }
113        return Err(format!(
114            "{}: Undefined function: {name}",
115            "MATLAB:UndefinedFunction"
116        ));
117    }
118
119    // Partition into no-category (tests/legacy shims) and categorized (library) builtins.
120    let mut no_category: Vec<&runmat_builtins::BuiltinFunction> = Vec::new();
121    let mut categorized: Vec<&runmat_builtins::BuiltinFunction> = Vec::new();
122    for b in matching_builtins {
123        if b.category.is_empty() {
124            no_category.push(b);
125        } else {
126            categorized.push(b);
127        }
128    }
129
130    // Try each builtin until one succeeds. Within each group, prefer later-registered
131    // implementations to allow overrides when names collide.
132    let mut last_error = String::new();
133    for builtin in no_category
134        .into_iter()
135        .rev()
136        .chain(categorized.into_iter().rev())
137    {
138        let f = builtin.implementation;
139        match (f)(args) {
140            Ok(mut result) => {
141                // Normalize certain logical scalar results to numeric 0/1 for
142                // compatibility with legacy expectations in dispatcher tests
143                // and VM shims.
144                if matches!(name, "eq" | "ne" | "gt" | "ge" | "lt" | "le") {
145                    if let Value::Bool(flag) = result {
146                        result = Value::Num(if flag { 1.0 } else { 0.0 });
147                    }
148                }
149                return Ok(result);
150            }
151            Err(err) => {
152                if should_retry_with_gpu_gather(&err, args) {
153                    match gather_args_for_retry(args) {
154                        Ok(Some(gathered_args)) => match (f)(&gathered_args) {
155                            Ok(result) => return Ok(result),
156                            Err(retry_err) => last_error = retry_err,
157                        },
158                        Ok(None) => last_error = err,
159                        Err(gather_err) => last_error = gather_err,
160                    }
161                } else {
162                    last_error = err;
163                }
164            }
165        }
166    }
167
168    // If none succeeded, return the last error
169    Err(format!(
170        "No matching overload for `{}` with {} args: {}",
171        name,
172        args.len(),
173        last_error
174    ))
175}
176
177fn should_retry_with_gpu_gather(err: &str, args: &[Value]) -> bool {
178    if !args.iter().any(value_contains_gpu) {
179        return false;
180    }
181    let lowered = err.to_ascii_lowercase();
182    lowered.contains("gpu")
183}
184
185fn gather_args_for_retry(args: &[Value]) -> Result<Option<Vec<Value>>, String> {
186    let mut gathered_any = false;
187    let mut gathered_args = Vec::with_capacity(args.len());
188    for arg in args {
189        if value_contains_gpu(arg) {
190            gathered_args.push(gather_if_needed(arg)?);
191            gathered_any = true;
192        } else {
193            gathered_args.push(arg.clone());
194        }
195    }
196    if gathered_any {
197        Ok(Some(gathered_args))
198    } else {
199        Ok(None)
200    }
201}