Skip to main content

boostr/
lib.rs

1//! # boostr
2//!
3//! **ML framework built on numr — attention, quantization, model architectures.**
4//!
5//! boostr extends numr's foundational numerical computing with ML-specific operations,
6//! quantized tensor support, and model building blocks. It uses numr's runtime, tensors,
7//! and ops directly — no reimplementation, no wrappers.
8//!
9//! ## Relationship to numr
10//!
11//! ```text
12//! ┌─────────────────────────────────────────────────────────┐
13//! │                    boostr ◄── YOU ARE HERE               │
14//! │   (attention, RoPE, MoE, quantization, model loaders)   │
15//! └──────────────────────────┬──────────────────────────────┘
16//! │                      numr                                │
17//! │     (tensors, ops, runtime, autograd, linalg, FFT)       │
18//! └─────────────────────────────────────────────────────────┘
19//! ```
20//!
21//! ## Design
22//!
23//! - **Extension traits**: ML ops (AttentionOps, RoPEOps) implemented on numr's clients
24//! - **QuantTensor**: Separate type for block-quantized data (GGUF formats)
25//! - **impl_generic**: Composite ops composed from numr primitives, same on all backends
26//! - **Custom kernels**: Dequant, quantized matmul, fused attention (SIMD/PTX/WGSL)
27
28pub mod data;
29pub mod distributed;
30pub mod error;
31pub mod format;
32pub mod inference;
33pub mod model;
34pub mod nn;
35pub mod ops;
36pub mod optimizer;
37pub mod quant;
38pub mod trainer;
39
40// Re-export primary boostr traits
41pub use nn::{Init, VarBuilder, VarMap, Weight};
42pub use ops::{
43    AttentionOps, DeviceGrammarDfa, FlashAttentionOps, FusedFp8TrainingOps, FusedOptimizerOps,
44    FusedQkvOps, GrammarDfaOps, KvCacheOps, MlaOps, PagedAttentionOps, RoPEOps, SamplingOps,
45    var_flash_attention,
46};
47pub use quant::{
48    DecomposedQuantLinear, DecomposedQuantMethod, DecomposedQuantTensor, DequantOps, FusedQuantOps,
49    QuantFormat, QuantMatmulOps, QuantTensor,
50};
51
52// Re-export numr types that users will commonly need
53pub use numr::dtype::DType;
54pub use numr::error::{Error as NumrError, Result as NumrResult};
55pub use numr::runtime::{Runtime, RuntimeClient};
56pub use numr::tensor::Tensor;
57
58// Re-export runtime types for convenience (blazr uses boostr::CpuRuntime, etc.)
59pub use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
60#[cfg(feature = "cuda")]
61pub use numr::runtime::cuda::{CudaClient, CudaDevice, CudaRuntime};
62
63// Re-export numr modules for path-based access (e.g., boostr::runtime::Device)
64pub use numr::autograd;
65pub use numr::runtime;
66pub use numr::tensor;
67
68// Re-export TensorOps as a trait alias that blazr uses for client bounds
69pub use ops::TensorOps;
70
71// Re-export MoE expert weight types for blazr's expert offloading
72pub use model::ExpertWeights;
73
74// Re-export embedding pipeline and GGUF tokenizer for sentence embedding use cases
75pub use format::GgufTokenizer;
76pub use model::encoder::{EmbeddingPipeline, Encoder, EncoderClient, EncoderConfig, Pooling};
77
78// Re-export IndexingOps for KV cache bounds
79pub use numr::ops::traits::IndexingOps;
80
81// Re-export ScalarOps for blazr's temperature scaling
82pub use numr::ops::ScalarOps;
83
84// Re-export numr ops needed by blazr's Mamba2 inference path
85pub use numr::ops::{
86    ActivationOps, BinaryOps, ConvOps, NormalizationOps, TypeConversionOps, UnaryOps,
87};
88
89/// Pre-load all CUDA PTX modules needed for LLaMA inference.
90///
91/// This front-loads all PTX→SASS JIT compilation during warmup,
92/// eliminating ~300ms latency on the first decode token.
93/// Call this once after creating the CudaClient, before any real inference.
94#[cfg(feature = "cuda")]
95pub fn preload_inference_modules(client: &CudaClient) -> Result<(), error::Error> {
96    use numr::runtime::Device;
97    use numr::runtime::cuda::kernels::kernel_names;
98
99    // numr core modules used by LLaMA inference
100    client
101        .preload_modules(&[
102            kernel_names::BINARY_MODULE,
103            kernel_names::UNARY_MODULE,
104            kernel_names::SCALAR_MODULE,
105            kernel_names::REDUCE_MODULE,
106            kernel_names::ACTIVATION_MODULE,
107            kernel_names::SOFTMAX_MODULE,
108            kernel_names::NORM_MODULE,
109            kernel_names::FUSED_ADD_NORM_MODULE,
110            kernel_names::CAST_MODULE,
111            kernel_names::UTILITY_MODULE,
112            kernel_names::MATMUL_MODULE,
113            kernel_names::GEMV_MODULE,
114        ])
115        .map_err(error::Error::Numr)?;
116
117    // boostr ops modules
118    ops::cuda::kernels::preload_modules(
119        client.context(),
120        client.device().id(),
121        &[
122            ops::cuda::kernels::ROPE_MODULE,
123            ops::cuda::kernels::DECODE_ATTENTION_MODULE,
124            ops::cuda::kernels::PAGED_DECODE_ATTENTION_MODULE,
125            ops::cuda::kernels::PAGED_ATTENTION_MODULE,
126            ops::cuda::kernels::FLASH_V2_MODULE,
127            ops::cuda::kernels::KV_CACHE_UPDATE_MODULE,
128            ops::cuda::kernels::RESHAPE_AND_CACHE_MODULE,
129        ],
130    )?;
131
132    // boostr quant modules (for GGUF inference)
133    quant::cuda::kernels::preload_modules(
134        client.context(),
135        client.device().id(),
136        &[
137            quant::cuda::kernels::DEQUANT_MODULE,
138            quant::cuda::kernels::QUANT_MATMUL_MODULE,
139            quant::cuda::kernels::QUANT_GEMV_MODULE,
140            quant::cuda::kernels::QUANT_ACT_MODULE,
141        ],
142    )?;
143
144    Ok(())
145}
146
147#[cfg(test)]
148pub(crate) mod test_utils {
149    use numr::runtime::cpu::{CpuClient, CpuDevice};
150
151    /// Create a CPU client and device for use in unit tests.
152    pub(crate) fn cpu_setup() -> (CpuClient, CpuDevice) {
153        let device = CpuDevice::new();
154        let client = CpuClient::new(device.clone());
155        (client, device)
156    }
157}