1#![warn(missing_docs)]
31#![warn(clippy::all)]
32#![allow(clippy::module_name_repetitions)]
33#![allow(clippy::missing_safety_doc)]
34
35pub mod context;
36pub mod conv;
37pub mod elementwise;
38pub mod error;
39pub mod matmul;
40pub mod reduce;
41pub mod softmax;
42
43pub use context::CudaContext;
44pub use error::CudaDispatchError as CudaError;
45
46use std::collections::HashMap;
47
48use oxionnx_core::graph::{Node, OpKind};
49use oxionnx_core::{OnnxError, Tensor};
50
51pub fn try_cuda_dispatch(
58 node: &Node,
59 weights: &HashMap<String, Tensor>,
60 intermediates: &HashMap<String, Tensor>,
61 ctx: &CudaContext,
62) -> Result<Option<Vec<Tensor>>, OnnxError> {
63 let resolve = |name: &str| -> Option<&Tensor> {
64 if name.is_empty() {
65 None
66 } else {
67 intermediates.get(name).or_else(|| weights.get(name))
68 }
69 };
70
71 match &node.op {
72 OpKind::MatMul | OpKind::Gemm => {
76 let a = resolve(&node.inputs[0]);
77 let b = resolve(&node.inputs[1]);
78 if let (Some(a), Some(b)) = (a, b) {
79 let is_gemm = matches!(node.op, OpKind::Gemm);
81 let alpha = if is_gemm {
82 node.attrs.f("alpha", 1.0)
83 } else {
84 1.0
85 };
86 let beta = if is_gemm {
87 node.attrs.f("beta", 1.0)
88 } else {
89 0.0
90 };
91 let trans_a = is_gemm && node.attrs.i("transA", 0) != 0;
92 let trans_b = is_gemm && node.attrs.i("transB", 0) != 0;
93
94 let an = a.ndim();
95 let bn = b.ndim();
96 if an >= 2 && bn >= 2 {
97 let m = if trans_a {
99 a.shape[an - 1]
100 } else {
101 a.shape[an - 2]
102 };
103 let k = if trans_a {
104 a.shape[an - 2]
105 } else {
106 a.shape[an - 1]
107 };
108 let n = if trans_b {
109 b.shape[bn - 2]
110 } else {
111 b.shape[bn - 1]
112 };
113 let batch: usize = a.shape[..an - 2].iter().product::<usize>().max(1);
114
115 let a_data = if trans_a {
117 transpose_2d_batched(&a.data, batch, a.shape[an - 2], a.shape[an - 1])
118 } else {
119 a.data.clone()
120 };
121 let b_data = if trans_b {
122 transpose_2d_batched(&b.data, batch, b.shape[bn - 2], b.shape[bn - 1])
123 } else {
124 b.data.clone()
125 };
126
127 let slice_a = m * k;
128 let slice_b = k * n;
129 let slice_c = m * n;
130
131 let mut out = Vec::with_capacity(batch * slice_c);
132 for i in 0..batch {
133 let a_start = i * slice_a;
134 let b_start = i * slice_b;
135 let mut c = matmul::cuda_matmul(
136 ctx,
137 &a_data[a_start..a_start + slice_a],
138 &b_data[b_start..b_start + slice_b],
139 m,
140 k,
141 n,
142 )
143 .map_err(OnnxError::from)?;
144
145 if (alpha - 1.0).abs() > f32::EPSILON {
147 for v in &mut c {
148 *v *= alpha;
149 }
150 }
151 out.append(&mut c);
152 }
153
154 if is_gemm && beta.abs() > f32::EPSILON {
156 if let Some(bias) = node.inputs.get(2).and_then(|n| resolve(n)) {
157 apply_gemm_bias(&mut out, &bias.data, m, n, beta);
158 }
159 }
160
161 let out_shape = if an > 2 {
162 let mut s = a.shape[..an - 2].to_vec();
163 s.push(m);
164 s.push(n);
165 s
166 } else {
167 vec![m, n]
168 };
169 return Ok(Some(vec![Tensor::new(out, out_shape)]));
170 }
171 }
172 Ok(None)
173 }
174
175 OpKind::Conv => {
179 let input = resolve(&node.inputs[0]);
180 let weight = resolve(&node.inputs[1]);
181 let bias = node.inputs.get(2).and_then(|n| resolve(n));
182 if let (Some(input), Some(weight)) = (input, weight) {
183 let attrs = &node.attrs;
184 let strides_v = attrs.ints("strides");
185 let strides = [
186 strides_v.first().copied().unwrap_or(1) as usize,
187 strides_v.get(1).copied().unwrap_or(1) as usize,
188 ];
189 let pads_v = attrs.ints("pads");
190 let pads = [
191 pads_v.first().copied().unwrap_or(0) as usize,
192 pads_v.get(1).copied().unwrap_or(0) as usize,
193 pads_v.get(2).copied().unwrap_or(0) as usize,
194 pads_v.get(3).copied().unwrap_or(0) as usize,
195 ];
196 let dilations_v = attrs.ints("dilations");
197 let dilations = [
198 dilations_v.first().copied().unwrap_or(1) as usize,
199 dilations_v.get(1).copied().unwrap_or(1) as usize,
200 ];
201 let group = attrs.i("group", 1) as usize;
202
203 let conv_params = conv::ConvParams {
204 strides,
205 pads,
206 dilations,
207 group,
208 };
209
210 match conv::cuda_conv(ctx, input, weight, bias, &conv_params)
211 .map_err(OnnxError::from)?
212 {
213 Some(tensor) => return Ok(Some(vec![tensor])),
214 None => return Ok(None),
215 }
216 }
217 Ok(None)
218 }
219
220 OpKind::Relu
224 | OpKind::Sigmoid
225 | OpKind::Gelu
226 | OpKind::Tanh
227 | OpKind::Exp
228 | OpKind::Sqrt
229 | OpKind::Abs
230 | OpKind::Neg
231 | OpKind::Log
232 | OpKind::Ceil
233 | OpKind::Floor
234 | OpKind::HardSigmoid
235 | OpKind::HardSwish
236 | OpKind::SiLU
237 | OpKind::Softplus
238 | OpKind::LeakyRelu => {
239 let input = resolve(&node.inputs[0]);
240 if let Some(input) = input {
241 let op_name = node.op.as_str();
242 let out = elementwise::cuda_elementwise(ctx, &input.data, op_name)
243 .map_err(OnnxError::from)?;
244 return Ok(Some(vec![Tensor::new(out, input.shape.clone())]));
245 }
246 Ok(None)
247 }
248
249 OpKind::Add | OpKind::Sub | OpKind::Mul | OpKind::Div => {
253 let a = resolve(&node.inputs[0]);
254 let b = resolve(&node.inputs[1]);
255 if let (Some(a), Some(b)) = (a, b) {
256 if a.shape == b.shape {
258 let op_name = node.op.as_str();
259 let out = elementwise::cuda_binary_elementwise(ctx, &a.data, &b.data, op_name)
260 .map_err(OnnxError::from)?;
261 return Ok(Some(vec![Tensor::new(out, a.shape.clone())]));
262 }
263 }
264 Ok(None)
265 }
266
267 OpKind::ReduceSum | OpKind::ReduceMax => {
271 let input = resolve(&node.inputs[0]);
272 if let Some(input) = input {
273 let axes = node.attrs.ints("axes");
274 if axes.len() == 1 {
275 let axis = axes[0] as usize;
276 let op_name = node.op.as_str();
277 match reduce::cuda_reduce(ctx, &input.data, &input.shape, axis, op_name)
278 .map_err(OnnxError::from)?
279 {
280 Some(out) => {
281 let mut out_shape = input.shape.clone();
282 if axis < out_shape.len() {
283 out_shape[axis] = 1;
284 }
285 return Ok(Some(vec![Tensor::new(out, out_shape)]));
286 }
287 None => return Ok(None),
288 }
289 }
290 }
291 Ok(None)
292 }
293
294 OpKind::Softmax => {
298 let input = resolve(&node.inputs[0]);
299 if let Some(input) = input {
300 match softmax::cuda_softmax(ctx, &input.data, &input.shape)
301 .map_err(OnnxError::from)?
302 {
303 Some(out) => {
304 return Ok(Some(vec![Tensor::new(out, input.shape.clone())]));
305 }
306 None => return Ok(None),
307 }
308 }
309 Ok(None)
310 }
311
312 _ => Ok(None),
313 }
314}
315
316fn transpose_2d_batched(data: &[f32], batch: usize, rows: usize, cols: usize) -> Vec<f32> {
321 let slice = rows * cols;
322 let mut out = vec![0.0_f32; data.len()];
323 for b in 0..batch {
324 let base_in = b * slice;
325 let base_out = b * slice;
326 for r in 0..rows {
327 for c in 0..cols {
328 out[base_out + c * rows + r] = data[base_in + r * cols + c];
329 }
330 }
331 }
332 out
333}
334
335fn apply_gemm_bias(out: &mut [f32], bias: &[f32], m: usize, n: usize, beta: f32) {
339 let total_rows = out.len() / n;
340 if bias.len() == n {
341 for row in 0..total_rows {
343 let base = row * n;
344 for col in 0..n {
345 out[base + col] += beta * bias[col];
346 }
347 }
348 } else if bias.len() == m * n {
349 for row in 0..total_rows {
351 let bias_row = row % m;
352 let base = row * n;
353 let bias_base = bias_row * n;
354 for col in 0..n {
355 out[base + col] += beta * bias[bias_base + col];
356 }
357 }
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use oxionnx_core::graph::{Attributes, Node, OpKind};
365
366 fn make_node(op: OpKind, inputs: &[&str], outputs: &[&str]) -> Node {
367 Node {
368 op,
369 name: "test_node".to_string(),
370 inputs: inputs.iter().map(|s| s.to_string()).collect(),
371 outputs: outputs.iter().map(|s| s.to_string()).collect(),
372 attrs: Attributes::default(),
373 }
374 }
375
376 #[test]
379 fn dispatch_unknown_op_returns_none() {
380 let node = make_node(OpKind::Identity, &["x"], &["y"]);
383 let weights: HashMap<String, Tensor> = HashMap::new();
384 let mut intermediates: HashMap<String, Tensor> = HashMap::new();
385 let t = Tensor::new(vec![1.0f32], vec![1]);
386 intermediates.insert("x".to_string(), t);
387
388 let _ = &node;
391 let _ = &weights;
392 let _ = &intermediates;
393 }
394
395 #[test]
396 fn cuda_context_try_new_no_panic() {
397 let _ctx = CudaContext::try_new();
399 }
400
401 #[test]
402 fn cuda_error_displays_correctly() {
403 let e = CudaError::Ptx("bad ptx".to_string());
404 let s = format!("{e}");
405 assert!(
406 s.contains("bad ptx"),
407 "Expected error message to contain 'bad ptx', got: {s}"
408 );
409 }
410
411 #[test]
412 fn cuda_error_maps_to_onnx_internal() {
413 let e = CudaError::Shape {
414 op: "Conv",
415 msg: "wrong shape".to_string(),
416 };
417 let onnx_err: OnnxError = e.into();
418 match onnx_err {
419 OnnxError::Internal(msg) => {
420 assert!(
421 msg.contains("wrong shape"),
422 "Expected 'wrong shape' in: {msg}"
423 );
424 }
425 other => panic!("Expected OnnxError::Internal, got: {other:?}"),
426 }
427 }
428}