runmat_runtime/
dispatcher.rs1use runmat_builtins::{builtin_functions, LogicalArray, NumericDType, Tensor, Value};
2
3use crate::{make_cell_with_shape, new_object_builtin};
4
5pub fn is_gpu_value(value: &Value) -> bool {
7 matches!(value, Value::GpuTensor(_))
8}
9
10pub fn value_contains_gpu(value: &Value) -> bool {
12 match value {
13 Value::GpuTensor(_) => true,
14 Value::Cell(ca) => ca.data.iter().any(|ptr| value_contains_gpu(ptr)),
15 Value::Struct(sv) => sv.fields.values().any(value_contains_gpu),
16 Value::Object(obj) => obj.properties.values().any(value_contains_gpu),
17 _ => false,
18 }
19}
20
21pub fn gather_if_needed(value: &Value) -> Result<Value, String> {
24 match value {
25 Value::GpuTensor(handle) => {
26 #[cfg(all(test, feature = "wgpu"))]
28 {
29 if handle.device_id != 0 {
30 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
31 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
32 );
33 }
34 }
35 let provider = runmat_accelerate_api::provider_for_handle(handle)
36 .ok_or_else(|| "gather: no acceleration provider registered".to_string())?;
37 let is_logical = runmat_accelerate_api::handle_is_logical(handle);
38 let host = provider.download(handle).map_err(|e| e.to_string())?;
39 runmat_accelerate_api::clear_residency(handle);
40 let runmat_accelerate_api::HostTensorOwned { data, shape } = host;
41 if is_logical {
42 let bits: Vec<u8> = data.iter().map(|&v| if v != 0.0 { 1 } else { 0 }).collect();
43 let logical = LogicalArray::new(bits, shape).map_err(|e| e.to_string())?;
44 Ok(Value::LogicalArray(logical))
45 } else {
46 let mut data = data;
47 let precision = runmat_accelerate_api::handle_precision(handle)
48 .unwrap_or_else(|| provider.precision());
49 if matches!(precision, runmat_accelerate_api::ProviderPrecision::F32) {
50 for value in &mut data {
51 *value = (*value as f32) as f64;
52 }
53 }
54 let dtype = match precision {
55 runmat_accelerate_api::ProviderPrecision::F32 => NumericDType::F32,
56 runmat_accelerate_api::ProviderPrecision::F64 => NumericDType::F64,
57 };
58 let tensor =
59 Tensor::new_with_dtype(data, shape, dtype).map_err(|e| e.to_string())?;
60 Ok(Value::Tensor(tensor))
61 }
62 }
63 Value::Cell(ca) => {
64 let mut gathered = Vec::with_capacity(ca.data.len());
65 for ptr in &ca.data {
66 gathered.push(gather_if_needed(ptr)?);
67 }
68 make_cell_with_shape(gathered, ca.shape.clone())
69 }
70 Value::Struct(sv) => {
71 let mut gathered = sv.clone();
72 for value in gathered.fields.values_mut() {
73 let updated = gather_if_needed(value)?;
74 *value = updated;
75 }
76 Ok(Value::Struct(gathered))
77 }
78 Value::Object(obj) => {
79 let mut cloned = obj.clone();
80 for value in cloned.properties.values_mut() {
81 *value = gather_if_needed(value)?;
82 }
83 Ok(Value::Object(cloned))
84 }
85 other => Ok(other.clone()),
86 }
87}
88
89pub fn call_builtin(name: &str, args: &[Value]) -> Result<Value, String> {
93 let mut matching_builtins = Vec::new();
94
95 for b in builtin_functions() {
97 if b.name == name {
98 matching_builtins.push(b);
99 }
100 }
101
102 if matching_builtins.is_empty() {
103 if let Some(cls) = runmat_builtins::get_class(name) {
105 if let Some(ctor) = cls.methods.get(name) {
107 return call_builtin(&ctor.function_name, args);
109 }
110 return new_object_builtin(name.to_string());
112 }
113 return Err(format!(
114 "{}: Undefined function: {name}",
115 "MATLAB:UndefinedFunction"
116 ));
117 }
118
119 let mut no_category: Vec<&runmat_builtins::BuiltinFunction> = Vec::new();
121 let mut categorized: Vec<&runmat_builtins::BuiltinFunction> = Vec::new();
122 for b in matching_builtins {
123 if b.category.is_empty() {
124 no_category.push(b);
125 } else {
126 categorized.push(b);
127 }
128 }
129
130 let mut last_error = String::new();
133 for builtin in no_category
134 .into_iter()
135 .rev()
136 .chain(categorized.into_iter().rev())
137 {
138 let f = builtin.implementation;
139 match (f)(args) {
140 Ok(mut result) => {
141 if matches!(name, "eq" | "ne" | "gt" | "ge" | "lt" | "le") {
145 if let Value::Bool(flag) = result {
146 result = Value::Num(if flag { 1.0 } else { 0.0 });
147 }
148 }
149 return Ok(result);
150 }
151 Err(err) => {
152 if should_retry_with_gpu_gather(&err, args) {
153 match gather_args_for_retry(args) {
154 Ok(Some(gathered_args)) => match (f)(&gathered_args) {
155 Ok(result) => return Ok(result),
156 Err(retry_err) => last_error = retry_err,
157 },
158 Ok(None) => last_error = err,
159 Err(gather_err) => last_error = gather_err,
160 }
161 } else {
162 last_error = err;
163 }
164 }
165 }
166 }
167
168 Err(format!(
170 "No matching overload for `{}` with {} args: {}",
171 name,
172 args.len(),
173 last_error
174 ))
175}
176
177fn should_retry_with_gpu_gather(err: &str, args: &[Value]) -> bool {
178 if !args.iter().any(value_contains_gpu) {
179 return false;
180 }
181 let lowered = err.to_ascii_lowercase();
182 lowered.contains("gpu")
183}
184
185fn gather_args_for_retry(args: &[Value]) -> Result<Option<Vec<Value>>, String> {
186 let mut gathered_any = false;
187 let mut gathered_args = Vec::with_capacity(args.len());
188 for arg in args {
189 if value_contains_gpu(arg) {
190 gathered_args.push(gather_if_needed(arg)?);
191 gathered_any = true;
192 } else {
193 gathered_args.push(arg.clone());
194 }
195 }
196 if gathered_any {
197 Ok(Some(gathered_args))
198 } else {
199 Ok(None)
200 }
201}