1use runmat_builtins::{Tensor, Value};
12
13pub mod simple_provider;
14#[cfg(feature = "wgpu")]
15pub mod wgpu_backend;
16use serde::{Deserialize, Serialize};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20pub enum DeviceKind {
21 Cpu,
22 Cuda,
23 Rocm,
24 Metal,
25 Vulkan,
26 OpenCl,
27 Wgpu,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct DeviceInfo {
33 pub kind: DeviceKind,
34 pub name: String,
35 pub vendor: String,
36 pub memory_bytes: Option<u64>,
37 pub compute_capability: Option<String>,
38}
39
40pub trait BufferHandle: Send + Sync {
42 fn len(&self) -> usize;
43 fn is_empty(&self) -> bool {
44 self.len() == 0
45 }
46}
47
48pub trait DeviceMatrix: Send + Sync {
50 fn rows(&self) -> usize;
51 fn cols(&self) -> usize;
52 fn as_buffer(&self) -> &dyn BufferHandle;
53}
54
55pub trait AccelerateBackend: Send + Sync {
57 fn device_info(&self) -> DeviceInfo;
58
59 fn upload_matrix(&self, host: &Tensor) -> anyhow::Result<Box<dyn DeviceMatrix>>;
61 fn download_matrix(&self, dev: &dyn DeviceMatrix) -> anyhow::Result<Tensor>;
62
63 fn elem_add(
65 &self,
66 a: &dyn DeviceMatrix,
67 b: &dyn DeviceMatrix,
68 ) -> anyhow::Result<Box<dyn DeviceMatrix>>;
69 fn elem_sub(
70 &self,
71 a: &dyn DeviceMatrix,
72 b: &dyn DeviceMatrix,
73 ) -> anyhow::Result<Box<dyn DeviceMatrix>>;
74 fn elem_mul(
75 &self,
76 a: &dyn DeviceMatrix,
77 b: &dyn DeviceMatrix,
78 ) -> anyhow::Result<Box<dyn DeviceMatrix>>;
79 fn elem_div(
80 &self,
81 a: &dyn DeviceMatrix,
82 b: &dyn DeviceMatrix,
83 ) -> anyhow::Result<Box<dyn DeviceMatrix>>;
84 fn elem_pow(
85 &self,
86 a: &dyn DeviceMatrix,
87 b: &dyn DeviceMatrix,
88 ) -> anyhow::Result<Box<dyn DeviceMatrix>>;
89
90 fn matmul(
92 &self,
93 a: &dyn DeviceMatrix,
94 b: &dyn DeviceMatrix,
95 ) -> anyhow::Result<Box<dyn DeviceMatrix>>;
96 fn transpose(&self, a: &dyn DeviceMatrix) -> anyhow::Result<Box<dyn DeviceMatrix>>;
97}
98
99#[derive(Default)]
102pub struct Planner {
103 backend: Option<Box<dyn AccelerateBackend>>,
104}
105
106impl Planner {
107 pub fn new(backend: Option<Box<dyn AccelerateBackend>>) -> Self {
108 Self { backend }
109 }
110
111 pub fn device(&self) -> Option<&dyn AccelerateBackend> {
112 self.backend.as_deref()
113 }
114
115 pub fn choose_elem_add(&self, a: &Tensor, b: &Tensor) -> ExecutionTarget {
117 if let Some(bk) = &self.backend {
118 if a.data.len() >= 1 << 16 && a.rows() == b.rows() && a.cols() == b.cols() {
119 return ExecutionTarget::Gpu(bk.device_info());
120 }
121 }
122 ExecutionTarget::Cpu
123 }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub enum ExecutionTarget {
128 Cpu,
129 Gpu(DeviceInfo),
130}
131
132pub struct Accelerator {
134 planner: Planner,
135}
136
137impl Accelerator {
138 pub fn new(planner: Planner) -> Self {
139 Self { planner }
140 }
141
142 pub fn elementwise_add(&self, a: &Value, b: &Value) -> anyhow::Result<Value> {
143 match (a, b) {
144 (Value::Tensor(ma), Value::Tensor(mb)) => match self.planner.choose_elem_add(ma, mb) {
145 ExecutionTarget::Cpu => {
146 runmat_runtime::elementwise_add(a, b).map_err(|e| anyhow::anyhow!(e))
147 }
148 ExecutionTarget::Gpu(_) => {
149 let bk = self
150 .planner
151 .device()
152 .ok_or_else(|| anyhow::anyhow!("no backend"))?;
153 let da = bk.upload_matrix(ma)?;
154 let db = bk.upload_matrix(mb)?;
155 let dc = bk.elem_add(da.as_ref(), db.as_ref())?;
156 let out = bk.download_matrix(dc.as_ref())?;
157 Ok(Value::Tensor(out))
158 }
159 },
160 (Value::GpuTensor(ga), Value::GpuTensor(gb)) => {
161 let ha = self.gather_handle(ga)?;
164 let hb = self.gather_handle(gb)?;
165 self.elementwise_add(&ha, &hb)
166 }
167 (Value::GpuTensor(ga), other) => {
168 let ha = self.gather_handle(ga)?;
169 self.elementwise_add(&ha, other)
170 }
171 (other, Value::GpuTensor(gb)) => {
172 let hb = self.gather_handle(gb)?;
173 self.elementwise_add(other, &hb)
174 }
175 _ => runmat_runtime::elementwise_add(a, b).map_err(|e| anyhow::anyhow!(e)),
176 }
177 }
178
179 fn gather_handle(&self, h: &runmat_accelerate_api::GpuTensorHandle) -> anyhow::Result<Value> {
180 if let Some(p) = runmat_accelerate_api::provider() {
181 let ht = p.download(h).map_err(|e| anyhow::anyhow!(e))?;
182 let t = Tensor::new(ht.data, ht.shape).map_err(|e| anyhow::anyhow!(e))?;
183 Ok(Value::Tensor(t))
184 } else {
185 let shape = h.shape.clone();
187 let total: usize = shape.iter().product();
188 let zeros = Tensor::new(vec![0.0; total], shape).map_err(|e| anyhow::anyhow!(e))?;
189 Ok(Value::Tensor(zeros))
190 }
191 }
192}
193
194