#[cfg(not(target_arch = "wasm32"))]
use once_cell::sync::OnceCell;
use runmat_accelerate_api::{AccelContextHandle, AccelContextKind, WgpuContextHandle};
#[cfg(target_arch = "wasm32")]
use runmat_thread_local::runmat_thread_local;
#[cfg(target_arch = "wasm32")]
use std::cell::RefCell;
use crate::{build_runtime_error, BuiltinResult, RuntimeError};
#[cfg(not(target_arch = "wasm32"))]
static SHARED_WGPU_CONTEXT: OnceCell<WgpuContextHandle> = OnceCell::new();
#[cfg(target_arch = "wasm32")]
runmat_thread_local! {
static SHARED_WGPU_CONTEXT: RefCell<Option<WgpuContextHandle>> = RefCell::new(None);
}
pub fn shared_wgpu_context() -> Option<WgpuContextHandle> {
#[cfg(not(target_arch = "wasm32"))]
{
SHARED_WGPU_CONTEXT.get().cloned()
}
#[cfg(target_arch = "wasm32")]
{
SHARED_WGPU_CONTEXT.with(|cell| cell.borrow().clone())
}
}
pub fn install_wgpu_context(context: &WgpuContextHandle) {
#[cfg(not(target_arch = "wasm32"))]
{
let _ = SHARED_WGPU_CONTEXT.set(context.clone());
}
#[cfg(target_arch = "wasm32")]
{
SHARED_WGPU_CONTEXT.with(|cell| {
*cell.borrow_mut() = Some(context.clone());
});
}
propagate_to_plot_crate(context);
}
pub fn ensure_context_from_provider() -> BuiltinResult<WgpuContextHandle> {
if let Some(ctx) = shared_wgpu_context() {
return Ok(ctx);
}
let handle =
runmat_accelerate_api::export_context(AccelContextKind::Plotting).ok_or_else(|| {
context_error(
"plotting context unavailable (GPU provider did not export a shared device)",
)
})?;
match handle {
AccelContextHandle::Wgpu(ctx) => {
install_wgpu_context(&ctx);
Ok(ctx)
}
}
}
fn context_error(message: impl Into<String>) -> RuntimeError {
build_runtime_error(message)
.with_identifier("RunMat:plot:ContextUnavailable")
.build()
}
fn propagate_to_plot_crate(context: &WgpuContextHandle) {
#[cfg(any(
feature = "gui",
feature = "plot-core",
all(target_arch = "wasm32", feature = "plot-web")
))]
{
use runmat_plot::context::{
install_shared_wgpu_context as install_plot_context, SharedWgpuContext,
};
use runmat_plot::gpu::tuning as plot_tuning;
install_plot_context(SharedWgpuContext {
instance: context.instance.clone(),
device: context.device.clone(),
queue: context.queue.clone(),
adapter: context.adapter.clone(),
adapter_info: context.adapter_info.clone(),
limits: context.limits.clone(),
features: context.features,
});
if let Some(wg) = runmat_accelerate_api::workgroup_size_hint() {
plot_tuning::set_effective_workgroup_size(wg);
}
}
#[cfg(not(any(feature = "gui", all(target_arch = "wasm32", feature = "plot-web"))))]
{
let _ = context;
}
}
#[cfg(all(test, feature = "plot-core"))]
pub(crate) mod tests {
use super::*;
use pollster::FutureExt;
use std::sync::Arc;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn install_context_propagates_to_plot_crate() {
if std::env::var("RUNMAT_PLOT_SKIP_GPU_TESTS").is_ok() {
return;
}
let instance = Arc::new(wgpu::Instance::default());
let adapter = match instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
})
.block_on()
{
Some(adapter) => adapter,
None => return,
};
let adapter_info = adapter.get_info();
let adapter_limits = adapter.limits();
let adapter_features = adapter.features();
let adapter = Arc::new(adapter);
let (device, queue) = match adapter
.request_device(
&wgpu::DeviceDescriptor {
label: Some("plotting-context-test-device"),
required_features: adapter_features,
required_limits: adapter_limits.clone(),
},
None,
)
.block_on()
{
Ok(pair) => pair,
Err(_) => return,
};
let handle = WgpuContextHandle {
instance: instance.clone(),
device: Arc::new(device),
queue: Arc::new(queue),
adapter: adapter.clone(),
adapter_info,
limits: adapter_limits.clone(),
features: adapter_features,
};
install_wgpu_context(&handle);
assert!(shared_wgpu_context().is_some());
assert!(runmat_plot::shared_wgpu_context().is_some());
}
}