pub struct WgslForwardPass { /* private fields */ }Expand description
PMAT-324: WGSL transformer forward pass shaders. GPU-resident transformer layer state. All buffers persist across tokens — only input/output change per step.
Implementations§
Source§impl WgslForwardPass
impl WgslForwardPass
Sourcepub fn rmsnorm_shader() -> &'static str
pub fn rmsnorm_shader() -> &'static str
Get the shader sources for external inspection/testing
pub fn silu_mul_shader() -> &'static str
pub fn residual_shader() -> &'static str
pub fn rope_shader() -> &'static str
Sourcepub fn new(
device: Device,
queue: Queue,
hidden_dim: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
intermediate_dim: usize,
) -> Self
pub fn new( device: Device, queue: Queue, hidden_dim: usize, num_heads: usize, num_kv_heads: usize, head_dim: usize, intermediate_dim: usize, ) -> Self
PMAT-325: Create a new WGSL forward pass context.
Compiles all shader pipelines and allocates persistent intermediate buffers. Call once at model init. All GPU resources persist until dropped.
Sourcepub fn upload_weight(&mut self, name: &str, data: &[f32])
pub fn upload_weight(&mut self, name: &str, data: &[f32])
Upload a weight matrix (call once per layer at init). PMAT-342: Bias weights (name contains “bias”) are stored CPU-side.
Sourcepub fn upload_q4k_weight(&mut self, name: &str, data: &[u8])
pub fn upload_q4k_weight(&mut self, name: &str, data: &[u8])
GH-560: Upload raw Q4K weight bytes for fused dequant+GEMV on GPU.
Sourcepub fn init_kv_cache(&mut self, num_layers: usize)
pub fn init_kv_cache(&mut self, num_layers: usize)
GH-560: Initialize per-layer KV cache buffers on GPU.
Sourcepub fn weight_count(&self) -> usize
pub fn weight_count(&self) -> usize
Number of uploaded weight buffers.
Sourcepub fn weight_buffer(&self, name: &str) -> Option<&Buffer>
pub fn weight_buffer(&self, name: &str) -> Option<&Buffer>
Access a dequantized weight buffer by name (e.g. “layer.0.down_proj”). Used by backward pass for gradient propagation through frozen base weights.
Sourcepub fn device_ref(&self) -> &Device
pub fn device_ref(&self) -> &Device
Reference to the wgpu device.
Reference to the hidden state buffer (for writing input).
Sourcepub fn gpu_residual_add(
&self,
a: &Buffer,
b: &Buffer,
output: &Buffer,
len: u32,
)
pub fn gpu_residual_add( &self, a: &Buffer, b: &Buffer, output: &Buffer, len: u32, )
Elementwise add: output = a + b. Dispatches residual add shader.
Sourcepub fn gpu_rmsnorm(&self, weight: &Buffer, output: &Buffer, _seq_len: u32)
pub fn gpu_rmsnorm(&self, weight: &Buffer, output: &Buffer, _seq_len: u32)
Apply RMSNorm on GPU: normed = rmsnorm(hidden_buf, weight) → output_buf. Contract: gpu-output-norm-v1 / gpu_resident — hidden state never leaves GPU.
Download hidden state from GPU.
Sourcepub fn total_vram_bytes(&self) -> usize
pub fn total_vram_bytes(&self) -> usize
Total VRAM used by all buffers (bytes).
Sourcepub fn forward_model(
&self,
token_id: u32,
position: usize,
num_layers: usize,
token_embedding: &[f32],
output_norm_weight: &[f32],
lm_head_weight: &[f32],
vocab_size: usize,
eps: f32,
kv_caches: &mut Vec<(Vec<f32>, Vec<f32>)>,
) -> Result<Vec<f32>, String>
pub fn forward_model( &self, token_id: u32, position: usize, num_layers: usize, token_embedding: &[f32], output_norm_weight: &[f32], lm_head_weight: &[f32], vocab_size: usize, eps: f32, kv_caches: &mut Vec<(Vec<f32>, Vec<f32>)>, ) -> Result<Vec<f32>, String>
PMAT-336: Full model forward — embedding + all layers + output norm + LM head.
Returns logits [vocab_size] for the given token at the given position. Embedding lookup and final LM head are CPU-side (not yet GPU-accelerated). PMAT-344: Added kv_caches for multi-token context
Sourcepub fn forward_layer(
&self,
hidden: &mut [f32],
layer_prefix: &str,
_position: usize,
kv_cache_k: &mut Vec<f32>,
kv_cache_v: &mut Vec<f32>,
) -> Result<(), String>
pub fn forward_layer( &self, hidden: &mut [f32], layer_prefix: &str, _position: usize, kv_cache_k: &mut Vec<f32>, kv_cache_v: &mut Vec<f32>, ) -> Result<(), String>
PMAT-325: Execute one transformer layer — 14 passes, 1 submit, 1 readback.
Input: hidden state [hidden_dim] on CPU. Output: updated hidden state [hidden_dim] on CPU. All intermediate computation stays GPU-resident. PMAT-344: KV cache parameters for multi-token context
Sourcepub fn encode_forward_layer_training(
&self,
encoder: &mut CommandEncoder,
seq_len: u32,
layer_prefix: &str,
saved: &LayerActivations,
lora: Option<&QkvLoRA<'_>>,
) -> Result<(), String>
pub fn encode_forward_layer_training( &self, encoder: &mut CommandEncoder, seq_len: u32, layer_prefix: &str, saved: &LayerActivations, lora: Option<&QkvLoRA<'_>>, ) -> Result<(), String>
Training forward pass for a single transformer layer.
Unlike forward_layer (M=1 decode), this processes the full sequence
at once (M=seq_len) and keeps everything on GPU. No CPU readback.
Saves norm_output (pre-projection activations) for backward pass.
§Arguments
seq_len: number of tokens in the sequencelayer_prefix: e.g. “model.layers.0”saved_norm_attn: OUTPUT — saved pre-attention norm for backward (wgpu::Buffer, [seq×hidden])saved_norm_ffn: OUTPUT — saved pre-FFN norm for backward (wgpu::Buffer, [seq×hidden])
Forward one layer into an EXISTING encoder (no submit). Caller batches multiple layers into one encoder, submits once.
Sourcepub fn forward_layer_traced(
&self,
seq_len: u32,
layer_prefix: &str,
saved: &LayerActivations,
lora: Option<&QkvLoRA<'_>>,
) -> Result<(), String>
pub fn forward_layer_traced( &self, seq_len: u32, layer_prefix: &str, saved: &LayerActivations, lora: Option<&QkvLoRA<'_>>, ) -> Result<(), String>
Run one layer with per-operation GPU timing (submit+poll between each op group). Contract: forward-pass-perf-v1 / bottleneck_identified
Sourcepub fn alloc_layer_activations(&self, seq_len: u32) -> LayerActivations
pub fn alloc_layer_activations(&self, seq_len: u32) -> LayerActivations
Allocate saved activations for one layer.
Sourcepub fn forward_layer_training(
&self,
seq_len: u32,
layer_prefix: &str,
) -> Result<LayerActivations, String>
pub fn forward_layer_training( &self, seq_len: u32, layer_prefix: &str, ) -> Result<LayerActivations, String>
Forward one layer with its own encoder + submit (original API, kept for compat).
Sourcepub fn forward_all_layers_training(
&self,
seq_len: u32,
num_layers: usize,
) -> Result<Vec<LayerActivations>, String>
pub fn forward_all_layers_training( &self, seq_len: u32, num_layers: usize, ) -> Result<Vec<LayerActivations>, String>
Forward ALL layers in one encoder submit. 28 layers → 1 GPU sync.
Sourcepub fn encode_broadcast_bias(
&self,
encoder: &mut CommandEncoder,
buf: &Buffer,
bias: &[f32],
seq_len: u32,
)
pub fn encode_broadcast_bias( &self, encoder: &mut CommandEncoder, buf: &Buffer, bias: &[f32], seq_len: u32, )
Encode causal multi-head attention on GPU. Q: [seq_len, num_heads * head_dim], K/V: [seq_len, num_kv_heads * head_dim] Output written to q_buf (reused as attn output). PMAT-509: Add broadcast bias to a [seq_len, dim] buffer. bias has shape [dim], applied to each of seq_len rows.
Auto Trait Implementations§
impl Freeze for WgslForwardPass
impl !RefUnwindSafe for WgslForwardPass
impl Send for WgslForwardPass
impl Sync for WgslForwardPass
impl Unpin for WgslForwardPass
impl UnsafeUnpin for WgslForwardPass
impl !UnwindSafe for WgslForwardPass
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> FmtForward for T
impl<T> FmtForward for T
Source§fn fmt_binary(self) -> FmtBinary<Self>where
Self: Binary,
fn fmt_binary(self) -> FmtBinary<Self>where
Self: Binary,
self to use its Binary implementation when Debug-formatted.Source§fn fmt_display(self) -> FmtDisplay<Self>where
Self: Display,
fn fmt_display(self) -> FmtDisplay<Self>where
Self: Display,
self to use its Display implementation when
Debug-formatted.Source§fn fmt_lower_exp(self) -> FmtLowerExp<Self>where
Self: LowerExp,
fn fmt_lower_exp(self) -> FmtLowerExp<Self>where
Self: LowerExp,
self to use its LowerExp implementation when
Debug-formatted.Source§fn fmt_lower_hex(self) -> FmtLowerHex<Self>where
Self: LowerHex,
fn fmt_lower_hex(self) -> FmtLowerHex<Self>where
Self: LowerHex,
self to use its LowerHex implementation when
Debug-formatted.Source§fn fmt_octal(self) -> FmtOctal<Self>where
Self: Octal,
fn fmt_octal(self) -> FmtOctal<Self>where
Self: Octal,
self to use its Octal implementation when Debug-formatted.Source§fn fmt_pointer(self) -> FmtPointer<Self>where
Self: Pointer,
fn fmt_pointer(self) -> FmtPointer<Self>where
Self: Pointer,
self to use its Pointer implementation when
Debug-formatted.Source§fn fmt_upper_exp(self) -> FmtUpperExp<Self>where
Self: UpperExp,
fn fmt_upper_exp(self) -> FmtUpperExp<Self>where
Self: UpperExp,
self to use its UpperExp implementation when
Debug-formatted.Source§fn fmt_upper_hex(self) -> FmtUpperHex<Self>where
Self: UpperHex,
fn fmt_upper_hex(self) -> FmtUpperHex<Self>where
Self: UpperHex,
self to use its UpperHex implementation when
Debug-formatted.Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pipe for Twhere
T: ?Sized,
impl<T> Pipe for Twhere
T: ?Sized,
Source§fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> Rwhere
Self: Sized,
fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> Rwhere
Self: Sized,
Source§fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> Rwhere
R: 'a,
fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> Rwhere
R: 'a,
self and passes that borrow into the pipe function. Read moreSource§fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> Rwhere
R: 'a,
fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> Rwhere
R: 'a,
self and passes that borrow into the pipe function. Read moreSource§fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
Source§fn pipe_borrow_mut<'a, B, R>(
&'a mut self,
func: impl FnOnce(&'a mut B) -> R,
) -> R
fn pipe_borrow_mut<'a, B, R>( &'a mut self, func: impl FnOnce(&'a mut B) -> R, ) -> R
Source§fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
self, then passes self.as_ref() into the pipe function.Source§fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
self, then passes self.as_mut() into the pipe
function.Source§fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
self, then passes self.deref() into the pipe function.Source§impl<T> Pointable for T
impl<T> Pointable for T
Source§impl<T> Tap for T
impl<T> Tap for T
Source§fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
Borrow<B> of a value. Read moreSource§fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
BorrowMut<B> of a value. Read moreSource§fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
AsRef<R> view of a value. Read moreSource§fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
AsMut<R> view of a value. Read moreSource§fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
Deref::Target of a value. Read moreSource§fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
Deref::Target of a value. Read moreSource§fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self
fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self
.tap() only in debug builds, and is erased in release builds.Source§fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self
fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self
.tap_mut() only in debug builds, and is erased in release
builds.Source§fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
.tap_borrow() only in debug builds, and is erased in release
builds.Source§fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
.tap_borrow_mut() only in debug builds, and is erased in release
builds.Source§fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
.tap_ref() only in debug builds, and is erased in release
builds.Source§fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
.tap_ref_mut() only in debug builds, and is erased in release
builds.Source§fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
.tap_deref() only in debug builds, and is erased in release
builds.