Expand description
Safe, typed Rust wrappers for NVIDIA FlashInfer’s inference- serving kernels — the vLLM-style serving surface for the baracuda CUDA stack.
FlashInfer (flashinfer-ai/flashinfer, Apache-2.0) is an
inference-focused attention / sampling kernel library. baracuda
vendors a cherry-picked subset of it (see
crates/baracuda-kernels-sys/vendor/flashinfer/) and exposes typed
*Plan wrappers — the same select / can_implement / run shape used
across the rest of the baracuda kernel surface.
This crate is a thin, cohesive safe facade: the plan
implementations live in baracuda_kernels (so they integrate with
the shared SKU / autotuner / telemetry machinery), and the raw FFI
lives in baracuda_flashinfer_sys. This crate re-groups those into
a serving-oriented API and is the documented entry point.
§Capabilities
§attention — paged-KV attention
- [
BatchPagedDecodePlan] — batched paged-KV decode (one query row per request, KV history in a paged store). The core vLLM serving primitive. f16 / bf16 / f32, head_dim ∈ {64, 128, 256}. - [
BatchPagedPrefillPlan] — batched paged-KV prefill (multiple query rows per request, ragged viaq_indptr, attending over the paged history; causal or full; optional KV-split parallelism for long-context / few-request). The prompt-ingestion primitive. f16 / bf16. - [
BatchRaggedPrefillPlan] — prefill over a contiguous (non-paged) ragged KV store (kv_indptr), for the not-yet-paged path. - [
PagedKvAppendPlan] — decode-time KV-cache append (writes the freshly-computed K/V for the current token into the paged store). - [
CascadeAttentionPlan] — LSE-aware pairwise merge of partial attention states, the building block for prefix-cache / shared-prompt reuse. - [
CascadeMergeStatesPlan] — many-way (fan-in > 2) cascade merge, for multi-level shared-prefix trees / overlapping prefix caches. - [
BatchPagedDecodeFp8Plan] — paged decode with an fp8 KV cache (e4m3 / e5m2), Q/O in f16/bf16. Halves KV bandwidth + footprint.
§sampling — sort-free token sampling + verification
- [
TopKTopPSamplingPlan] — top-K / top-P / min-P / combined top-K+top-P sampling directly from a probability tensor, with no global sort. Select the variant via [SamplerKind]. - [
PerRowSamplingPlan] — the same samplers with per-request thresholds supplied as device arrays. - [
SpeculativeSamplingPlan] — speculative-decode accept/reject verification (ChainSpeculativeSampling). - [
TokenPenaltyPlan] — repetition / frequency / presence penalty logit transform (native baracuda op; not feature-gated).
§Feature gating
The FlashInfer-backed kernels are behind the flashinfer cargo feature
(OFF by default). With the feature off the plan types still exist and
select / can_implement still validate shapes, but run returns
Error::Unsupported. ([TokenPenaltyPlan] is a native baracuda op and
runs regardless.) Enable the feature to compile the vendored kernels.
§Example shape (paged decode)
use baracuda_flashinfer::{
BatchPagedDecodePlan, BatchPagedDecodeDescriptor, PagedKvCacheDescriptor,
ElementKind, PlanPreference,
};
use half::f16;
let desc = BatchPagedDecodeDescriptor {
batch_size: 8,
num_qo_heads: 32,
sm_scale: 1.0 / (128.0_f32).sqrt(),
paged_kv: PagedKvCacheDescriptor {
page_size: 16,
num_total_pages: 1024,
num_kv_heads: 8, // GQA group size 4
head_dim: 128,
element: ElementKind::F16,
},
};
let plan = BatchPagedDecodePlan::<f16>::select(stream, &desc, PlanPreference::default())?;
let _ws_bytes = plan.workspace_size();
// ... allocate workspace + page table, then plan.run(stream, ws, args)Re-exports§
pub use baracuda_flashinfer_sys as sys;
Modules§
- attention
- Paged-KV attention plans (decode, append, cascade merge).
- prelude
- Glob-importable common surface:
use baracuda_flashinfer::prelude::*;. - sampling
- Sort-free token sampling plans.
Structs§
- Plan
Preference - Support types used across the FlashInfer plan APIs.
Hints that influence kernel selection inside a plan’s
selectmethod. - Precision
Guarantee - Support types used across the FlashInfer plan APIs. Numerical guarantees a kernel provides.
- Tensor
Mut - Support types used across the FlashInfer plan APIs.
Mutable view of a device-resident rank-
Ntensor. - Tensor
Ref - Support types used across the FlashInfer plan APIs.
Read-only view of a device-resident rank-
Ntensor.
Enums§
- Backend
Kind - Support types used across the FlashInfer plan APIs. Which underlying compute backend served a kernel SKU.
- Element
Kind - Support types used across the FlashInfer plan APIs.
Runtime tag for an
ElementorIntElement. - Error
- Shared error / result type (re-exported from the baracuda kernel surface so callers don’t need a separate import). Errors raised by the safe CUTLASS wrapper.
- Workspace
- Support types used across the FlashInfer plan APIs. Caller-supplied workspace for a launch.
Functions§
- contiguous_
stride - Support types used across the FlashInfer plan APIs.
Compute the row-major contiguous stride for the given
shape.
Type Aliases§
- Result
- Shared error / result type (re-exported from the baracuda kernel surface so callers don’t need a separate import). Crate-local result alias.