Skip to main content

Crate oxibonsai_model

Crate oxibonsai_model 

Source
Expand description

§oxibonsai-model

Qwen3 Transformer implementation for 1-bit Bonsai inference.

This crate implements the full autoregressive forward pass for the Qwen3 architecture family (8B, 4B, 1.7B) using 1-bit quantised weights. The forward pass pipeline is:

  1. Token embedding — FP32 lookup from a [vocab_size x hidden_size] table
  2. N Transformer blocks, each containing:
    • Pre-attention RMSNorm
    • Grouped Query Attention (GQA) with rotary position embeddings
    • Pre-FFN RMSNorm
    • SwiGLU MLP (gate + up + down projections)
  3. Final RMSNorm
  4. LM head projection to vocabulary logits

All linear projections in the Transformer blocks use Q1_0_g128 1-bit weights dispatched through oxibonsai_kernels::OneBitKernel.

§Model Registry

ModelVariant auto-detects the architecture from configuration dimensions and provides parameter counts and expected file sizes.

Re-exports§

pub use calibration::simulate_calibration;
pub use calibration::validate_calibration;
pub use calibration::CalibMethod;
pub use calibration::CalibSummary;
pub use calibration::CalibValidation;
pub use calibration::CalibrationDb;
pub use calibration::LayerCalibStats;
pub use checkpoint::Checkpoint;
pub use checkpoint::CheckpointError;
pub use checkpoint::CheckpointMetadata;
pub use checkpoint::CheckpointTensor;
pub use chunked_prefill::create_prefill_chunks;
pub use chunked_prefill::peak_memory_estimate;
pub use chunked_prefill::ChunkedPrefillConfig;
pub use chunked_prefill::PrefillAction;
pub use chunked_prefill::PrefillChunk;
pub use chunked_prefill::PrefillMemoryEstimate;
pub use chunked_prefill::PrefillPriority;
pub use chunked_prefill::PrefillScheduler;
pub use compression::compress_model;
pub use compression::estimate_compressed_size;
pub use compression::CompressionConfig;
pub use compression::CompressionError;
pub use compression::CompressionResult;
pub use compression::CompressionStage;
pub use compression::StageStats;
pub use disk_cache::CacheEntry;
pub use disk_cache::CacheFileInfo;
pub use disk_cache::CacheManager;
pub use disk_cache::DiskCache;
pub use disk_cache::DiskCacheError;
pub use disk_cache::CACHE_MAGIC;
pub use disk_cache::CACHE_VERSION;
pub use dynamic_quant::compute_scale;
pub use dynamic_quant::compute_smooth_factors;
pub use dynamic_quant::dynamic_quantize_int4;
pub use dynamic_quant::dynamic_quantize_int8;
pub use dynamic_quant::dynamic_quantize_int8_per_row;
pub use dynamic_quant::quantization_mae;
pub use dynamic_quant::smooth_activations;
pub use dynamic_quant::smooth_weights;
pub use dynamic_quant::w8a8_matvec;
pub use dynamic_quant::CalibStats;
pub use dynamic_quant::DynQuantError;
pub use dynamic_quant::DynQuantFormat;
pub use dynamic_quant::DynQuantTensor;
pub use dynamic_quant::DynamicScaleMode;
pub use dynamic_quant::SmoothQuantConfig;
pub use error::ModelError;
pub use error::ModelResult;
pub use gguf_loader::estimate_memory_bytes;
pub use gguf_loader::fits_in_budget;
pub use gguf_loader::load_tensor_metadata;
pub use gguf_loader::validate_gguf_file;
pub use gguf_loader::LoadConfig;
pub use gguf_loader::LoadError;
pub use gguf_loader::LoadStats;
pub use gguf_loader::TensorChunkIter;
pub use gguf_loader::TensorEntry;
pub use gradient_checkpoint::Checkpoint as GradientCheckpoint;
pub use gradient_checkpoint::CheckpointBudget;
pub use gradient_checkpoint::CheckpointError as GradientCheckpointError;
pub use gradient_checkpoint::CheckpointSegment;
pub use gradient_checkpoint::CheckpointStrategy;
pub use gradient_checkpoint::CheckpointedActivation;
pub use gradient_checkpoint::CheckpointedNetwork;
pub use gradient_checkpoint::CheckpointedPipeline;
pub use gradient_checkpoint::LinearSegment;
pub use gradient_checkpoint::Recomputable;
pub use kv_cache::KvCache;
pub use kv_cache_fp16::KvCacheFp16;
pub use kv_cache_quant::dequantize_row_i8;
pub use kv_cache_quant::quant_error_mae;
pub use kv_cache_quant::quantize_row_i8;
pub use kv_cache_quant::Fp8KvCache;
pub use kv_cache_quant::Fp8KvFormat;
pub use kv_cache_quant::Fp8KvLayer;
pub use kv_cache_quant::QuantKvError;
pub use kv_cache_quant::QuantizedKvCache;
pub use kv_cache_quant::QuantizedKvLayer;
pub use layers::attention_sink::AttentionSinkCache;
pub use layers::attention_sink::AttentionSinkConfig;
pub use layers::attention_sink::AttentionSinkLayer;
pub use layers::attention_sink::SinkError;
pub use layers::attention_sink::SinkSlot;
pub use layers::cross_attention::causal_cross_attention;
pub use layers::cross_attention::compute_attention_weights;
pub use layers::cross_attention::cross_attention_forward;
pub use layers::cross_attention::single_head_cross_attention;
pub use layers::cross_attention::CrossAttentionConfig;
pub use layers::cross_attention::CrossAttnError;
pub use layers::flash_decode::flash_decode_multi_head;
pub use layers::flash_decode::flash_decode_single_head;
pub use layers::flash_decode::flash_vs_naive_error;
pub use layers::flash_decode::FlashDecodeConfig;
pub use layers::flash_decode::FlashDecodeError;
pub use layers::mixture_of_depths::mixture_of_depths_forward;
pub use layers::mixture_of_depths::ModConfig;
pub use layers::mixture_of_depths::ModError;
pub use layers::mixture_of_depths::ModRouter;
pub use layers::mixture_of_depths::ModStats;
pub use layers::rope_scaling::apply_rope_with_freqs;
pub use layers::rope_scaling::compute_rope_frequencies;
pub use layers::rope_scaling::dynamic_ntk_base;
pub use layers::rope_scaling::llama31_frequencies;
pub use layers::rope_scaling::FreqStats;
pub use layers::rope_scaling::RopeScalingError;
pub use layers::rope_scaling::RopeScalingStrategy;
pub use layers::sparse_attention::memory_reduction;
pub use layers::sparse_attention::sparse_attention_forward;
pub use layers::sparse_attention::sparse_vs_dense_error;
pub use layers::sparse_attention::SparseAttentionMask;
pub use layers::sparse_attention::SparseAttnError;
pub use layers::sparse_attention::SparsePattern;
pub use layers::yarn_rope::apply_rope;
pub use layers::yarn_rope::apply_yarn_rope;
pub use layers::yarn_rope::LongRopeConfig;
pub use layers::yarn_rope::YarnConfig;
pub use layers::yarn_rope::YarnError;
pub use layers::yarn_rope::YarnFreqTable;
pub use losses::contrastive_loss;
pub use losses::cross_entropy;
pub use losses::cross_entropy_grad;
pub use losses::cross_entropy_single;
pub use losses::distillation_loss;
pub use losses::focal_loss;
pub use losses::huber_loss;
pub use losses::kl_divergence;
pub use losses::label_smoothed_cross_entropy;
pub use losses::log_softmax;
pub use losses::mse;
pub use losses::ntp_loss;
pub use losses::softmax;
pub use losses::LossError;
pub use lr_schedulers::CyclicLr;
pub use lr_schedulers::LinearWarmupCosineDecay;
pub use lr_schedulers::OneCycleLr;
pub use lr_schedulers::PlateauMode;
pub use lr_schedulers::PolynomialDecay;
pub use lr_schedulers::ReduceOnPlateau;
pub use model::BonsaiModel;
pub use model_merge::dare_merge;
pub use model_merge::linear_merge;
pub use model_merge::merge_models;
pub use model_merge::merge_models_with_stats;
pub use model_merge::merge_tensors;
pub use model_merge::slerp;
pub use model_merge::task_vector_merge;
pub use model_merge::ties_merge;
pub use model_merge::MergeConfig;
pub use model_merge::MergeError;
pub use model_merge::MergeMethod;
pub use model_merge::MergeStats;
pub use model_merge::WeightTensor;
pub use model_registry::ModelVariant;
pub use multi_gpu::merge_column_shards;
pub use multi_gpu::partition_weights_column;
pub use multi_gpu::partition_weights_row;
pub use multi_gpu::CollectiveResult;
pub use multi_gpu::DeviceId;
pub use multi_gpu::DeviceInfo;
pub use multi_gpu::DeviceMesh;
pub use multi_gpu::NcclCollectives;
pub use paged_kv_cache::BlockPool;
pub use paged_kv_cache::BlockTable;
pub use paged_kv_cache::KvPage;
pub use paged_kv_cache::PagedKvCache;
pub use paged_kv_cache::PagedKvError;
pub use paged_kv_cache::DEFAULT_BLOCK_SIZE;
pub use prefix_cache::CacheBlock;
pub use prefix_cache::CacheSession;
pub use prefix_cache::PrefixAwarePrefill;
pub use prefix_cache::PrefixCache;
pub use prefix_cache::PrefixCacheStats;
pub use pruning::compute_importance;
pub use pruning::model_sparsity_report;
pub use pruning::prune_model;
pub use pruning::prune_tensor;
pub use pruning::prune_tensor_inplace;
pub use pruning::ImportanceMetric;
pub use pruning::ImportanceScores;
pub use pruning::ModelSparsitySummary;
pub use pruning::PruningConfig;
pub use pruning::PruningError;
pub use pruning::PruningGranularity;
pub use pruning::ScoreStats;
pub use pruning::SparsityReport;
pub use smoothquant::quantize_fp8_e4m3_smooth;
pub use smoothquant::quantize_fp8_e5m2_smooth;
pub use smoothquant::SmoothQuantCalibrator;
pub use smoothquant::SmoothQuantError;
pub use weight_tying::TiedEmbedding;
pub use weight_tying::TyingError;
pub use convert::onnx::convert_onnx_to_gguf;
pub use convert::onnx::DequantError as OnnxDequantError;
pub use convert::onnx::OnnxImportError;
pub use convert::ConvertStats;
pub use layers::linear_kquant_ext::LinearQ5K;
pub use layers::linear_kquant_ext::LinearQ6K;
pub use layers::linear_kquant_full::LinearQ2K;
pub use layers::linear_kquant_full::LinearQ3K;
pub use layers::linear_kquant_full::LinearQ4K;
pub use layers::linear_kquant_full::LinearQ8K;
pub use layers::linear_standard::LinearQ4_0;
pub use layers::linear_standard::LinearQ8_0;

Modules§

block
Auto-generated module structure
calibration
Post-Training Quantization (PTQ) calibration pipeline.
checkpoint
Model checkpoint format for saving and restoring training state.
chunked_prefill
Chunked prefill: process long prompts in smaller chunks.
compression
Model compression pipeline: prune → quantize → report.
convert
HuggingFace safetensors → OxiBonsai GGUF conversion.
disk_cache
On-disk model cache for fast model reloading.
dynamic_quant
Dynamic activation quantization for W8A8 / W4A8 inference.
error
Error types for the model crate.
export
Model weight export utilities.
gguf_loader
Production-quality GGUF model loader with validation, streaming, and memory budgeting.
gradient
Forward-mode automatic differentiation for 1D tensors.
gradient_checkpoint
Gradient checkpointing: trade compute for memory in training.
kv_cache
KV Cache for autoregressive generation.
kv_cache_fp16
FP16 KV cache — halves memory usage by storing keys/values in half precision.
kv_cache_quant
Quantized KV cache: INT8 and FP8 per-row quantization for keys and values.
layers
Transformer layer implementations.
lora
LoRA (Low-Rank Adaptation) adapter support for Bonsai/Qwen3 models.
lora_trainer
LoRA fine-tuning training scaffold.
losses
Loss functions for training language models.
lr_schedulers
Advanced learning rate schedulers for LLM training.
model
Auto-generated module structure
model_config_builder
Builder pattern for constructing Qwen3Config values.
model_merge
Model merging utilities: linear interpolation, SLERP, TIES, and task-vector merging.
model_registry
Multi-model support: auto-detect Bonsai model variant from GGUF metadata.
model_variants
Full architectural specification for all Bonsai model variants.
multi_gpu
Multi-GPU / multi-device inference utilities.
optimizer
Optimizer implementations for LoRA fine-tuning.
paged_kv_cache
PagedAttention / vLLM-style paged KV cache.
pipeline_parallel
Pipeline parallelism utilities for OxiBonsai.
prefix_cache
Prefix KV-cache — share key/value tensors across requests with a common prefix.
pruning
Weight importance analysis and structured/unstructured pruning.
quantize
FP32 → Q1_0_g128 quantization and related utilities.
quantize_int8
INT8 (8-bit symmetric) quantization for weight tensors.
quantize_ternary
Ternary quantization helpers for GGUF export.
smoothquant
SmoothQuant per-channel FP8 calibrator and channel-aware quantization.
tensor_parallel
Tensor parallelism utilities for OxiBonsai.
weight_tying
Weight tying: share input embedding weights with the LM head.