1use std::fmt;
27
28#[derive(Debug, Clone, PartialEq, Eq)]
32pub enum BackendError {
33 Unsupported(String),
35 DeviceError(String),
37 InvalidArgument(String),
39 OutOfMemory,
41 NotInitialized,
43}
44
45impl fmt::Display for BackendError {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 Self::Unsupported(msg) => write!(f, "unsupported operation: {msg}"),
49 Self::DeviceError(msg) => write!(f, "device error: {msg}"),
50 Self::InvalidArgument(msg) => write!(f, "invalid argument: {msg}"),
51 Self::OutOfMemory => write!(f, "out of device memory"),
52 Self::NotInitialized => write!(f, "backend not initialized"),
53 }
54 }
55}
56
57impl std::error::Error for BackendError {}
58
59pub type BackendResult<T> = Result<T, BackendError>;
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
66pub enum BackendTranspose {
67 NoTrans,
69 Trans,
71 ConjTrans,
73}
74
75impl fmt::Display for BackendTranspose {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 match self {
78 Self::NoTrans => write!(f, "N"),
79 Self::Trans => write!(f, "T"),
80 Self::ConjTrans => write!(f, "C"),
81 }
82 }
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
87pub enum ReduceOp {
88 Sum,
90 Max,
92 Min,
94 Mean,
96}
97
98impl fmt::Display for ReduceOp {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 match self {
101 Self::Sum => write!(f, "sum"),
102 Self::Max => write!(f, "max"),
103 Self::Min => write!(f, "min"),
104 Self::Mean => write!(f, "mean"),
105 }
106 }
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
111pub enum UnaryOp {
112 Relu,
114 Sigmoid,
116 Tanh,
118 Exp,
120 Log,
122 Sqrt,
124 Abs,
126 Neg,
128}
129
130impl fmt::Display for UnaryOp {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 match self {
133 Self::Relu => write!(f, "relu"),
134 Self::Sigmoid => write!(f, "sigmoid"),
135 Self::Tanh => write!(f, "tanh"),
136 Self::Exp => write!(f, "exp"),
137 Self::Log => write!(f, "log"),
138 Self::Sqrt => write!(f, "sqrt"),
139 Self::Abs => write!(f, "abs"),
140 Self::Neg => write!(f, "neg"),
141 }
142 }
143}
144
145#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
147pub enum BinaryOp {
148 Add,
150 Sub,
152 Mul,
154 Div,
156 Max,
158 Min,
160}
161
162impl fmt::Display for BinaryOp {
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 match self {
165 Self::Add => write!(f, "add"),
166 Self::Sub => write!(f, "sub"),
167 Self::Mul => write!(f, "mul"),
168 Self::Div => write!(f, "div"),
169 Self::Max => write!(f, "max"),
170 Self::Min => write!(f, "min"),
171 }
172 }
173}
174
175pub trait ComputeBackend: Send + Sync + fmt::Debug {
199 fn name(&self) -> &str;
201
202 fn init(&mut self) -> BackendResult<()>;
207
208 fn is_initialized(&self) -> bool;
210
211 #[allow(clippy::too_many_arguments)]
221 fn gemm(
222 &self,
223 trans_a: BackendTranspose,
224 trans_b: BackendTranspose,
225 m: usize,
226 n: usize,
227 k: usize,
228 alpha: f64,
229 a_ptr: u64,
230 lda: usize,
231 b_ptr: u64,
232 ldb: usize,
233 beta: f64,
234 c_ptr: u64,
235 ldc: usize,
236 ) -> BackendResult<()>;
237
238 #[allow(clippy::too_many_arguments)]
251 fn conv2d_forward(
252 &self,
253 input_ptr: u64,
254 input_shape: &[usize],
255 filter_ptr: u64,
256 filter_shape: &[usize],
257 output_ptr: u64,
258 output_shape: &[usize],
259 stride: &[usize],
260 padding: &[usize],
261 ) -> BackendResult<()>;
262
263 #[allow(clippy::too_many_arguments)]
277 fn attention(
278 &self,
279 q_ptr: u64,
280 k_ptr: u64,
281 v_ptr: u64,
282 o_ptr: u64,
283 batch: usize,
284 heads: usize,
285 seq_q: usize,
286 seq_kv: usize,
287 head_dim: usize,
288 scale: f64,
289 causal: bool,
290 ) -> BackendResult<()>;
291
292 fn reduce(
296 &self,
297 op: ReduceOp,
298 input_ptr: u64,
299 output_ptr: u64,
300 shape: &[usize],
301 axis: usize,
302 ) -> BackendResult<()>;
303
304 fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()>;
308
309 fn binary(
313 &self,
314 op: BinaryOp,
315 a_ptr: u64,
316 b_ptr: u64,
317 output_ptr: u64,
318 n: usize,
319 ) -> BackendResult<()>;
320
321 #[allow(clippy::too_many_arguments)]
338 fn batched_gemm(
339 &self,
340 trans_a: BackendTranspose,
341 trans_b: BackendTranspose,
342 m: usize,
343 n: usize,
344 k: usize,
345 alpha: f64,
346 a_ptr: u64,
347 lda: usize,
348 stride_a: usize,
349 b_ptr: u64,
350 ldb: usize,
351 stride_b: usize,
352 beta: f64,
353 c_ptr: u64,
354 ldc: usize,
355 stride_c: usize,
356 batch_count: usize,
357 ) -> BackendResult<()> {
358 let elem_bytes: u64 = 4; for b in 0..batch_count {
362 let b64 = b as u64;
363 self.gemm(
364 trans_a,
365 trans_b,
366 m,
367 n,
368 k,
369 alpha,
370 a_ptr + b64 * stride_a as u64 * elem_bytes,
371 lda,
372 b_ptr + b64 * stride_b as u64 * elem_bytes,
373 ldb,
374 beta,
375 c_ptr + b64 * stride_c as u64 * elem_bytes,
376 ldc,
377 )?;
378 }
379 Ok(())
380 }
381
382 fn synchronize(&self) -> BackendResult<()>;
386
387 fn alloc(&self, bytes: usize) -> BackendResult<u64>;
392
393 fn free(&self, ptr: u64) -> BackendResult<()>;
395
396 fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()>;
401
402 fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()>;
407}
408
409#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn backend_error_display() {
417 assert_eq!(
418 BackendError::Unsupported("foo".into()).to_string(),
419 "unsupported operation: foo"
420 );
421 assert_eq!(
422 BackendError::DeviceError("bar".into()).to_string(),
423 "device error: bar"
424 );
425 assert_eq!(
426 BackendError::InvalidArgument("baz".into()).to_string(),
427 "invalid argument: baz"
428 );
429 assert_eq!(
430 BackendError::OutOfMemory.to_string(),
431 "out of device memory"
432 );
433 assert_eq!(
434 BackendError::NotInitialized.to_string(),
435 "backend not initialized"
436 );
437 }
438
439 #[test]
440 fn backend_error_is_std_error() {
441 let err: Box<dyn std::error::Error> = Box::new(BackendError::DeviceError("test".into()));
442 assert!(err.to_string().contains("test"));
443 }
444
445 #[test]
446 fn backend_transpose_display_and_values() {
447 assert_eq!(BackendTranspose::NoTrans.to_string(), "N");
448 assert_eq!(BackendTranspose::Trans.to_string(), "T");
449 assert_eq!(BackendTranspose::ConjTrans.to_string(), "C");
450
451 assert_eq!(BackendTranspose::NoTrans, BackendTranspose::NoTrans);
453 assert_ne!(BackendTranspose::NoTrans, BackendTranspose::Trans);
454 }
455
456 #[test]
457 fn reduce_op_display_and_coverage() {
458 let ops = [ReduceOp::Sum, ReduceOp::Max, ReduceOp::Min, ReduceOp::Mean];
459 let names = ["sum", "max", "min", "mean"];
460 for (op, name) in ops.iter().zip(names.iter()) {
461 assert_eq!(op.to_string(), *name);
462 }
463 }
464
465 #[test]
466 fn unary_op_display_and_coverage() {
467 let ops = [
468 UnaryOp::Relu,
469 UnaryOp::Sigmoid,
470 UnaryOp::Tanh,
471 UnaryOp::Exp,
472 UnaryOp::Log,
473 UnaryOp::Sqrt,
474 UnaryOp::Abs,
475 UnaryOp::Neg,
476 ];
477 let names = [
478 "relu", "sigmoid", "tanh", "exp", "log", "sqrt", "abs", "neg",
479 ];
480 for (op, name) in ops.iter().zip(names.iter()) {
481 assert_eq!(op.to_string(), *name);
482 }
483 }
484
485 #[test]
486 fn binary_op_display_and_coverage() {
487 let ops = [
488 BinaryOp::Add,
489 BinaryOp::Sub,
490 BinaryOp::Mul,
491 BinaryOp::Div,
492 BinaryOp::Max,
493 BinaryOp::Min,
494 ];
495 let names = ["add", "sub", "mul", "div", "max", "min"];
496 for (op, name) in ops.iter().zip(names.iter()) {
497 assert_eq!(op.to_string(), *name);
498 }
499 }
500
501 use std::sync::atomic::{AtomicUsize, Ordering};
504
505 #[derive(Debug)]
506 struct MockBackend {
507 gemm_call_count: AtomicUsize,
508 }
509
510 impl MockBackend {
511 fn new() -> Self {
512 Self {
513 gemm_call_count: AtomicUsize::new(0),
514 }
515 }
516 }
517
518 impl ComputeBackend for MockBackend {
519 fn name(&self) -> &str {
520 "mock"
521 }
522 fn init(&mut self) -> BackendResult<()> {
523 Ok(())
524 }
525 fn is_initialized(&self) -> bool {
526 true
527 }
528 fn gemm(
529 &self,
530 _trans_a: BackendTranspose,
531 _trans_b: BackendTranspose,
532 _m: usize,
533 _n: usize,
534 _k: usize,
535 _alpha: f64,
536 _a_ptr: u64,
537 _lda: usize,
538 _b_ptr: u64,
539 _ldb: usize,
540 _beta: f64,
541 _c_ptr: u64,
542 _ldc: usize,
543 ) -> BackendResult<()> {
544 self.gemm_call_count.fetch_add(1, Ordering::Relaxed);
545 Ok(())
546 }
547 fn conv2d_forward(
548 &self,
549 _: u64,
550 _: &[usize],
551 _: u64,
552 _: &[usize],
553 _: u64,
554 _: &[usize],
555 _: &[usize],
556 _: &[usize],
557 ) -> BackendResult<()> {
558 Ok(())
559 }
560 fn attention(
561 &self,
562 _: u64,
563 _: u64,
564 _: u64,
565 _: u64,
566 _: usize,
567 _: usize,
568 _: usize,
569 _: usize,
570 _: usize,
571 _: f64,
572 _: bool,
573 ) -> BackendResult<()> {
574 Ok(())
575 }
576 fn reduce(&self, _: ReduceOp, _: u64, _: u64, _: &[usize], _: usize) -> BackendResult<()> {
577 Ok(())
578 }
579 fn unary(&self, _: UnaryOp, _: u64, _: u64, _: usize) -> BackendResult<()> {
580 Ok(())
581 }
582 fn binary(&self, _: BinaryOp, _: u64, _: u64, _: u64, _: usize) -> BackendResult<()> {
583 Ok(())
584 }
585 fn synchronize(&self) -> BackendResult<()> {
586 Ok(())
587 }
588 fn alloc(&self, _: usize) -> BackendResult<u64> {
589 Ok(0)
590 }
591 fn free(&self, _: u64) -> BackendResult<()> {
592 Ok(())
593 }
594 fn copy_htod(&self, _: u64, _: &[u8]) -> BackendResult<()> {
595 Ok(())
596 }
597 fn copy_dtoh(&self, _: &mut [u8], _: u64) -> BackendResult<()> {
598 Ok(())
599 }
600 }
601
602 #[test]
603 fn batched_gemm_zero_batch_is_noop() {
604 let backend = MockBackend::new();
605 let result = backend.batched_gemm(
606 BackendTranspose::NoTrans,
607 BackendTranspose::NoTrans,
608 4,
609 4,
610 4,
611 1.0,
612 0,
613 4,
614 16,
615 0,
616 4,
617 16,
618 0.0,
619 0,
620 4,
621 16,
622 0, );
624 assert!(result.is_ok());
625 assert_eq!(backend.gemm_call_count.load(Ordering::Relaxed), 0);
626 }
627
628 #[test]
629 fn batched_gemm_default_calls_gemm_n_times() {
630 let backend = MockBackend::new();
631 let batch_count = 7;
632 let result = backend.batched_gemm(
633 BackendTranspose::NoTrans,
634 BackendTranspose::Trans,
635 8,
636 8,
637 8,
638 1.0,
639 1000,
640 8,
641 64,
642 2000,
643 8,
644 64,
645 0.0,
646 3000,
647 8,
648 64,
649 batch_count,
650 );
651 assert!(result.is_ok());
652 assert_eq!(backend.gemm_call_count.load(Ordering::Relaxed), batch_count);
653 }
654
655 #[test]
656 fn batched_gemm_single_batch() {
657 let backend = MockBackend::new();
658 let result = backend.batched_gemm(
659 BackendTranspose::NoTrans,
660 BackendTranspose::NoTrans,
661 16,
662 16,
663 16,
664 1.0,
665 0,
666 16,
667 256,
668 0,
669 16,
670 256,
671 1.0,
672 0,
673 16,
674 256,
675 1,
676 );
677 assert!(result.is_ok());
678 assert_eq!(backend.gemm_call_count.load(Ordering::Relaxed), 1);
679 }
680
681 #[test]
682 fn enum_clone_and_hash() {
683 use std::collections::HashSet;
684
685 let mut set = HashSet::new();
686 set.insert(ReduceOp::Sum);
687 set.insert(ReduceOp::Max);
688 assert!(set.contains(&ReduceOp::Sum));
689 assert!(!set.contains(&ReduceOp::Min));
690
691 let op = UnaryOp::Relu;
693 let cloned = op;
694 assert_eq!(op, cloned);
695
696 let bop = BinaryOp::Add;
697 let bcloned = bop;
698 assert_eq!(bop, bcloned);
699
700 let trans = BackendTranspose::ConjTrans;
701 let tcloned = trans;
702 assert_eq!(trans, tcloned);
703 }
704}