pub mod data;
pub mod distributed;
pub mod error;
pub mod format;
pub mod inference;
pub mod model;
pub mod nn;
pub mod ops;
pub mod optimizer;
pub mod quant;
pub mod trainer;
pub use nn::{Init, VarBuilder, VarMap, Weight};
pub use ops::{
AttentionOps, DeviceGrammarDfa, FlashAttentionOps, FusedFp8TrainingOps, FusedOptimizerOps,
FusedQkvOps, GrammarDfaOps, KvCacheOps, MlaOps, PagedAttentionOps, RoPEOps, SamplingOps,
var_flash_attention,
};
pub use quant::{
DecomposedQuantLinear, DecomposedQuantMethod, DecomposedQuantTensor, DequantOps, FusedQuantOps,
QuantFormat, QuantMatmulOps, QuantTensor,
};
pub use numr::dtype::DType;
pub use numr::error::{Error as NumrError, Result as NumrResult};
pub use numr::runtime::{Runtime, RuntimeClient};
pub use numr::tensor::Tensor;
pub use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
#[cfg(feature = "cuda")]
pub use numr::runtime::cuda::{CudaClient, CudaDevice, CudaRuntime};
pub use numr::autograd;
pub use numr::runtime;
pub use numr::tensor;
pub use ops::TensorOps;
pub use model::ExpertWeights;
pub use format::GgufTokenizer;
pub use model::encoder::{EmbeddingPipeline, Encoder, EncoderClient, EncoderConfig, Pooling};
pub use numr::ops::traits::IndexingOps;
pub use numr::ops::ScalarOps;
pub use numr::ops::{
ActivationOps, BinaryOps, ConvOps, NormalizationOps, TypeConversionOps, UnaryOps,
};
#[cfg(feature = "cuda")]
pub fn preload_inference_modules(client: &CudaClient) -> Result<(), error::Error> {
use numr::runtime::Device;
use numr::runtime::cuda::kernels::kernel_names;
client
.preload_modules(&[
kernel_names::BINARY_MODULE,
kernel_names::UNARY_MODULE,
kernel_names::SCALAR_MODULE,
kernel_names::REDUCE_MODULE,
kernel_names::ACTIVATION_MODULE,
kernel_names::SOFTMAX_MODULE,
kernel_names::NORM_MODULE,
kernel_names::FUSED_ADD_NORM_MODULE,
kernel_names::CAST_MODULE,
kernel_names::UTILITY_MODULE,
kernel_names::MATMUL_MODULE,
kernel_names::GEMV_MODULE,
])
.map_err(error::Error::Numr)?;
ops::cuda::kernels::preload_modules(
client.context(),
client.device().id(),
&[
ops::cuda::kernels::ROPE_MODULE,
ops::cuda::kernels::DECODE_ATTENTION_MODULE,
ops::cuda::kernels::PAGED_DECODE_ATTENTION_MODULE,
ops::cuda::kernels::PAGED_ATTENTION_MODULE,
ops::cuda::kernels::FLASH_V2_MODULE,
ops::cuda::kernels::KV_CACHE_UPDATE_MODULE,
ops::cuda::kernels::RESHAPE_AND_CACHE_MODULE,
],
)?;
quant::cuda::kernels::preload_modules(
client.context(),
client.device().id(),
&[
quant::cuda::kernels::DEQUANT_MODULE,
quant::cuda::kernels::QUANT_MATMUL_MODULE,
quant::cuda::kernels::QUANT_GEMV_MODULE,
quant::cuda::kernels::QUANT_ACT_MODULE,
],
)?;
Ok(())
}
#[cfg(test)]
pub(crate) mod test_utils {
use numr::runtime::cpu::{CpuClient, CpuDevice};
pub(crate) fn cpu_setup() -> (CpuClient, CpuDevice) {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
(client, device)
}
}