Skip to main content

oxibonsai_model/
lib.rs

1//! # oxibonsai-model
2//!
3//! Qwen3 Transformer implementation for 1-bit Bonsai inference.
4//!
5//! This crate implements the full autoregressive forward pass for the
6//! Qwen3 architecture family (8B, 4B, 1.7B) using 1-bit quantised
7//! weights. The forward pass pipeline is:
8//!
9//! 1. **Token embedding** — FP32 lookup from a `[vocab_size x hidden_size]` table
10//! 2. **N Transformer blocks**, each containing:
11//!    - Pre-attention **RMSNorm**
12//!    - **Grouped Query Attention** (GQA) with rotary position embeddings
13//!    - Pre-FFN **RMSNorm**
14//!    - **SwiGLU MLP** (gate + up + down projections)
15//! 3. **Final RMSNorm**
16//! 4. **LM head** projection to vocabulary logits
17//!
18//! All linear projections in the Transformer blocks use Q1\_0\_g128 1-bit
19//! weights dispatched through [`oxibonsai_kernels::OneBitKernel`].
20//!
21//! ## Model Registry
22//!
23//! [`ModelVariant`] auto-detects the architecture from configuration
24//! dimensions and provides parameter counts and expected file sizes.
25
26pub mod block;
27pub mod calibration;
28pub mod checkpoint;
29pub mod chunked_prefill;
30pub mod compression;
31pub mod convert;
32pub mod disk_cache;
33pub mod dynamic_quant;
34pub mod error;
35pub mod export;
36pub mod gguf_loader;
37pub mod gradient;
38pub mod gradient_checkpoint;
39pub mod kv_cache;
40pub mod kv_cache_fp16;
41pub mod kv_cache_quant;
42pub mod layers;
43pub mod lora;
44pub mod lora_trainer;
45pub mod losses;
46pub mod lr_schedulers;
47pub mod model;
48pub mod model_config_builder;
49pub mod model_merge;
50pub mod model_registry;
51pub mod model_variants;
52pub mod multi_gpu;
53pub mod optimizer;
54pub mod paged_kv_cache;
55pub mod pipeline_parallel;
56pub mod prefix_cache;
57pub mod pruning;
58pub mod quantize;
59pub mod quantize_int8;
60pub mod quantize_ternary;
61pub mod smoothquant;
62pub mod tensor_parallel;
63pub mod weight_tying;
64
65pub use calibration::{
66    simulate_calibration, validate_calibration, CalibMethod, CalibSummary, CalibValidation,
67    CalibrationDb, LayerCalibStats,
68};
69pub use checkpoint::{Checkpoint, CheckpointError, CheckpointMetadata, CheckpointTensor};
70pub use chunked_prefill::{
71    create_prefill_chunks, peak_memory_estimate, ChunkedPrefillConfig, PrefillAction, PrefillChunk,
72    PrefillMemoryEstimate, PrefillPriority, PrefillScheduler,
73};
74pub use compression::{
75    compress_model, estimate_compressed_size, CompressionConfig, CompressionError,
76    CompressionResult, CompressionStage, StageStats,
77};
78pub use disk_cache::{
79    CacheEntry, CacheFileInfo, CacheManager, DiskCache, DiskCacheError, CACHE_MAGIC, CACHE_VERSION,
80};
81pub use dynamic_quant::{
82    compute_scale, compute_smooth_factors, dynamic_quantize_int4, dynamic_quantize_int8,
83    dynamic_quantize_int8_per_row, quantization_mae, smooth_activations, smooth_weights,
84    w8a8_matvec, CalibStats, DynQuantError, DynQuantFormat, DynQuantTensor, DynamicScaleMode,
85    SmoothQuantConfig,
86};
87pub use error::{ModelError, ModelResult};
88pub use gguf_loader::{
89    estimate_memory_bytes, fits_in_budget, load_tensor_metadata, validate_gguf_file, LoadConfig,
90    LoadError, LoadStats, TensorChunkIter, TensorEntry,
91};
92pub use gradient_checkpoint::{
93    Checkpoint as GradientCheckpoint, CheckpointBudget, CheckpointError as GradientCheckpointError,
94    CheckpointSegment, CheckpointStrategy, CheckpointedActivation, CheckpointedNetwork,
95    CheckpointedPipeline, LinearSegment, Recomputable,
96};
97pub use kv_cache::KvCache;
98pub use kv_cache_fp16::KvCacheFp16;
99pub use kv_cache_quant::{
100    dequantize_row_i8, quant_error_mae, quantize_row_i8, Fp8KvCache, Fp8KvFormat, Fp8KvLayer,
101    QuantKvError, QuantizedKvCache, QuantizedKvLayer,
102};
103pub use layers::attention_sink::{
104    AttentionSinkCache, AttentionSinkConfig, AttentionSinkLayer, SinkError, SinkSlot,
105};
106pub use layers::cross_attention::{
107    causal_cross_attention, compute_attention_weights, cross_attention_forward,
108    single_head_cross_attention, CrossAttentionConfig, CrossAttnError,
109};
110pub use layers::flash_decode::{
111    flash_decode_multi_head, flash_decode_single_head, flash_vs_naive_error, FlashDecodeConfig,
112    FlashDecodeError,
113};
114pub use layers::mixture_of_depths::{
115    mixture_of_depths_forward, ModConfig, ModError, ModRouter, ModStats,
116};
117pub use layers::rope_scaling::{
118    apply_rope_with_freqs, compute_rope_frequencies, dynamic_ntk_base, llama31_frequencies,
119    FreqStats, RopeScalingError, RopeScalingStrategy,
120};
121pub use layers::sparse_attention::{
122    memory_reduction, sparse_attention_forward, sparse_vs_dense_error, SparseAttentionMask,
123    SparseAttnError, SparsePattern,
124};
125pub use layers::yarn_rope::{
126    apply_rope, apply_yarn_rope, LongRopeConfig, YarnConfig, YarnError, YarnFreqTable,
127};
128pub use losses::{
129    contrastive_loss, cross_entropy, cross_entropy_grad, cross_entropy_single, distillation_loss,
130    focal_loss, huber_loss, kl_divergence, label_smoothed_cross_entropy, log_softmax, mse,
131    ntp_loss, softmax, LossError,
132};
133pub use lr_schedulers::{
134    CyclicLr, LinearWarmupCosineDecay, OneCycleLr, PlateauMode, PolynomialDecay, ReduceOnPlateau,
135};
136pub use model::BonsaiModel;
137pub use model_merge::{
138    dare_merge, linear_merge, merge_models, merge_models_with_stats, merge_tensors, slerp,
139    task_vector_merge, ties_merge, MergeConfig, MergeError, MergeMethod, MergeStats, WeightTensor,
140};
141pub use model_registry::ModelVariant;
142pub use multi_gpu::{
143    merge_column_shards, partition_weights_column, partition_weights_row, CollectiveResult,
144    DeviceId, DeviceInfo, DeviceMesh, NcclCollectives,
145};
146pub use paged_kv_cache::{
147    BlockPool, BlockTable, KvPage, PagedKvCache, PagedKvError, DEFAULT_BLOCK_SIZE,
148};
149pub use prefix_cache::{
150    CacheBlock, CacheSession, PrefixAwarePrefill, PrefixCache, PrefixCacheStats,
151};
152pub use pruning::{
153    compute_importance, model_sparsity_report, prune_model, prune_tensor, prune_tensor_inplace,
154    ImportanceMetric, ImportanceScores, ModelSparsitySummary, PruningConfig, PruningError,
155    PruningGranularity, ScoreStats, SparsityReport,
156};
157pub use smoothquant::{
158    quantize_fp8_e4m3_smooth, quantize_fp8_e5m2_smooth, SmoothQuantCalibrator, SmoothQuantError,
159};
160pub use weight_tying::{TiedEmbedding, TyingError};
161
162pub use convert::onnx::{convert_onnx_to_gguf, DequantError as OnnxDequantError, OnnxImportError};
163pub use convert::ConvertStats;
164pub use layers::linear_kquant_ext::{LinearQ5K, LinearQ6K};
165pub use layers::linear_kquant_full::{LinearQ2K, LinearQ3K, LinearQ4K, LinearQ8K};
166pub use layers::linear_standard::{LinearQ4_0, LinearQ8_0};