Skip to main content

runmat_runtime/
dispatcher.rs

1use crate::{build_runtime_error, create_class_object, make_cell_with_shape, RuntimeError};
2use runmat_accelerate_api::{AccelProvider, GpuTensorHandle, GpuTensorStorage, HostTensorOwned};
3use runmat_builtins::{
4    builtin_functions, ComplexTensor, LogicalArray, NumericDType, Tensor, Value,
5};
6use std::cell::RefCell;
7
8thread_local! {
9    static CLASS_ACCESS_CONTEXT: RefCell<Option<String>> = const { RefCell::new(None) };
10}
11
12#[cfg(target_arch = "wasm32")]
13fn ensure_wasm_builtins_registered() {
14    crate::builtins::wasm_registry::register_all();
15}
16
17#[cfg(not(target_arch = "wasm32"))]
18fn ensure_wasm_builtins_registered() {}
19
20pub struct ClassAccessContextGuard {
21    previous: Option<String>,
22}
23
24impl Drop for ClassAccessContextGuard {
25    fn drop(&mut self) {
26        let previous = self.previous.take();
27        CLASS_ACCESS_CONTEXT.with(|slot| {
28            *slot.borrow_mut() = previous;
29        });
30    }
31}
32
33pub fn push_class_access_context(class_name: Option<String>) -> ClassAccessContextGuard {
34    let previous =
35        CLASS_ACCESS_CONTEXT.with(|slot| std::mem::replace(&mut *slot.borrow_mut(), class_name));
36    ClassAccessContextGuard { previous }
37}
38
39fn current_class_access_context() -> Option<String> {
40    CLASS_ACCESS_CONTEXT.with(|slot| slot.borrow().clone())
41}
42
43pub fn class_access_context() -> Option<String> {
44    current_class_access_context()
45}
46
47/// Return `true` when the passed value is a GPU-resident tensor handle.
48pub fn is_gpu_value(value: &Value) -> bool {
49    matches!(value, Value::GpuTensor(_))
50}
51
52/// Returns true when the value (or nested elements) contains any GPU-resident tensors.
53pub fn value_contains_gpu(value: &Value) -> bool {
54    match value {
55        Value::GpuTensor(_) => true,
56        Value::Cell(ca) => ca.data.iter().any(|ptr| value_contains_gpu(ptr)),
57        Value::Struct(sv) => sv.fields.values().any(value_contains_gpu),
58        Value::Object(obj) => obj.properties.values().any(value_contains_gpu),
59        Value::Closure(closure) => closure.captures.iter().any(value_contains_gpu),
60        Value::OutputList(values) => values.iter().any(value_contains_gpu),
61        _ => false,
62    }
63}
64
65/// Convert GPU-resident values to host tensors when an acceleration provider exists.
66/// Non-GPU inputs are passed through unchanged.
67pub async fn gather_if_needed_async(value: &Value) -> Result<Value, RuntimeError> {
68    gather_if_needed_async_impl(value).await
69}
70
71pub async fn download_handle_async(
72    provider: &dyn AccelProvider,
73    handle: &GpuTensorHandle,
74) -> anyhow::Result<HostTensorOwned> {
75    provider.download(handle).await
76}
77
78fn gather_if_needed_async_impl<'a>(
79    value: &'a Value,
80) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Value, RuntimeError>> + 'a>> {
81    Box::pin(async move {
82        match value {
83            Value::GpuTensor(handle) => {
84                // In parallel test runs, ensure the WGPU provider is reasserted for WGPU handles.
85                #[cfg(all(test, feature = "wgpu"))]
86                {
87                    if handle.device_id != 0 {
88                        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
89                        runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
90                    );
91                    }
92                }
93                let provider =
94                    runmat_accelerate_api::provider_for_handle(handle).ok_or_else(|| {
95                        build_runtime_error("gather: no acceleration provider registered")
96                            .with_identifier("RunMat:gather:ProviderUnavailable")
97                            .build()
98                    })?;
99                let is_logical = runmat_accelerate_api::handle_is_logical(handle);
100                let host = download_handle_async(provider, handle)
101                    .await
102                    .map_err(|err| {
103                        build_runtime_error(format!("gather: {err}"))
104                            .with_identifier("RunMat:gather:DownloadFailed")
105                            .build()
106                    })?;
107                runmat_accelerate_api::clear_residency(handle);
108                let runmat_accelerate_api::HostTensorOwned {
109                    data,
110                    shape,
111                    storage,
112                } = host;
113                if is_logical {
114                    let bits: Vec<u8> =
115                        data.iter().map(|&v| if v != 0.0 { 1 } else { 0 }).collect();
116                    let logical = LogicalArray::new(bits, shape).map_err(|e| {
117                        build_runtime_error(format!("gather: {e}"))
118                            .with_identifier("RunMat:gather:LogicalShapeError")
119                            .build()
120                    })?;
121                    Ok(Value::LogicalArray(logical))
122                } else if storage == GpuTensorStorage::ComplexInterleaved {
123                    let mut data = data;
124                    let precision = runmat_accelerate_api::handle_precision(handle)
125                        .unwrap_or_else(|| provider.precision());
126                    if matches!(precision, runmat_accelerate_api::ProviderPrecision::F32) {
127                        for value in &mut data {
128                            *value = (*value as f32) as f64;
129                        }
130                    }
131                    let mut complex = Vec::with_capacity(data.len() / 2);
132                    for chunk in data.chunks_exact(2) {
133                        complex.push((chunk[0], chunk[1]));
134                    }
135                    let tensor = ComplexTensor::new(complex, shape).map_err(|e| {
136                        build_runtime_error(format!("gather: {e}"))
137                            .with_identifier("RunMat:gather:TensorShapeError")
138                            .build()
139                    })?;
140                    Ok(Value::ComplexTensor(tensor))
141                } else {
142                    let mut data = data;
143                    let precision = runmat_accelerate_api::handle_precision(handle)
144                        .unwrap_or_else(|| provider.precision());
145                    if matches!(precision, runmat_accelerate_api::ProviderPrecision::F32) {
146                        for value in &mut data {
147                            *value = (*value as f32) as f64;
148                        }
149                    }
150                    let dtype = match precision {
151                        runmat_accelerate_api::ProviderPrecision::F32 => NumericDType::F32,
152                        runmat_accelerate_api::ProviderPrecision::F64 => NumericDType::F64,
153                    };
154                    let tensor = Tensor::new_with_dtype(data, shape, dtype).map_err(|e| {
155                        build_runtime_error(format!("gather: {e}"))
156                            .with_identifier("RunMat:gather:TensorShapeError")
157                            .build()
158                    })?;
159                    Ok(Value::Tensor(tensor))
160                }
161            }
162            Value::Cell(ca) => {
163                let mut gathered = Vec::with_capacity(ca.data.len());
164                for ptr in &ca.data {
165                    gathered.push(gather_if_needed_async_impl(ptr).await?);
166                }
167                make_cell_with_shape(gathered, ca.shape.clone()).map_err(|err| {
168                    build_runtime_error(format!("gather: {err}"))
169                        .with_identifier("RunMat:gather:CellShapeError")
170                        .build()
171                })
172            }
173            Value::Struct(sv) => {
174                let mut gathered = sv.clone();
175                for value in gathered.fields.values_mut() {
176                    let updated = gather_if_needed_async_impl(value).await?;
177                    *value = updated;
178                }
179                Ok(Value::Struct(gathered))
180            }
181            Value::Object(obj) => {
182                let mut cloned = obj.clone();
183                for value in cloned.properties.values_mut() {
184                    *value = gather_if_needed_async_impl(value).await?;
185                }
186                Ok(Value::Object(cloned))
187            }
188            Value::Closure(closure) => {
189                let mut cloned = closure.clone();
190                for value in &mut cloned.captures {
191                    *value = gather_if_needed_async_impl(value).await?;
192                }
193                Ok(Value::Closure(cloned))
194            }
195            Value::OutputList(values) => {
196                let mut gathered = Vec::with_capacity(values.len());
197                for value in values {
198                    gathered.push(gather_if_needed_async_impl(value).await?);
199                }
200                Ok(Value::OutputList(gathered))
201            }
202            other => Ok(other.clone()),
203        }
204    })
205}
206
207#[cfg(not(target_arch = "wasm32"))]
208pub fn gather_if_needed(value: &Value) -> Result<Value, RuntimeError> {
209    futures::executor::block_on(gather_if_needed_async(value))
210}
211
212#[cfg(target_arch = "wasm32")]
213pub fn gather_if_needed(_value: &Value) -> Result<Value, RuntimeError> {
214    Err(
215        build_runtime_error("gather: synchronous gather is unavailable on wasm")
216            .with_identifier("RunMat:gather:UnavailableOnWasm")
217            .build(),
218    )
219}
220
221/// Call a registered language builtin by name.
222/// Supports function overloading by trying different argument patterns.
223/// Returns an error if no builtin with that name and compatible arguments is found.
224pub fn call_builtin(name: &str, args: &[Value]) -> Result<Value, RuntimeError> {
225    futures::executor::block_on(call_builtin_async(name, args))
226}
227
228#[async_recursion::async_recursion(?Send)]
229async fn call_builtin_async_impl(
230    name: &str,
231    args: &[Value],
232    output_count: Option<usize>,
233) -> Result<Value, RuntimeError> {
234    ensure_wasm_builtins_registered();
235
236    let _output_guard = crate::output_count::push_output_count(output_count);
237    let mut matching_builtins = Vec::new();
238
239    // Collect all builtins with the matching name
240    for b in builtin_functions() {
241        if b.name == name {
242            matching_builtins.push(b);
243        }
244    }
245
246    if matching_builtins.is_empty() {
247        if let Some(result) = try_call_registered_instance_method(name, args, output_count).await? {
248            return Ok(result);
249        }
250        if let Some(result) = try_call_registered_static_method(name, args, output_count).await? {
251            return Ok(result);
252        }
253        // Fallback: treat as class constructor if class is registered.
254        if runmat_builtins::get_class(name).is_some() {
255            return call_registered_class_constructor(name, args, output_count).await;
256        }
257        return Err(build_runtime_error(format!("Undefined function: {name}"))
258            .with_identifier("RunMat:UndefinedFunction")
259            .build());
260    }
261
262    // Partition into no-category (tests/legacy shims) and categorized (library) builtins.
263    let mut no_category: Vec<&runmat_builtins::BuiltinFunction> = Vec::new();
264    let mut categorized: Vec<&runmat_builtins::BuiltinFunction> = Vec::new();
265    for b in matching_builtins {
266        if b.category.is_empty() {
267            no_category.push(b);
268        } else {
269            categorized.push(b);
270        }
271    }
272    let matching_count = no_category.len() + categorized.len();
273
274    // Try each builtin until one succeeds. Within each group, prefer later-registered
275    // implementations to allow overrides when names collide.
276    let mut last_error = RuntimeError::new("unknown error");
277    for builtin in no_category
278        .into_iter()
279        .rev()
280        .chain(categorized.into_iter().rev())
281    {
282        let f = builtin.implementation;
283        match (f)(args).await {
284            Ok(mut result) => {
285                // Normalize certain logical scalar results to numeric 0/1 for
286                // compatibility with legacy expectations in dispatcher tests
287                // and VM shims.
288                if matches!(name, "eq" | "ne" | "gt" | "ge" | "lt" | "le") {
289                    if let Value::Bool(flag) = result {
290                        result = Value::Num(if flag { 1.0 } else { 0.0 });
291                    }
292                }
293                return Ok(result);
294            }
295            Err(err) => {
296                if should_retry_with_gpu_gather(&err, args) {
297                    match gather_args_for_retry_async(args).await {
298                        Ok(Some(gathered_args)) => match (f)(&gathered_args).await {
299                            Ok(result) => return Ok(result),
300                            Err(retry_err) => last_error = retry_err,
301                        },
302                        Ok(None) => last_error = err,
303                        Err(gather_err) => last_error = gather_err,
304                    }
305                } else {
306                    last_error = err;
307                }
308            }
309        }
310    }
311
312    // A single implementation already knows whether its inputs are invalid or
313    // whether execution failed. Preserve that error verbatim instead of
314    // presenting it as overload resolution noise.
315    if matching_count == 1 || last_error.identifier().is_some() {
316        return Err(last_error);
317    }
318
319    // If none succeeded, return the last error
320    let identifier = last_error
321        .identifier()
322        .unwrap_or("RunMat:NoMatchingOverload")
323        .to_string();
324    let mut builder = build_runtime_error(format!(
325        "No matching overload for `{}` with {} args: {}",
326        name,
327        args.len(),
328        last_error.message()
329    ))
330    .with_source(last_error);
331    builder = builder.with_identifier(identifier);
332    Err(builder.build())
333}
334
335async fn try_call_registered_instance_method(
336    method_name: &str,
337    args: &[Value],
338    output_count: Option<usize>,
339) -> Result<Option<Value>, RuntimeError> {
340    let Some(receiver) = args.first() else {
341        return Ok(None);
342    };
343    let class_name = match receiver {
344        Value::Object(obj) => obj.class_name.as_str(),
345        Value::HandleObject(handle) => handle.class_name.as_str(),
346        _ => return Ok(None),
347    };
348    let Some((method, owner)) = runmat_builtins::lookup_method(class_name, method_name) else {
349        return Ok(None);
350    };
351    if method.is_static {
352        return Ok(None);
353    }
354    let caller_class = current_class_access_context();
355    let access_allowed = match method.access {
356        runmat_builtins::Access::Public => true,
357        runmat_builtins::Access::Private => caller_class.as_deref() == Some(owner.as_str()),
358        runmat_builtins::Access::Protected => caller_class
359            .as_deref()
360            .is_some_and(|caller| runmat_builtins::is_class_or_subclass(caller, &owner)),
361    };
362    if !access_allowed {
363        return Err(build_runtime_error(format!(
364            "Method '{}' is not accessible from current context.",
365            method_name
366        ))
367        .with_identifier("RunMat:MethodPrivate")
368        .build());
369    }
370    if let Some(result) = crate::user_functions::try_call_semantic_function_by_name(
371        &method.function_name,
372        args,
373        output_count.unwrap_or(1),
374    )
375    .await
376    {
377        return result.map(Some);
378    }
379    if runmat_builtins::builtin_function_by_name(&method.function_name).is_some()
380        && method.function_name != method_name
381    {
382        return call_builtin_async_impl(&method.function_name, args, output_count)
383            .await
384            .map(Some);
385    }
386    let owner_qualified = format!("{owner}.{method_name}");
387    if owner_qualified != method.function_name {
388        if let Some(result) = crate::user_functions::try_call_semantic_function_by_name(
389            &owner_qualified,
390            args,
391            output_count.unwrap_or(1),
392        )
393        .await
394        {
395            return result.map(Some);
396        }
397        if runmat_builtins::builtin_function_by_name(&owner_qualified).is_some()
398            && owner_qualified != method_name
399        {
400            return call_builtin_async_impl(&owner_qualified, args, output_count)
401                .await
402                .map(Some);
403        }
404    }
405    Ok(None)
406}
407
408async fn try_call_registered_static_method(
409    qualified_name: &str,
410    args: &[Value],
411    output_count: Option<usize>,
412) -> Result<Option<Value>, RuntimeError> {
413    let Some((class_name, method_name)) = qualified_name.rsplit_once('.') else {
414        return Ok(None);
415    };
416    if class_name.trim().is_empty() || method_name.trim().is_empty() {
417        return Ok(None);
418    }
419    if runmat_builtins::get_class(class_name).is_none() {
420        return Ok(None);
421    }
422    let Some((method, owner)) = runmat_builtins::lookup_method(class_name, method_name) else {
423        return Ok(None);
424    };
425    if !method.is_static || method.access != runmat_builtins::Access::Public {
426        return Ok(None);
427    }
428    if let Some(result) = crate::user_functions::try_call_semantic_function_by_name(
429        &method.function_name,
430        args,
431        output_count.unwrap_or(1),
432    )
433    .await
434    {
435        return result.map(Some);
436    }
437    if runmat_builtins::builtin_function_by_name(&method.function_name).is_some()
438        && method.function_name != qualified_name
439    {
440        return call_builtin_async_impl(&method.function_name, args, output_count)
441            .await
442            .map(Some);
443    }
444    let owner_qualified = format!("{owner}.{method_name}");
445    if owner_qualified != method.function_name {
446        if let Some(result) = crate::user_functions::try_call_semantic_function_by_name(
447            &owner_qualified,
448            args,
449            output_count.unwrap_or(1),
450        )
451        .await
452        {
453            return result.map(Some);
454        }
455        if runmat_builtins::builtin_function_by_name(&owner_qualified).is_some()
456            && owner_qualified != qualified_name
457        {
458            return call_builtin_async_impl(&owner_qualified, args, output_count)
459                .await
460                .map(Some);
461        }
462    }
463    Ok(None)
464}
465
466async fn call_registered_class_constructor(
467    class_name: &str,
468    args: &[Value],
469    output_count: Option<usize>,
470) -> Result<Value, RuntimeError> {
471    let requested_outputs = output_count.unwrap_or(1);
472    let default_object = create_class_object(class_name.to_string()).await?;
473    let constructor_method_name = class_name.rsplit('.').next().unwrap_or(class_name);
474    let Some((ctor, owner)) = runmat_builtins::lookup_method(class_name, constructor_method_name)
475        .or_else(|| runmat_builtins::lookup_method(class_name, class_name))
476    else {
477        return Ok(default_object);
478    };
479    let owner_qualified = format!("{owner}.{constructor_method_name}");
480    let caller_class = current_class_access_context();
481    let ctor_access_allowed = match ctor.access {
482        runmat_builtins::Access::Public => true,
483        runmat_builtins::Access::Private => caller_class.as_deref() == Some(owner.as_str()),
484        runmat_builtins::Access::Protected => caller_class
485            .as_deref()
486            .is_some_and(|caller| runmat_builtins::is_class_or_subclass(caller, &owner)),
487    };
488    if !ctor_access_allowed {
489        return Err(build_runtime_error(format!(
490            "Constructor '{}' is not accessible from current context.",
491            class_name
492        ))
493        .with_identifier("RunMat:MethodPrivate")
494        .build());
495    }
496    let constructor_result = crate::with_constructor_receiver(default_object.clone(), async {
497        if let Some(result) = crate::user_functions::try_call_semantic_function_by_name(
498            &ctor.function_name,
499            args,
500            requested_outputs,
501        )
502        .await
503        {
504            return Ok::<Option<Value>, RuntimeError>(Some(result?));
505        }
506        if runmat_builtins::builtin_function_by_name(&ctor.function_name).is_some()
507            && ctor.function_name != class_name
508        {
509            let result = call_builtin_async_impl(&ctor.function_name, args, output_count).await?;
510            return Ok::<Option<Value>, RuntimeError>(Some(result));
511        }
512        if let Some(result) = crate::user_functions::try_call_semantic_function_by_name(
513            &owner_qualified,
514            args,
515            requested_outputs,
516        )
517        .await
518        {
519            return Ok::<Option<Value>, RuntimeError>(Some(result?));
520        }
521        if runmat_builtins::builtin_function_by_name(&owner_qualified).is_some()
522            && owner_qualified != class_name
523        {
524            let result = call_builtin_async_impl(&owner_qualified, args, output_count).await?;
525            return Ok::<Option<Value>, RuntimeError>(Some(result));
526        }
527        Ok::<Option<Value>, RuntimeError>(None)
528    })
529    .await?;
530    let Some(result) = constructor_result else {
531        return Ok(default_object);
532    };
533    normalize_constructor_result(default_object, result, requested_outputs)
534}
535
536fn normalize_constructor_result(
537    default_object: Value,
538    result: Value,
539    requested_outputs: usize,
540) -> Result<Value, RuntimeError> {
541    if requested_outputs != 1 {
542        return Ok(result);
543    }
544    match result {
545        Value::Struct(struct_value) => match default_object {
546            Value::Object(mut object) => {
547                for (field, value) in struct_value.fields {
548                    object.properties.insert(field, value);
549                }
550                Ok(Value::Object(object))
551            }
552            Value::HandleObject(handle) => {
553                enum ConstructorMergeStatus {
554                    Merged,
555                    InvalidHandle,
556                    NonObject,
557                }
558
559                let merged = runmat_gc::gc_with_value_mut(&handle.target, |target| {
560                    if let Value::Object(object) = target {
561                        if !crate::object_handle_flag_valid(object) {
562                            return ConstructorMergeStatus::InvalidHandle;
563                        }
564                        for (field, value) in struct_value.fields {
565                            runmat_gc::gc_record_handle_write(&handle.target, &value);
566                            object.properties.insert(field, value);
567                        }
568                        ConstructorMergeStatus::Merged
569                    } else {
570                        ConstructorMergeStatus::NonObject
571                    }
572                })
573                .map_err(|e| {
574                    build_runtime_error(format!("constructor result handle target invalid: {e}"))
575                        .build()
576                })?;
577                match merged {
578                    ConstructorMergeStatus::Merged => {}
579                    ConstructorMergeStatus::InvalidHandle => {
580                        return Err(build_runtime_error(
581                            "constructor result handle target is invalid",
582                        )
583                        .build());
584                    }
585                    ConstructorMergeStatus::NonObject => {
586                        return Err(build_runtime_error(
587                            "constructor result handle target is not an object",
588                        )
589                        .build());
590                    }
591                }
592                Ok(Value::HandleObject(handle))
593            }
594            _ => Ok(Value::Struct(struct_value)),
595        },
596        Value::Object(_) | Value::HandleObject(_) => Ok(result),
597        _ => Ok(default_object),
598    }
599}
600
601pub async fn call_builtin_async(name: &str, args: &[Value]) -> Result<Value, RuntimeError> {
602    call_builtin_async_impl(name, args, None).await
603}
604
605pub async fn call_builtin_async_with_outputs(
606    name: &str,
607    args: &[Value],
608    output_count: usize,
609) -> Result<Value, RuntimeError> {
610    call_builtin_async_impl(name, args, Some(output_count)).await
611}
612
613fn should_retry_with_gpu_gather(err: &RuntimeError, args: &[Value]) -> bool {
614    if !args.iter().any(value_contains_gpu) {
615        return false;
616    }
617    let lowered = err.message().to_ascii_lowercase();
618    lowered.contains("gpu")
619}
620
621async fn gather_args_for_retry_async(args: &[Value]) -> Result<Option<Vec<Value>>, RuntimeError> {
622    let mut gathered_any = false;
623    let mut gathered_args = Vec::with_capacity(args.len());
624    for arg in args {
625        if value_contains_gpu(arg) {
626            gathered_args.push(gather_if_needed_async(arg).await?);
627            gathered_any = true;
628        } else {
629            gathered_args.push(arg.clone());
630        }
631    }
632    if gathered_any {
633        Ok(Some(gathered_args))
634    } else {
635        Ok(None)
636    }
637}
638
639#[cfg(test)]
640mod tests {
641    use super::{call_builtin, gather_if_needed_async, value_contains_gpu};
642    use runmat_accelerate_api::{GpuTensorHandle, ThreadProviderGuard};
643    use runmat_builtins::{
644        register_class, Access, ClassDef, Closure, MethodDef, StructValue, Value,
645    };
646    use std::collections::HashMap;
647    use std::sync::atomic::{AtomicU64, Ordering};
648
649    static TEST_CLASS_COUNTER: AtomicU64 = AtomicU64::new(0);
650
651    fn unique_class_name(prefix: &str) -> String {
652        let id = TEST_CLASS_COUNTER.fetch_add(1, Ordering::Relaxed);
653        format!("{}_{}", prefix, id)
654    }
655
656    #[test]
657    fn value_contains_gpu_detects_nested_closure_captures() {
658        let value = Value::Closure(Closure {
659            function_name: "worker".to_string(),
660            bound_function: None,
661            captures: vec![Value::GpuTensor(GpuTensorHandle {
662                shape: vec![1],
663                device_id: 999,
664                buffer_id: 42,
665            })],
666        });
667        assert!(value_contains_gpu(&value));
668    }
669
670    #[test]
671    fn value_contains_gpu_detects_output_list_entries() {
672        let value = Value::OutputList(vec![
673            Value::Num(1.0),
674            Value::GpuTensor(GpuTensorHandle {
675                shape: vec![1],
676                device_id: 998,
677                buffer_id: 43,
678            }),
679        ]);
680        assert!(value_contains_gpu(&value));
681    }
682
683    #[test]
684    fn gather_if_needed_reports_provider_unavailable_for_nested_output_list_gpu() {
685        runmat_accelerate_api::clear_provider();
686        let _provider_guard = ThreadProviderGuard::set(None);
687        let value = Value::OutputList(vec![Value::GpuTensor(GpuTensorHandle {
688            shape: vec![1],
689            // Keep device id at zero so test-only WGPU re-registration hooks are not triggered.
690            device_id: 0,
691            buffer_id: 44,
692        })]);
693        let err = futures::executor::block_on(gather_if_needed_async(&value))
694            .expect_err("missing provider should fail nested output-list gather");
695        assert_eq!(err.identifier(), Some("RunMat:gather:ProviderUnavailable"));
696    }
697
698    #[test]
699    fn gather_if_needed_reports_provider_unavailable_for_closure_capture_gpu() {
700        runmat_accelerate_api::clear_provider();
701        let _provider_guard = ThreadProviderGuard::set(None);
702        let value = Value::Closure(Closure {
703            function_name: "worker".to_string(),
704            bound_function: None,
705            captures: vec![Value::GpuTensor(GpuTensorHandle {
706                shape: vec![1],
707                // Keep device id at zero so test-only WGPU re-registration hooks are not triggered.
708                device_id: 0,
709                buffer_id: 45,
710            })],
711        });
712        let err = futures::executor::block_on(gather_if_needed_async(&value))
713            .expect_err("missing provider should fail closure-captured gather");
714        assert_eq!(err.identifier(), Some("RunMat:gather:ProviderUnavailable"));
715    }
716
717    #[test]
718    fn constructor_fallback_uses_inherited_constructor_metadata_with_semantic_invoker() {
719        let parent_name = unique_class_name("runtime_ctor_parent");
720        let child_name = unique_class_name("runtime_ctor_child");
721        let ctor_fn_name = unique_class_name("runtime_ctor_fn");
722        let ctor_fn_name_for_resolver = ctor_fn_name.clone();
723        let ctor_fn_name_for_invoker = ctor_fn_name.clone();
724        let _resolver_guard = crate::user_functions::install_semantic_function_resolver(Some(
725            std::sync::Arc::new(move |name| (name == ctor_fn_name_for_resolver).then_some(10101)),
726        ));
727        let _invoker_guard = crate::user_functions::install_semantic_function_invoker(Some(
728            std::sync::Arc::new(move |function, _args, requested_outputs| {
729                assert_eq!(function, 10101);
730                assert_eq!(requested_outputs, 1);
731                let mut sv = StructValue::new();
732                sv.fields.insert("x".to_string(), Value::Num(12.0));
733                Box::pin(async move { Ok(Value::Struct(sv)) })
734            }),
735        ));
736
737        let mut parent_methods = HashMap::new();
738        parent_methods.insert(
739            child_name.clone(),
740            MethodDef {
741                name: child_name.clone(),
742                is_static: true,
743                is_abstract: false,
744                is_sealed: false,
745                access: Access::Public,
746                function_name: ctor_fn_name_for_invoker,
747                implicit_class_argument: None,
748            },
749        );
750        register_class(ClassDef {
751            name: parent_name.clone(),
752            parent: None,
753            properties: HashMap::new(),
754            methods: parent_methods,
755        });
756        register_class(ClassDef {
757            name: child_name.clone(),
758            parent: Some(parent_name),
759            properties: HashMap::new(),
760            methods: HashMap::new(),
761        });
762
763        let out =
764            call_builtin(&child_name, &[]).expect("inherited static constructor should dispatch");
765        let Value::Object(obj) = out else {
766            panic!("expected object from constructor dispatch");
767        };
768        assert_eq!(obj.class_name, child_name);
769        assert_eq!(obj.properties.get("x"), Some(&Value::Num(12.0)));
770    }
771
772    #[test]
773    fn constructor_fallback_defaults_when_constructor_is_private_or_unavailable() {
774        let private_class_name = unique_class_name("runtime_ctor_private");
775        let mut private_methods = HashMap::new();
776        private_methods.insert(
777            private_class_name.clone(),
778            MethodDef {
779                name: private_class_name.clone(),
780                is_static: true,
781                is_abstract: false,
782                is_sealed: false,
783                access: Access::Private,
784                function_name: "Point.origin".to_string(),
785                implicit_class_argument: None,
786            },
787        );
788        register_class(ClassDef {
789            name: private_class_name.clone(),
790            parent: None,
791            properties: HashMap::new(),
792            methods: private_methods,
793        });
794        let err = call_builtin(&private_class_name, &[])
795            .expect_err("private constructor should enforce access before default fallback");
796        assert_eq!(err.identifier(), Some("RunMat:MethodPrivate"));
797
798        let public_class_name = unique_class_name("runtime_ctor_public_no_semantic");
799        let mut public_methods = HashMap::new();
800        public_methods.insert(
801            public_class_name.clone(),
802            MethodDef {
803                name: public_class_name.clone(),
804                is_static: true,
805                is_abstract: false,
806                is_sealed: false,
807                access: Access::Public,
808                function_name: unique_class_name("runtime_ctor_missing_body"),
809                implicit_class_argument: None,
810            },
811        );
812        register_class(ClassDef {
813            name: public_class_name.clone(),
814            parent: None,
815            properties: HashMap::new(),
816            methods: public_methods,
817        });
818
819        let out = call_builtin(&public_class_name, &[])
820            .expect("public ctor metadata without semantic body should default-construct");
821        let Value::Object(obj) = out else {
822            panic!("expected object result");
823        };
824        assert_eq!(obj.class_name, public_class_name);
825    }
826
827    #[test]
828    fn dotted_static_method_name_dispatches_to_registered_class_method() {
829        let class_name = unique_class_name("runtime_static_dispatch");
830        let fn_name = unique_class_name("runtime_static_fn");
831        register_class(ClassDef {
832            name: class_name.clone(),
833            parent: None,
834            properties: HashMap::new(),
835            methods: {
836                let mut methods = HashMap::new();
837                methods.insert(
838                    "zero".to_string(),
839                    MethodDef {
840                        name: "zero".to_string(),
841                        is_static: true,
842                        is_abstract: false,
843                        is_sealed: false,
844                        access: Access::Public,
845                        function_name: fn_name.clone(),
846                        implicit_class_argument: None,
847                    },
848                );
849                methods
850            },
851        });
852
853        let fn_name_for_resolver = fn_name.clone();
854        let _resolver_guard = crate::user_functions::install_semantic_function_resolver(Some(
855            std::sync::Arc::new(move |name| (name == fn_name_for_resolver).then_some(20202)),
856        ));
857        let _invoker_guard = crate::user_functions::install_semantic_function_invoker(Some(
858            std::sync::Arc::new(move |function, _args, requested_outputs| {
859                assert_eq!(function, 20202);
860                assert_eq!(requested_outputs, 1);
861                Box::pin(async { Ok(Value::Num(77.0)) })
862            }),
863        ));
864
865        let out = call_builtin(&format!("{class_name}.zero"), &[])
866            .expect("dotted static class method call should dispatch");
867        assert_eq!(out, Value::Num(77.0));
868    }
869}