runmat_accelerate_api/
lib.rs1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
4pub struct GpuTensorHandle {
5 pub shape: Vec<usize>,
6 pub device_id: u32,
7 pub buffer_id: u64,
8}
9
10pub trait AccelProvider: Send + Sync {
12 fn upload(&self, host: &crate::HostTensorView) -> anyhow::Result<GpuTensorHandle>;
13 fn download(&self, h: &GpuTensorHandle) -> anyhow::Result<crate::HostTensorOwned>;
14 fn free(&self, h: &GpuTensorHandle) -> anyhow::Result<()>;
15 fn device_info(&self) -> String;
16
17 fn elem_add(
19 &self,
20 _a: &GpuTensorHandle,
21 _b: &GpuTensorHandle,
22 ) -> anyhow::Result<GpuTensorHandle> {
23 Err(anyhow::anyhow!("elem_add not supported by provider"))
24 }
25 fn elem_mul(
26 &self,
27 _a: &GpuTensorHandle,
28 _b: &GpuTensorHandle,
29 ) -> anyhow::Result<GpuTensorHandle> {
30 Err(anyhow::anyhow!("elem_mul not supported by provider"))
31 }
32 fn elem_sub(
33 &self,
34 _a: &GpuTensorHandle,
35 _b: &GpuTensorHandle,
36 ) -> anyhow::Result<GpuTensorHandle> {
37 Err(anyhow::anyhow!("elem_sub not supported by provider"))
38 }
39 fn elem_div(
40 &self,
41 _a: &GpuTensorHandle,
42 _b: &GpuTensorHandle,
43 ) -> anyhow::Result<GpuTensorHandle> {
44 Err(anyhow::anyhow!("elem_div not supported by provider"))
45 }
46 fn matmul(
47 &self,
48 _a: &GpuTensorHandle,
49 _b: &GpuTensorHandle,
50 ) -> anyhow::Result<GpuTensorHandle> {
51 Err(anyhow::anyhow!("matmul not supported by provider"))
52 }
53}
54
55static mut GLOBAL_PROVIDER: Option<&'static dyn AccelProvider> = None;
56
57pub unsafe fn register_provider(p: &'static dyn AccelProvider) {
65 GLOBAL_PROVIDER = Some(p);
66}
67
68pub fn provider() -> Option<&'static dyn AccelProvider> {
69 unsafe { GLOBAL_PROVIDER }
70}
71
72pub fn try_elem_add(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
74 if let Some(p) = provider() {
75 if let Ok(h) = p.elem_add(a, b) {
76 return Some(h);
77 }
78 }
79 None
80}
81
82#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
84pub struct HostTensorOwned {
85 pub data: Vec<f64>,
86 pub shape: Vec<usize>,
87}
88
89#[derive(Debug)]
90pub struct HostTensorView<'a> {
91 pub data: &'a [f64],
92 pub shape: &'a [usize],
93}