Skip to main content

ferrotorch_cubecl/
lib.rs

1//! Portable GPU backend for ferrotorch via CubeCL.
2//!
3//! CubeCL compiles a single kernel definition to CUDA PTX, AMD HIP/ROCm, and
4//! WGPU (Vulkan/Metal/DX12). This crate wraps CubeCL's runtime and dispatches
5//! real `#[cube]` kernels to the active backend — no CPU fallbacks.
6//!
7//! # Feature flags
8//!
9//! | Feature | Backend              | GPU vendors            |
10//! |---------|----------------------|------------------------|
11//! | `cuda`  | NVIDIA CUDA via PTX  | NVIDIA                 |
12//! | `wgpu`  | WGPU (Vulkan/Metal)  | AMD, Intel, Apple, ... |
13//! | `rocm`  | AMD HIP (native)     | AMD                    |
14//!
15//! Enable at least one backend feature to use GPU acceleration. Without any
16//! backend feature [`CubeRuntime::new`] returns
17//! `FerrotorchError::DeviceUnavailable` and [`CubeRuntime::auto`] returns
18//! `None`.
19//!
20//! # Example
21//!
22//! ```rust,no_run
23//! use ferrotorch_cubecl::{CubeDevice, CubeRuntime};
24//!
25//! // Auto-detect the best available backend
26//! if let Some(rt) = CubeRuntime::auto() {
27//!     println!("Using device: {:?}", rt.device());
28//! }
29//! ```
30//!
31//! ## REQ status (per `.design/ferrotorch-cubecl/lib.md`)
32//!
33//! Full evidence rows (impl + non-test production consumer + upstream
34//! cites) live in the design doc; this synopsis is a one-line summary per
35//! REQ.
36//!
37//! | REQ | Status | Evidence |
38//! |---|---|---|
39//! | REQ-1 (public module surface) | SHIPPED | `pub mod grammar/kernels/ops/quant/runtime/storage` in `lib.rs`; consumer `ferrotorch-xpu/src/lib.rs` imports `ferrotorch_cubecl::{CubeDevice, CubeRuntime, upload_f32, wrap_kernel_output}` |
40//! | REQ-2 (feature-flag wiring) | SHIPPED | `cuda`/`wgpu`/`rocm` feature gates in `Cargo.toml` + `make_client` cfg arms in `runtime.rs`; no-backend path pinned by `runtime_construction_errors_without_backend` in `ops.rs` |
41//! | REQ-3 (boundary re-exports) | SHIPPED | `pub use runtime::*` / `storage::*` / `quant::*` / `grammar::*` in `lib.rs`; consumers `ferrotorch-xpu/src/lib.rs` + `ferrotorch-grammar/src/gpu_dispatch.rs` reach names via `ferrotorch_cubecl::Foo` |
42//! | REQ-4 (crate-internal launch helpers) | SHIPPED | `pub(crate) fn elementwise_launch_dims` + `pub(crate) fn debug_assert_handle_capacity` in `lib.rs`; consumers `kernels::run_unary`/`run_binary_handle`, `quant::dequantize_q4_0_to_gpu`, `grammar::compute_token_mask_dfa_to_gpu` |
43//! | REQ-5 (lint baseline) | SHIPPED | `#![warn(clippy::all, clippy::pedantic)]` + `#![deny(rust_2018_idioms, missing_debug_implementations)]` at top of `lib.rs`; verified by `cargo clippy -p ferrotorch-cubecl --no-default-features -- -D warnings` |
44
45#![warn(clippy::all, clippy::pedantic)]
46#![deny(rust_2018_idioms, missing_debug_implementations)]
47// Rustdoc coverage is being swept workspace-wide in a separate dispatch
48// (tracked workspace-wide rustdoc pass); matches the gpu / jit precedent
49// until that lands.
50#![allow(missing_docs)]
51// Pedantic lints we explicitly accept across this crate. Each allow names a
52// concrete reason — the alternative would be churn-for-zero-benefit, a
53// worse API, or scope-creep into frozen files (storage.rs, runtime.rs).
54// Mirrors the ferrotorch-gpu / ferrotorch-jit baseline; add only with a
55// one-line justification.
56#![allow(
57    // Doc prose includes `CubeCL`, `GGUF`, `Q4_0`, etc. — surrounding every
58    // such word in backticks would hurt readability for technical prose.
59    clippy::doc_markdown,
60    // # Errors / # Panics sections will be added in the workspace-wide
61    // rustdoc pass; not gated on this lint baseline.
62    clippy::missing_errors_doc,
63    clippy::missing_panics_doc,
64    // Numeric ML code casts pervasively between `usize` and `u32` for buffer
65    // sizes, dimensions, and CubeCL launch arithmetic; explicit `as` is more
66    // readable than `try_into().unwrap()` cluttering hot paths.
67    clippy::cast_possible_truncation,
68    clippy::cast_possible_wrap,
69    clippy::cast_sign_loss,
70    clippy::cast_precision_loss,
71    clippy::cast_lossless,
72    // `#[must_use]` on every getter is churn for marginal value; existing
73    // callers already use the returned values.
74    clippy::must_use_candidate,
75    // Math kernels naturally use single-character names (m, k, n for matmul
76    // dims; a, b for binary operands); requiring longer names hurts
77    // readability.
78    clippy::many_single_char_names,
79    // Pre-existing pedantic warnings in this crate's frozen files
80    // (storage.rs, runtime.rs) and the dataplane-heavy quant.rs / kernels.rs
81    // bodies are tracked for the cubecl-B SAFETY substantiation pass and a
82    // workspace-wide rustdoc / format-args sweep. Allowing them keeps
83    // `-D warnings` viable now without scope-creeping into frozen files.
84    clippy::ptr_as_ptr,
85    clippy::uninlined_format_args,
86)]
87
88pub mod grammar;
89pub mod kernels;
90pub mod ops;
91pub mod quant;
92pub mod runtime;
93pub mod storage;
94
95// ---------------------------------------------------------------------------
96// Crate-internal helpers shared across modules
97// ---------------------------------------------------------------------------
98
99/// Choose a 1-D cube count and cube dim that cover `n` elements when each
100/// unit processes exactly one element.
101///
102/// 256 units per cube is a safe default across all backends (wgpu, cuda,
103/// rocm). Returned as a tuple ready to feed into `kernel::launch_unchecked`.
104///
105/// Callers: [`kernels`], [`quant`], [`grammar`]. Previously duplicated
106/// verbatim across all three modules; consolidated here so launch geometry
107/// stays consistent.
108pub(crate) fn elementwise_launch_dims(
109    n: u32,
110) -> (cubecl::prelude::CubeCount, cubecl::prelude::CubeDim) {
111    let units_per_cube: u32 = 256;
112    let num_cubes = n.div_ceil(units_per_cube).max(1);
113    (
114        cubecl::prelude::CubeCount::Static(num_cubes, 1, 1),
115        cubecl::prelude::CubeDim::new_1d(units_per_cube),
116    )
117}
118
119/// Debug-build runtime check that a cubecl `Handle` has at least
120/// `n * size_of::<T>()` bytes capacity. Release builds elide via
121/// `debug_assert!`. Use before `ArrayArg::from_raw_parts(handle, n)`
122/// for caller-provided handles where the cubecl-side `unsafe` API
123/// requires the byte capacity to match.
124///
125/// `T` carries the kernel-side element type (e.g. `f32`) so that the
126/// byte stride is computed from `size_of::<T>()`. Closes #717 / #718.
127pub(crate) fn debug_assert_handle_capacity<T>(handle: &cubecl::server::Handle, n: usize) {
128    debug_assert!(
129        handle.size() as usize >= n.saturating_mul(std::mem::size_of::<T>()),
130        "cubecl handle capacity {} bytes < required {} bytes ({} elements x {} byte stride)",
131        handle.size(),
132        n.saturating_mul(std::mem::size_of::<T>()),
133        n,
134        std::mem::size_of::<T>(),
135    );
136}
137
138// Re-export runtime types.
139pub use runtime::{CubeClient, CubeDevice, CubeRuntime};
140
141// Re-export storage handle types and upload/wrapping helpers.
142pub use storage::{CubeclStorageHandle, cubecl_handle_of, upload_f32, wrap_kernel_output};
143
144// Re-export quantized-weight dequantization API.
145pub use quant::{
146    GgufBlockKind, dequantize_q4_0_to_gpu, dequantize_q4_1_to_gpu, dequantize_q5_0_to_gpu,
147    dequantize_q5_1_to_gpu, dequantize_q8_0_to_gpu, dequantize_q8_1_to_gpu, split_q4_0_blocks,
148    split_q4_1_blocks, split_q5_0_blocks, split_q5_1_blocks, split_q8_0_blocks, split_q8_1_blocks,
149};
150
151// Re-export GPU constrained-decoding token-mask compute API.
152pub use grammar::{DfaMaskInputs, compute_token_mask_dfa_to_gpu, kernel_compute_token_mask_dfa};