1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
//! # baracuda-cutlass
//!
//! Safe Rust wrapper for compiled CUTLASS kernels in the baracuda
//! ecosystem. Plan-based GEMM and grouped-GEMM API with caller-supplied
//! workspace, typed device-buffer arguments, and capture-safe launches.
//!
//! See the crate `README.md` for the v0 scope and the design rationale.
//! See [`baracuda-cutlass-kernels-sys`] for the underlying compiled
//! template instantiations.
//!
//! ## Quick start
//!
//! ```rust,no_run
//! use baracuda_cutlass::{
//! EpilogueKind, GemmArgs, GemmDescriptor, GemmPlan, LayoutSku,
//! MatrixMut, MatrixRef, PlanPreference, Workspace,
//! };
//! use baracuda_driver::{Context, Device, DeviceBuffer, Stream};
//! use half::f16;
//!
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let ctx = Context::new(&Device::get(0)?)?;
//! let stream = Stream::new(&ctx)?;
//!
//! let m = 128i32; let n = 128i32; let k = 128i32;
//! let dev_a: DeviceBuffer<f16> = DeviceBuffer::zeros(&ctx, (m * k) as usize)?;
//! let dev_b: DeviceBuffer<f16> = DeviceBuffer::zeros(&ctx, (k * n) as usize)?;
//! let mut dev_d: DeviceBuffer<f16> = DeviceBuffer::zeros(&ctx, (m * n) as usize)?;
//!
//! let desc = GemmDescriptor {
//! m, n, k,
//! layout: LayoutSku::Rcr,
//! epilogue: EpilogueKind::Identity,
//! };
//! let plan = GemmPlan::<f16>::select(&stream, &desc, PlanPreference::default())?;
//! let args = GemmArgs::<f16> {
//! a: MatrixRef { data: dev_a.as_slice(), rows: m, cols: k, ld: k as i64 },
//! b: MatrixRef { data: dev_b.as_slice(), rows: k, cols: n, ld: k as i64 },
//! c: None,
//! d: MatrixMut { data: dev_d.as_slice_mut(), rows: m, cols: n, ld: n as i64 },
//! bias: None,
//! alpha: 1.0,
//! beta: 0.0,
//! };
//! plan.can_implement(&args)?;
//! plan.run(&stream, Workspace::None, args)?;
//! # Ok(()) }
//! ```
//!
//! [`baracuda-cutlass-kernels-sys`]: https://docs.rs/baracuda-cutlass-kernels-sys
pub use ;
pub use ;
pub use ;