baracuda_flashinfer/lib.rs
1//! Safe, typed Rust wrappers for NVIDIA **FlashInfer**'s inference-
2//! serving kernels — the vLLM-style serving surface for the baracuda
3//! CUDA stack.
4//!
5//! FlashInfer (`flashinfer-ai/flashinfer`, Apache-2.0) is an
6//! inference-focused attention / sampling kernel library. baracuda
7//! vendors a cherry-picked subset of it (see
8//! `crates/baracuda-kernels-sys/vendor/flashinfer/`) and exposes typed
9//! `*Plan` wrappers — the same select / can_implement / run shape used
10//! across the rest of the baracuda kernel surface.
11//!
12//! This crate is a thin, cohesive **safe facade**: the plan
13//! implementations live in [`baracuda_kernels`] (so they integrate with
14//! the shared SKU / autotuner / telemetry machinery), and the raw FFI
15//! lives in [`baracuda_flashinfer_sys`]. This crate re-groups those into
16//! a serving-oriented API and is the documented entry point.
17//!
18//! # Capabilities
19//!
20//! ## [`attention`] — paged-KV attention
21//!
22//! - [`BatchPagedDecodePlan`] — batched paged-KV **decode** (one query
23//! row per request, KV history in a paged store). The core vLLM
24//! serving primitive. f16 / bf16 / f32, head_dim ∈ {64, 128, 256}.
25//! - [`BatchPagedPrefillPlan`] — batched paged-KV **prefill** (multiple
26//! query rows per request, ragged via `q_indptr`, attending over the
27//! paged history; causal or full; optional KV-split parallelism for
28//! long-context / few-request). The prompt-ingestion primitive. f16 /
29//! bf16.
30//! - [`BatchRaggedPrefillPlan`] — prefill over a **contiguous** (non-paged)
31//! ragged KV store (`kv_indptr`), for the not-yet-paged path.
32//! - [`PagedKvAppendPlan`] — decode-time KV-cache **append** (writes the
33//! freshly-computed K/V for the current token into the paged store).
34//! - [`CascadeAttentionPlan`] — LSE-aware pairwise **merge** of partial
35//! attention states, the building block for prefix-cache / shared-prompt
36//! reuse.
37//! - [`CascadeMergeStatesPlan`] — many-way (fan-in > 2) cascade merge, for
38//! multi-level shared-prefix trees / overlapping prefix caches.
39//! - [`BatchPagedDecodeFp8Plan`] — paged decode with an **fp8** KV cache
40//! (e4m3 / e5m2), Q/O in f16/bf16. Halves KV bandwidth + footprint.
41//!
42//! ## [`sampling`] — sort-free token sampling + verification
43//!
44//! - [`TopKTopPSamplingPlan`] — top-K / top-P / min-P / combined
45//! top-K+top-P sampling directly from a probability tensor, with no
46//! global sort. Select the variant via [`SamplerKind`].
47//! - [`PerRowSamplingPlan`] — the same samplers with per-request
48//! thresholds supplied as device arrays.
49//! - [`SpeculativeSamplingPlan`] — speculative-decode accept/reject
50//! verification (`ChainSpeculativeSampling`).
51//! - [`TokenPenaltyPlan`] — repetition / frequency / presence penalty
52//! logit transform (native baracuda op; not feature-gated).
53//!
54//! # Feature gating
55//!
56//! The FlashInfer-backed kernels are behind the `flashinfer` cargo feature
57//! (OFF by default). With the feature off the plan types still exist and
58//! `select` / `can_implement` still validate shapes, but `run` returns
59//! `Error::Unsupported`. ([`TokenPenaltyPlan`] is a native baracuda op and
60//! runs regardless.) Enable the feature to compile the vendored kernels.
61//!
62//! # Example shape (paged decode)
63//!
64//! ```no_run
65//! # #[cfg(feature = "flashinfer")]
66//! # fn demo(stream: &baracuda_driver::Stream) -> baracuda_flashinfer::Result<()> {
67//! use baracuda_flashinfer::{
68//! BatchPagedDecodePlan, BatchPagedDecodeDescriptor, PagedKvCacheDescriptor,
69//! ElementKind, PlanPreference,
70//! };
71//! use half::f16;
72//!
73//! let desc = BatchPagedDecodeDescriptor {
74//! batch_size: 8,
75//! num_qo_heads: 32,
76//! sm_scale: 1.0 / (128.0_f32).sqrt(),
77//! paged_kv: PagedKvCacheDescriptor {
78//! page_size: 16,
79//! num_total_pages: 1024,
80//! num_kv_heads: 8, // GQA group size 4
81//! head_dim: 128,
82//! element: ElementKind::F16,
83//! },
84//! };
85//! let plan = BatchPagedDecodePlan::<f16>::select(stream, &desc, PlanPreference::default())?;
86//! let _ws_bytes = plan.workspace_size();
87//! // ... allocate workspace + page table, then plan.run(stream, ws, args)
88//! # Ok(())
89//! # }
90//! ```
91
92#![no_std]
93
94pub mod attention;
95pub mod sampling;
96
97/// Shared error / result type (re-exported from the baracuda kernel
98/// surface so callers don't need a separate import).
99pub use baracuda_kernels::{Error, Result};
100
101/// Support types used across the FlashInfer plan APIs.
102pub use baracuda_kernels::{
103 contiguous_stride, BackendKind, ElementKind, PlanPreference, PrecisionGuarantee, TensorMut,
104 TensorRef, Workspace,
105};
106
107/// Raw C-ABI FFI surface, for callers that need to drop below the safe
108/// layer. Prefer the typed plans above.
109pub use baracuda_flashinfer_sys as sys;
110
111/// Glob-importable common surface: `use baracuda_flashinfer::prelude::*;`.
112pub mod prelude {
113 pub use crate::attention::{
114 BatchPagedDecodeArgs, BatchPagedDecodeDescriptor, BatchPagedDecodeFp8Args,
115 BatchPagedDecodeFp8Descriptor, BatchPagedDecodeFp8Plan, BatchPagedDecodePlan,
116 BatchPagedPrefillArgs, BatchPagedPrefillDescriptor, BatchPagedPrefillPlan,
117 BatchRaggedPrefillArgs, BatchRaggedPrefillDescriptor, BatchRaggedPrefillPlan,
118 CascadeAttentionArgs, CascadeAttentionDescriptor, CascadeAttentionPlan,
119 CascadeMergeStatesArgs, CascadeMergeStatesDescriptor, CascadeMergeStatesPlan, Fp8KvDtype,
120 PagedKvAppendArgs, PagedKvAppendDescriptor, PagedKvAppendPlan, PagedKvCacheDescriptor,
121 };
122 pub use crate::sampling::{
123 PerRowSampler, PerRowSamplingArgs, PerRowSamplingDescriptor, PerRowSamplingPlan,
124 SamplerKind, SpeculativeSamplingArgs, SpeculativeSamplingDescriptor, SpeculativeSamplingPlan,
125 TokenPenaltyArgs, TokenPenaltyDescriptor, TokenPenaltyPlan, TopKTopPSamplingArgs,
126 TopKTopPSamplingDescriptor, TopKTopPSamplingPlan,
127 };
128 pub use crate::{
129 contiguous_stride, BackendKind, ElementKind, Error, PlanPreference, Result, TensorMut,
130 TensorRef, Workspace,
131 };
132}