pub struct CudaTransformerBlock { /* private fields */ }Expand description
CUDA-accelerated transformer block
All operations run on GPU with minimal CPU<->GPU transfers.
Implementations§
Source§impl CudaTransformerBlock
impl CudaTransformerBlock
Sourcepub fn new(
config: &TransformerConfig,
layer_idx: usize,
ctx: Arc<CudaContext>,
input_norm_weight: &[f32],
post_attn_norm_weight: &[f32],
w_q: &[f32],
w_k: &[f32],
w_v: &[f32],
w_o: &[f32],
w_gate: &[f32],
w_up: &[f32],
w_down: &[f32],
max_seq_len: usize,
) -> Result<Self>
pub fn new( config: &TransformerConfig, layer_idx: usize, ctx: Arc<CudaContext>, input_norm_weight: &[f32], post_attn_norm_weight: &[f32], w_q: &[f32], w_k: &[f32], w_v: &[f32], w_o: &[f32], w_gate: &[f32], w_up: &[f32], w_down: &[f32], max_seq_len: usize, ) -> Result<Self>
Create a new CUDA transformer block from CPU tensors
Uploads all weights to GPU memory.
Sourcepub fn set_qk_norm(&mut self, q_norm: &[f32], k_norm: &[f32]) -> Result<()>
pub fn set_qk_norm(&mut self, q_norm: &[f32], k_norm: &[f32]) -> Result<()>
Set QK-norm weights (ENT-270). Called after construction when loading Qwen3 models.
Sourcepub fn forward(
&mut self,
input: &GpuBuffer<f32>,
output: &mut GpuBuffer<f32>,
seq_len: usize,
stream: &CudaStream,
) -> Result<()>
pub fn forward( &mut self, input: &GpuBuffer<f32>, output: &mut GpuBuffer<f32>, seq_len: usize, stream: &CudaStream, ) -> Result<()>
Forward pass - all operations on GPU
§Arguments
input- Input tensor on GPU (seq_len * hidden_size)output- Output tensor on GPU (seq_len * hidden_size)seq_len- Sequence lengthstream- CUDA stream for async execution
Sourcepub fn config(&self) -> &TransformerConfig
pub fn config(&self) -> &TransformerConfig
Get configuration
Sourcepub fn backward(
&mut self,
input: &GpuBuffer<f32>,
grad_output: &GpuBuffer<f32>,
grad_input: &mut GpuBuffer<f32>,
seq_len: usize,
stream: &CudaStream,
grad_ws: &mut CudaGradWorkspace,
) -> Result<()>
pub fn backward( &mut self, input: &GpuBuffer<f32>, grad_output: &GpuBuffer<f32>, grad_input: &mut GpuBuffer<f32>, seq_len: usize, stream: &CudaStream, grad_ws: &mut CudaGradWorkspace, ) -> Result<()>
Backward pass - gradient computation on GPU (ENT-151)
Computes gradients for all parameters given upstream gradient.
§Arguments
input- Original input from forward pass (seq_len * hidden_size)grad_output- Gradient from upstream layer (seq_len * hidden_size)grad_input- Output: gradient w.r.t. input (seq_len * hidden_size)seq_len- Sequence lengthstream- CUDA stream for async execution
§Returns
Gradients are accumulated into the scratch buffers:
scratch.grad_input_norm- Gradient for input RMSNorm weightscratch.grad_post_attn_norm- Gradient for post-attention RMSNorm weightscratch.grad_gate/up/down- Gradients for FFN weightsscratch.grad_w_q/w_k/w_v/w_o- Gradients for attention projection weights
Sourcepub fn init_optimizer_state(&self) -> Result<GpuBlockOptimizerState>
pub fn init_optimizer_state(&self) -> Result<GpuBlockOptimizerState>
Initialize GPU-resident AdamW optimizer state for all block weights.
Allocates zero-initialized first and second moment buffers for each of the 9 weight tensors (4 attention projections + 3 FFN projections + 2 RMSNorm).
§Contract (C-OPTINIT-001)
- Precondition: CUDA context is valid, sufficient GPU memory available
- Postcondition: All m/v buffers are zero-initialized with dimensions matching the corresponding weight tensors
- Invariant: Total GPU memory for optimizer state = 2 × sum(weight_sizes) × 4 bytes
Sourcepub fn optimizer_step(
&mut self,
state: &mut GpuBlockOptimizerState,
step: u32,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
stream: &CudaStream,
grad_ws: &CudaGradWorkspace,
) -> Result<()>
pub fn optimizer_step( &mut self, state: &mut GpuBlockOptimizerState, step: u32, lr: f32, beta1: f32, beta2: f32, eps: f32, weight_decay: f32, stream: &CudaStream, grad_ws: &CudaGradWorkspace, ) -> Result<()>
Run GPU-resident AdamW optimizer step on all block weights.
Updates weights in-place using gradients computed by backward().
All operations run on GPU — zero CPU↔GPU data transfers.
§Contract (C-OPTSTEP-001)
- Precondition:
backward()completed for this block (scratch grad buffers valid),stateinitialized viainit_optimizer_state(),step > 0 - Postcondition: All 9 weight tensors updated by AdamW rule, m/v states updated with current gradient statistics
- Invariant: Weight dimensions unchanged; no GPU memory allocated or freed
Sourcepub fn download_weights(&self) -> Result<BlockWeights>
pub fn download_weights(&self) -> Result<BlockWeights>
Download all weight data from GPU to host vectors.
Used to synchronize GPU-updated weights back to CPU model for checkpointing.
§Contract (C-DLWEIGHTS-001)
- Precondition: Block weights are valid GPU allocations
- Postcondition: Returned vectors have exact same length and content as GPU buffers
- Invariant: GPU buffers are not modified
Auto Trait Implementations§
impl Freeze for CudaTransformerBlock
impl RefUnwindSafe for CudaTransformerBlock
impl Send for CudaTransformerBlock
impl Sync for CudaTransformerBlock
impl Unpin for CudaTransformerBlock
impl UnsafeUnpin for CudaTransformerBlock
impl UnwindSafe for CudaTransformerBlock
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> FmtForward for T
impl<T> FmtForward for T
Source§fn fmt_binary(self) -> FmtBinary<Self>where
Self: Binary,
fn fmt_binary(self) -> FmtBinary<Self>where
Self: Binary,
self to use its Binary implementation when Debug-formatted.Source§fn fmt_display(self) -> FmtDisplay<Self>where
Self: Display,
fn fmt_display(self) -> FmtDisplay<Self>where
Self: Display,
self to use its Display implementation when
Debug-formatted.Source§fn fmt_lower_exp(self) -> FmtLowerExp<Self>where
Self: LowerExp,
fn fmt_lower_exp(self) -> FmtLowerExp<Self>where
Self: LowerExp,
self to use its LowerExp implementation when
Debug-formatted.Source§fn fmt_lower_hex(self) -> FmtLowerHex<Self>where
Self: LowerHex,
fn fmt_lower_hex(self) -> FmtLowerHex<Self>where
Self: LowerHex,
self to use its LowerHex implementation when
Debug-formatted.Source§fn fmt_octal(self) -> FmtOctal<Self>where
Self: Octal,
fn fmt_octal(self) -> FmtOctal<Self>where
Self: Octal,
self to use its Octal implementation when Debug-formatted.Source§fn fmt_pointer(self) -> FmtPointer<Self>where
Self: Pointer,
fn fmt_pointer(self) -> FmtPointer<Self>where
Self: Pointer,
self to use its Pointer implementation when
Debug-formatted.Source§fn fmt_upper_exp(self) -> FmtUpperExp<Self>where
Self: UpperExp,
fn fmt_upper_exp(self) -> FmtUpperExp<Self>where
Self: UpperExp,
self to use its UpperExp implementation when
Debug-formatted.Source§fn fmt_upper_hex(self) -> FmtUpperHex<Self>where
Self: UpperHex,
fn fmt_upper_hex(self) -> FmtUpperHex<Self>where
Self: UpperHex,
self to use its UpperHex implementation when
Debug-formatted.Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pipe for Twhere
T: ?Sized,
impl<T> Pipe for Twhere
T: ?Sized,
Source§fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> Rwhere
Self: Sized,
fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> Rwhere
Self: Sized,
Source§fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> Rwhere
R: 'a,
fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> Rwhere
R: 'a,
self and passes that borrow into the pipe function. Read moreSource§fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> Rwhere
R: 'a,
fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> Rwhere
R: 'a,
self and passes that borrow into the pipe function. Read moreSource§fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
Source§fn pipe_borrow_mut<'a, B, R>(
&'a mut self,
func: impl FnOnce(&'a mut B) -> R,
) -> R
fn pipe_borrow_mut<'a, B, R>( &'a mut self, func: impl FnOnce(&'a mut B) -> R, ) -> R
Source§fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
self, then passes self.as_ref() into the pipe function.Source§fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
self, then passes self.as_mut() into the pipe
function.Source§fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
self, then passes self.deref() into the pipe function.Source§impl<T> Pointable for T
impl<T> Pointable for T
Source§impl<T> PolicyExt for Twhere
T: ?Sized,
impl<T> PolicyExt for Twhere
T: ?Sized,
Source§impl<T> Tap for T
impl<T> Tap for T
Source§fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
Borrow<B> of a value. Read moreSource§fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
BorrowMut<B> of a value. Read moreSource§fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
AsRef<R> view of a value. Read moreSource§fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
AsMut<R> view of a value. Read moreSource§fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
Deref::Target of a value. Read moreSource§fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
Deref::Target of a value. Read moreSource§fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self
fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self
.tap() only in debug builds, and is erased in release builds.Source§fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self
fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self
.tap_mut() only in debug builds, and is erased in release
builds.Source§fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
.tap_borrow() only in debug builds, and is erased in release
builds.Source§fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
.tap_borrow_mut() only in debug builds, and is erased in release
builds.Source§fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
.tap_ref() only in debug builds, and is erased in release
builds.Source§fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
.tap_ref_mut() only in debug builds, and is erased in release
builds.Source§fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
.tap_deref() only in debug builds, and is erased in release
builds.