1pub mod data;
29pub mod distributed;
30pub mod error;
31pub mod format;
32pub mod inference;
33pub mod model;
34pub mod nn;
35pub mod ops;
36pub mod optimizer;
37pub mod quant;
38pub mod trainer;
39
40pub use nn::{Init, VarBuilder, VarMap, Weight};
42pub use ops::{
43 AttentionOps, DeviceGrammarDfa, FlashAttentionOps, FusedFp8TrainingOps, FusedOptimizerOps,
44 FusedQkvOps, GrammarDfaOps, KvCacheOps, MlaOps, PagedAttentionOps, RoPEOps, SamplingOps,
45 var_flash_attention,
46};
47pub use quant::{
48 DecomposedQuantLinear, DecomposedQuantMethod, DecomposedQuantTensor, DequantOps, FusedQuantOps,
49 QuantFormat, QuantMatmulOps, QuantTensor,
50};
51
52pub use numr::dtype::DType;
54pub use numr::error::{Error as NumrError, Result as NumrResult};
55pub use numr::runtime::{Runtime, RuntimeClient};
56pub use numr::tensor::Tensor;
57
58pub use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
60#[cfg(feature = "cuda")]
61pub use numr::runtime::cuda::{CudaClient, CudaDevice, CudaRuntime};
62
63pub use numr::autograd;
65pub use numr::runtime;
66pub use numr::tensor;
67
68pub use ops::TensorOps;
70
71pub use model::ExpertWeights;
73
74pub use format::GgufTokenizer;
76pub use model::encoder::{EmbeddingPipeline, Encoder, EncoderClient, EncoderConfig, Pooling};
77
78pub use numr::ops::traits::IndexingOps;
80
81pub use numr::ops::ScalarOps;
83
84pub use numr::ops::{
86 ActivationOps, BinaryOps, ConvOps, NormalizationOps, TypeConversionOps, UnaryOps,
87};
88
89#[cfg(feature = "cuda")]
95pub fn preload_inference_modules(client: &CudaClient) -> Result<(), error::Error> {
96 use numr::runtime::Device;
97 use numr::runtime::cuda::kernels::kernel_names;
98
99 client
101 .preload_modules(&[
102 kernel_names::BINARY_MODULE,
103 kernel_names::UNARY_MODULE,
104 kernel_names::SCALAR_MODULE,
105 kernel_names::REDUCE_MODULE,
106 kernel_names::ACTIVATION_MODULE,
107 kernel_names::SOFTMAX_MODULE,
108 kernel_names::NORM_MODULE,
109 kernel_names::FUSED_ADD_NORM_MODULE,
110 kernel_names::CAST_MODULE,
111 kernel_names::UTILITY_MODULE,
112 kernel_names::MATMUL_MODULE,
113 kernel_names::GEMV_MODULE,
114 ])
115 .map_err(error::Error::Numr)?;
116
117 ops::cuda::kernels::preload_modules(
119 client.context(),
120 client.device().id(),
121 &[
122 ops::cuda::kernels::ROPE_MODULE,
123 ops::cuda::kernels::DECODE_ATTENTION_MODULE,
124 ops::cuda::kernels::PAGED_DECODE_ATTENTION_MODULE,
125 ops::cuda::kernels::PAGED_ATTENTION_MODULE,
126 ops::cuda::kernels::FLASH_V2_MODULE,
127 ops::cuda::kernels::KV_CACHE_UPDATE_MODULE,
128 ops::cuda::kernels::RESHAPE_AND_CACHE_MODULE,
129 ],
130 )?;
131
132 quant::cuda::kernels::preload_modules(
134 client.context(),
135 client.device().id(),
136 &[
137 quant::cuda::kernels::DEQUANT_MODULE,
138 quant::cuda::kernels::QUANT_MATMUL_MODULE,
139 quant::cuda::kernels::QUANT_GEMV_MODULE,
140 quant::cuda::kernels::QUANT_ACT_MODULE,
141 ],
142 )?;
143
144 Ok(())
145}
146
147#[cfg(test)]
148pub(crate) mod test_utils {
149 use numr::runtime::cpu::{CpuClient, CpuDevice};
150
151 pub(crate) fn cpu_setup() -> (CpuClient, CpuDevice) {
153 let device = CpuDevice::new();
154 let client = CpuClient::new(device.clone());
155 (client, device)
156 }
157}