runmat_accelerate_api/
lib.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
4pub struct GpuTensorHandle {
5    pub shape: Vec<usize>,
6    pub device_id: u32,
7    pub buffer_id: u64,
8}
9
10/// Device/provider interface that backends implement and register into the runtime layer
11pub trait AccelProvider: Send + Sync {
12    fn upload(&self, host: &crate::HostTensorView) -> anyhow::Result<GpuTensorHandle>;
13    fn download(&self, h: &GpuTensorHandle) -> anyhow::Result<crate::HostTensorOwned>;
14    fn free(&self, h: &GpuTensorHandle) -> anyhow::Result<()>;
15    fn device_info(&self) -> String;
16
17    // Optional operator hooks (default to unsupported)
18    fn elem_add(
19        &self,
20        _a: &GpuTensorHandle,
21        _b: &GpuTensorHandle,
22    ) -> anyhow::Result<GpuTensorHandle> {
23        Err(anyhow::anyhow!("elem_add not supported by provider"))
24    }
25    fn elem_mul(
26        &self,
27        _a: &GpuTensorHandle,
28        _b: &GpuTensorHandle,
29    ) -> anyhow::Result<GpuTensorHandle> {
30        Err(anyhow::anyhow!("elem_mul not supported by provider"))
31    }
32    fn elem_sub(
33        &self,
34        _a: &GpuTensorHandle,
35        _b: &GpuTensorHandle,
36    ) -> anyhow::Result<GpuTensorHandle> {
37        Err(anyhow::anyhow!("elem_sub not supported by provider"))
38    }
39    fn elem_div(
40        &self,
41        _a: &GpuTensorHandle,
42        _b: &GpuTensorHandle,
43    ) -> anyhow::Result<GpuTensorHandle> {
44        Err(anyhow::anyhow!("elem_div not supported by provider"))
45    }
46    fn matmul(
47        &self,
48        _a: &GpuTensorHandle,
49        _b: &GpuTensorHandle,
50    ) -> anyhow::Result<GpuTensorHandle> {
51        Err(anyhow::anyhow!("matmul not supported by provider"))
52    }
53}
54
55static mut GLOBAL_PROVIDER: Option<&'static dyn AccelProvider> = None;
56
57/// Register a global acceleration provider.
58///
59/// # Safety
60/// - The caller must guarantee that `p` is valid for the entire program lifetime
61///   (e.g., a `'static` singleton), as the runtime stores a raw reference globally.
62/// - Concurrent callers must ensure registration happens once or is properly
63///   synchronized; this function does not enforce thread-safety for re-registration.
64pub unsafe fn register_provider(p: &'static dyn AccelProvider) {
65    GLOBAL_PROVIDER = Some(p);
66}
67
68pub fn provider() -> Option<&'static dyn AccelProvider> {
69    unsafe { GLOBAL_PROVIDER }
70}
71
72/// Convenience: perform elementwise add via provider if possible; otherwise return None
73pub fn try_elem_add(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
74    if let Some(p) = provider() {
75        if let Ok(h) = p.elem_add(a, b) {
76            return Some(h);
77        }
78    }
79    None
80}
81
82// Minimal host tensor views to avoid depending on runmat-builtins and cycles
83#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
84pub struct HostTensorOwned {
85    pub data: Vec<f64>,
86    pub shape: Vec<usize>,
87}
88
89#[derive(Debug)]
90pub struct HostTensorView<'a> {
91    pub data: &'a [f64],
92    pub shape: &'a [usize],
93}