pub struct CudaNf4TransformerBlock { /* private fields */ }Expand description
CUDA-accelerated transformer block with NF4-quantized frozen weights.
Stores the 7 projection weights as packed NF4 (4-bit) + per-block scales instead of fp32, achieving ~8x compression. Norm weights remain fp32 (negligible size).
§VRAM Savings (Qwen3-4B example)
| Component | fp32 | NF4 |
|---|---|---|
| Frozen weights (36L × 7 projections) | 16.0 GB | 2.1 GB |
§Forward Only
NF4 blocks are frozen — no backward pass needed. LoRA adapters (fp32) handle the trainable parameters separately. The forward pass uses fused dequant+GEMM kernels that read NF4 directly without materializing fp32 weights.
Implementations§
Source§impl CudaNf4TransformerBlock
impl CudaNf4TransformerBlock
Sourcepub fn new(
config: &TransformerConfig,
layer_idx: usize,
ctx: Arc<CudaContext>,
input_norm_weight: &[f32],
post_attn_norm_weight: &[f32],
w_q: &[f32],
w_k: &[f32],
w_v: &[f32],
w_o: &[f32],
w_gate: &[f32],
w_up: &[f32],
w_down: &[f32],
_max_seq_len: usize,
q_lora: Option<(&[f32], &[f32])>,
v_lora: Option<(&[f32], &[f32])>,
lora_scale: f32,
lora_rank: usize,
q_norm: Option<&[f32]>,
k_norm: Option<&[f32]>,
) -> Result<Self>
pub fn new( config: &TransformerConfig, layer_idx: usize, ctx: Arc<CudaContext>, input_norm_weight: &[f32], post_attn_norm_weight: &[f32], w_q: &[f32], w_k: &[f32], w_v: &[f32], w_o: &[f32], w_gate: &[f32], w_up: &[f32], w_down: &[f32], _max_seq_len: usize, q_lora: Option<(&[f32], &[f32])>, v_lora: Option<(&[f32], &[f32])>, lora_scale: f32, lora_rank: usize, q_norm: Option<&[f32]>, k_norm: Option<&[f32]>, ) -> Result<Self>
Create a new NF4 transformer block from fp32 CPU tensors.
Quantizes all 7 projection weights to NF4 on CPU, then uploads the packed data and scales to GPU. Norm weights are uploaded as fp32.
Sourcepub fn set_fp16_weights(&mut self, stream: &CudaStream) -> Result<()>
pub fn set_fp16_weights(&mut self, stream: &CudaStream) -> Result<()>
Cast fp32→fp16 weights + drop fp32 (PMAT-470/472). Frees ~2.6 GB VRAM.
Source§impl CudaNf4TransformerBlock
impl CudaNf4TransformerBlock
Sourcepub fn download_lora_weights(
&self,
) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)>
pub fn download_lora_weights( &self, ) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>)>
Download LoRA weights from GPU to CPU for checkpoint saving.
Returns (A_q, B_q, A_v, B_v) as flat f32 vectors. B matrices are returned WITH the baked-in scale (caller can divide by lora_scale if they need the unscaled version).
Sourcepub fn upload_lora_weights(
&mut self,
a_q: &[f32],
b_q: &[f32],
a_v: &[f32],
b_v: &[f32],
) -> Result<()>
pub fn upload_lora_weights( &mut self, a_q: &[f32], b_q: &[f32], a_v: &[f32], b_v: &[f32], ) -> Result<()>
Upload LoRA weights from CPU to GPU for checkpoint resume (ENT-276).
Overwrites the current LoRA adapter buffers with trained weights
restored from a checkpoint. Call after new() to replace the fresh
random init with previously trained adapters.
Auto Trait Implementations§
impl Freeze for CudaNf4TransformerBlock
impl RefUnwindSafe for CudaNf4TransformerBlock
impl Send for CudaNf4TransformerBlock
impl Sync for CudaNf4TransformerBlock
impl Unpin for CudaNf4TransformerBlock
impl UnsafeUnpin for CudaNf4TransformerBlock
impl UnwindSafe for CudaNf4TransformerBlock
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> FmtForward for T
impl<T> FmtForward for T
Source§fn fmt_binary(self) -> FmtBinary<Self>where
Self: Binary,
fn fmt_binary(self) -> FmtBinary<Self>where
Self: Binary,
self to use its Binary implementation when Debug-formatted.Source§fn fmt_display(self) -> FmtDisplay<Self>where
Self: Display,
fn fmt_display(self) -> FmtDisplay<Self>where
Self: Display,
self to use its Display implementation when
Debug-formatted.Source§fn fmt_lower_exp(self) -> FmtLowerExp<Self>where
Self: LowerExp,
fn fmt_lower_exp(self) -> FmtLowerExp<Self>where
Self: LowerExp,
self to use its LowerExp implementation when
Debug-formatted.Source§fn fmt_lower_hex(self) -> FmtLowerHex<Self>where
Self: LowerHex,
fn fmt_lower_hex(self) -> FmtLowerHex<Self>where
Self: LowerHex,
self to use its LowerHex implementation when
Debug-formatted.Source§fn fmt_octal(self) -> FmtOctal<Self>where
Self: Octal,
fn fmt_octal(self) -> FmtOctal<Self>where
Self: Octal,
self to use its Octal implementation when Debug-formatted.Source§fn fmt_pointer(self) -> FmtPointer<Self>where
Self: Pointer,
fn fmt_pointer(self) -> FmtPointer<Self>where
Self: Pointer,
self to use its Pointer implementation when
Debug-formatted.Source§fn fmt_upper_exp(self) -> FmtUpperExp<Self>where
Self: UpperExp,
fn fmt_upper_exp(self) -> FmtUpperExp<Self>where
Self: UpperExp,
self to use its UpperExp implementation when
Debug-formatted.Source§fn fmt_upper_hex(self) -> FmtUpperHex<Self>where
Self: UpperHex,
fn fmt_upper_hex(self) -> FmtUpperHex<Self>where
Self: UpperHex,
self to use its UpperHex implementation when
Debug-formatted.Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pipe for Twhere
T: ?Sized,
impl<T> Pipe for Twhere
T: ?Sized,
Source§fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> Rwhere
Self: Sized,
fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> Rwhere
Self: Sized,
Source§fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> Rwhere
R: 'a,
fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> Rwhere
R: 'a,
self and passes that borrow into the pipe function. Read moreSource§fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> Rwhere
R: 'a,
fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> Rwhere
R: 'a,
self and passes that borrow into the pipe function. Read moreSource§fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
Source§fn pipe_borrow_mut<'a, B, R>(
&'a mut self,
func: impl FnOnce(&'a mut B) -> R,
) -> R
fn pipe_borrow_mut<'a, B, R>( &'a mut self, func: impl FnOnce(&'a mut B) -> R, ) -> R
Source§fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
self, then passes self.as_ref() into the pipe function.Source§fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
self, then passes self.as_mut() into the pipe
function.Source§fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
self, then passes self.deref() into the pipe function.Source§impl<T> Pointable for T
impl<T> Pointable for T
Source§impl<T> PolicyExt for Twhere
T: ?Sized,
impl<T> PolicyExt for Twhere
T: ?Sized,
Source§impl<T> Tap for T
impl<T> Tap for T
Source§fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
Borrow<B> of a value. Read moreSource§fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
BorrowMut<B> of a value. Read moreSource§fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
AsRef<R> view of a value. Read moreSource§fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
AsMut<R> view of a value. Read moreSource§fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
Deref::Target of a value. Read moreSource§fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
Deref::Target of a value. Read moreSource§fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self
fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self
.tap() only in debug builds, and is erased in release builds.Source§fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self
fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self
.tap_mut() only in debug builds, and is erased in release
builds.Source§fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
.tap_borrow() only in debug builds, and is erased in release
builds.Source§fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
.tap_borrow_mut() only in debug builds, and is erased in release
builds.Source§fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
.tap_ref() only in debug builds, and is erased in release
builds.Source§fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
.tap_ref_mut() only in debug builds, and is erased in release
builds.Source§fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
.tap_deref() only in debug builds, and is erased in release
builds.