Skip to main content

WgslForwardPass

Struct WgslForwardPass 

Source
pub struct WgslForwardPass { /* private fields */ }
Expand description

PMAT-324: WGSL transformer forward pass shaders. GPU-resident transformer layer state. All buffers persist across tokens — only input/output change per step.

Implementations§

Source§

impl WgslForwardPass

Source

pub fn rmsnorm_shader() -> &'static str

Get the shader sources for external inspection/testing

Source

pub fn silu_mul_shader() -> &'static str

Source

pub fn residual_shader() -> &'static str

Source

pub fn rope_shader() -> &'static str

Source

pub fn new( device: Device, queue: Queue, hidden_dim: usize, num_heads: usize, num_kv_heads: usize, head_dim: usize, intermediate_dim: usize, ) -> Self

PMAT-325: Create a new WGSL forward pass context.

Compiles all shader pipelines and allocates persistent intermediate buffers. Call once at model init. All GPU resources persist until dropped.

Source

pub fn upload_weight(&mut self, name: &str, data: &[f32])

Upload a weight matrix (call once per layer at init). PMAT-342: Bias weights (name contains “bias”) are stored CPU-side.

Source

pub fn upload_q4k_weight(&mut self, name: &str, data: &[u8])

GH-560: Upload raw Q4K weight bytes for fused dequant+GEMV on GPU.

Source

pub fn init_kv_cache(&mut self, num_layers: usize)

GH-560: Initialize per-layer KV cache buffers on GPU.

Source

pub fn weight_count(&self) -> usize

Number of uploaded weight buffers.

Source

pub fn weight_buffer(&self, name: &str) -> Option<&Buffer>

Access a dequantized weight buffer by name (e.g. “layer.0.down_proj”). Used by backward pass for gradient propagation through frozen base weights.

Source

pub fn device_ref(&self) -> &Device

Reference to the wgpu device.

Source

pub fn queue_ref(&self) -> &Queue

Reference to the wgpu queue.

Source

pub fn hidden_buffer(&self) -> &Buffer

Reference to the hidden state buffer (for writing input).

Source

pub fn q_buffer(&self) -> &Buffer

Reference to Q buffer (for LoRA addmm after Q projection).

Source

pub fn k_buffer(&self) -> &Buffer

Reference to K buffer.

Source

pub fn v_buffer(&self) -> &Buffer

Reference to V buffer.

Source

pub fn gpu_residual_add( &self, a: &Buffer, b: &Buffer, output: &Buffer, len: u32, )

Elementwise add: output = a + b. Dispatches residual add shader.

Source

pub fn gpu_rmsnorm(&self, weight: &Buffer, output: &Buffer, _seq_len: u32)

Apply RMSNorm on GPU: normed = rmsnorm(hidden_buf, weight) → output_buf. Contract: gpu-output-norm-v1 / gpu_resident — hidden state never leaves GPU.

Source

pub fn download_hidden(&self, len: usize) -> Vec<f32>

Download hidden state from GPU.

Source

pub fn total_vram_bytes(&self) -> usize

Total VRAM used by all buffers (bytes).

Source

pub fn forward_model( &self, token_id: u32, position: usize, num_layers: usize, token_embedding: &[f32], output_norm_weight: &[f32], lm_head_weight: &[f32], vocab_size: usize, eps: f32, kv_caches: &mut Vec<(Vec<f32>, Vec<f32>)>, ) -> Result<Vec<f32>, String>

PMAT-336: Full model forward — embedding + all layers + output norm + LM head.

Returns logits [vocab_size] for the given token at the given position. Embedding lookup and final LM head are CPU-side (not yet GPU-accelerated). PMAT-344: Added kv_caches for multi-token context

Source

pub fn forward_layer( &self, hidden: &mut [f32], layer_prefix: &str, _position: usize, kv_cache_k: &mut Vec<f32>, kv_cache_v: &mut Vec<f32>, ) -> Result<(), String>

PMAT-325: Execute one transformer layer — 14 passes, 1 submit, 1 readback.

Input: hidden state [hidden_dim] on CPU. Output: updated hidden state [hidden_dim] on CPU. All intermediate computation stays GPU-resident. PMAT-344: KV cache parameters for multi-token context

Source

pub fn encode_forward_layer_training( &self, encoder: &mut CommandEncoder, seq_len: u32, layer_prefix: &str, saved: &LayerActivations, lora: Option<&QkvLoRA<'_>>, ) -> Result<(), String>

Training forward pass for a single transformer layer.

Unlike forward_layer (M=1 decode), this processes the full sequence at once (M=seq_len) and keeps everything on GPU. No CPU readback.

Saves norm_output (pre-projection activations) for backward pass.

§Arguments
  • seq_len: number of tokens in the sequence
  • layer_prefix: e.g. “model.layers.0”
  • saved_norm_attn: OUTPUT — saved pre-attention norm for backward (wgpu::Buffer, [seq×hidden])
  • saved_norm_ffn: OUTPUT — saved pre-FFN norm for backward (wgpu::Buffer, [seq×hidden])

Forward one layer into an EXISTING encoder (no submit). Caller batches multiple layers into one encoder, submits once.

Source

pub fn forward_layer_traced( &self, seq_len: u32, layer_prefix: &str, saved: &LayerActivations, lora: Option<&QkvLoRA<'_>>, ) -> Result<(), String>

Run one layer with per-operation GPU timing (submit+poll between each op group). Contract: forward-pass-perf-v1 / bottleneck_identified

Source

pub fn alloc_layer_activations(&self, seq_len: u32) -> LayerActivations

Allocate saved activations for one layer.

Source

pub fn forward_layer_training( &self, seq_len: u32, layer_prefix: &str, ) -> Result<LayerActivations, String>

Forward one layer with its own encoder + submit (original API, kept for compat).

Source

pub fn forward_all_layers_training( &self, seq_len: u32, num_layers: usize, ) -> Result<Vec<LayerActivations>, String>

Forward ALL layers in one encoder submit. 28 layers → 1 GPU sync.

Source

pub fn encode_broadcast_bias( &self, encoder: &mut CommandEncoder, buf: &Buffer, bias: &[f32], seq_len: u32, )

Encode causal multi-head attention on GPU. Q: [seq_len, num_heads * head_dim], K/V: [seq_len, num_kv_heads * head_dim] Output written to q_buf (reused as attn output). PMAT-509: Add broadcast bias to a [seq_len, dim] buffer. bias has shape [dim], applied to each of seq_len rows.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> Conv for T

Source§

fn conv<T>(self) -> T
where Self: Into<T>,

Converts self into T using Into<T>. Read more
Source§

impl<T> Downcast<T> for T

Source§

fn downcast(&self) -> &T

Source§

impl<T> FmtForward for T

Source§

fn fmt_binary(self) -> FmtBinary<Self>
where Self: Binary,

Causes self to use its Binary implementation when Debug-formatted.
Source§

fn fmt_display(self) -> FmtDisplay<Self>
where Self: Display,

Causes self to use its Display implementation when Debug-formatted.
Source§

fn fmt_lower_exp(self) -> FmtLowerExp<Self>
where Self: LowerExp,

Causes self to use its LowerExp implementation when Debug-formatted.
Source§

fn fmt_lower_hex(self) -> FmtLowerHex<Self>
where Self: LowerHex,

Causes self to use its LowerHex implementation when Debug-formatted.
Source§

fn fmt_octal(self) -> FmtOctal<Self>
where Self: Octal,

Causes self to use its Octal implementation when Debug-formatted.
Source§

fn fmt_pointer(self) -> FmtPointer<Self>
where Self: Pointer,

Causes self to use its Pointer implementation when Debug-formatted.
Source§

fn fmt_upper_exp(self) -> FmtUpperExp<Self>
where Self: UpperExp,

Causes self to use its UpperExp implementation when Debug-formatted.
Source§

fn fmt_upper_hex(self) -> FmtUpperHex<Self>
where Self: UpperHex,

Causes self to use its UpperHex implementation when Debug-formatted.
Source§

fn fmt_list(self) -> FmtList<Self>
where &'a Self: for<'a> IntoIterator,

Formats each item in a sequence. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pipe for T
where T: ?Sized,

Source§

fn pipe<R>(self, func: impl FnOnce(Self) -> R) -> R
where Self: Sized,

Pipes by value. This is generally the method you want to use. Read more
Source§

fn pipe_ref<'a, R>(&'a self, func: impl FnOnce(&'a Self) -> R) -> R
where R: 'a,

Borrows self and passes that borrow into the pipe function. Read more
Source§

fn pipe_ref_mut<'a, R>(&'a mut self, func: impl FnOnce(&'a mut Self) -> R) -> R
where R: 'a,

Mutably borrows self and passes that borrow into the pipe function. Read more
Source§

fn pipe_borrow<'a, B, R>(&'a self, func: impl FnOnce(&'a B) -> R) -> R
where Self: Borrow<B>, B: 'a + ?Sized, R: 'a,

Borrows self, then passes self.borrow() into the pipe function. Read more
Source§

fn pipe_borrow_mut<'a, B, R>( &'a mut self, func: impl FnOnce(&'a mut B) -> R, ) -> R
where Self: BorrowMut<B>, B: 'a + ?Sized, R: 'a,

Mutably borrows self, then passes self.borrow_mut() into the pipe function. Read more
Source§

fn pipe_as_ref<'a, U, R>(&'a self, func: impl FnOnce(&'a U) -> R) -> R
where Self: AsRef<U>, U: 'a + ?Sized, R: 'a,

Borrows self, then passes self.as_ref() into the pipe function.
Source§

fn pipe_as_mut<'a, U, R>(&'a mut self, func: impl FnOnce(&'a mut U) -> R) -> R
where Self: AsMut<U>, U: 'a + ?Sized, R: 'a,

Mutably borrows self, then passes self.as_mut() into the pipe function.
Source§

fn pipe_deref<'a, T, R>(&'a self, func: impl FnOnce(&'a T) -> R) -> R
where Self: Deref<Target = T>, T: 'a + ?Sized, R: 'a,

Borrows self, then passes self.deref() into the pipe function.
Source§

fn pipe_deref_mut<'a, T, R>( &'a mut self, func: impl FnOnce(&'a mut T) -> R, ) -> R
where Self: DerefMut<Target = T> + Deref, T: 'a + ?Sized, R: 'a,

Mutably borrows self, then passes self.deref_mut() into the pipe function.
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> Tap for T

Source§

fn tap(self, func: impl FnOnce(&Self)) -> Self

Immutable access to a value. Read more
Source§

fn tap_mut(self, func: impl FnOnce(&mut Self)) -> Self

Mutable access to a value. Read more
Source§

fn tap_borrow<B>(self, func: impl FnOnce(&B)) -> Self
where Self: Borrow<B>, B: ?Sized,

Immutable access to the Borrow<B> of a value. Read more
Source§

fn tap_borrow_mut<B>(self, func: impl FnOnce(&mut B)) -> Self
where Self: BorrowMut<B>, B: ?Sized,

Mutable access to the BorrowMut<B> of a value. Read more
Source§

fn tap_ref<R>(self, func: impl FnOnce(&R)) -> Self
where Self: AsRef<R>, R: ?Sized,

Immutable access to the AsRef<R> view of a value. Read more
Source§

fn tap_ref_mut<R>(self, func: impl FnOnce(&mut R)) -> Self
where Self: AsMut<R>, R: ?Sized,

Mutable access to the AsMut<R> view of a value. Read more
Source§

fn tap_deref<T>(self, func: impl FnOnce(&T)) -> Self
where Self: Deref<Target = T>, T: ?Sized,

Immutable access to the Deref::Target of a value. Read more
Source§

fn tap_deref_mut<T>(self, func: impl FnOnce(&mut T)) -> Self
where Self: DerefMut<Target = T> + Deref, T: ?Sized,

Mutable access to the Deref::Target of a value. Read more
Source§

fn tap_dbg(self, func: impl FnOnce(&Self)) -> Self

Calls .tap() only in debug builds, and is erased in release builds.
Source§

fn tap_mut_dbg(self, func: impl FnOnce(&mut Self)) -> Self

Calls .tap_mut() only in debug builds, and is erased in release builds.
Source§

fn tap_borrow_dbg<B>(self, func: impl FnOnce(&B)) -> Self
where Self: Borrow<B>, B: ?Sized,

Calls .tap_borrow() only in debug builds, and is erased in release builds.
Source§

fn tap_borrow_mut_dbg<B>(self, func: impl FnOnce(&mut B)) -> Self
where Self: BorrowMut<B>, B: ?Sized,

Calls .tap_borrow_mut() only in debug builds, and is erased in release builds.
Source§

fn tap_ref_dbg<R>(self, func: impl FnOnce(&R)) -> Self
where Self: AsRef<R>, R: ?Sized,

Calls .tap_ref() only in debug builds, and is erased in release builds.
Source§

fn tap_ref_mut_dbg<R>(self, func: impl FnOnce(&mut R)) -> Self
where Self: AsMut<R>, R: ?Sized,

Calls .tap_ref_mut() only in debug builds, and is erased in release builds.
Source§

fn tap_deref_dbg<T>(self, func: impl FnOnce(&T)) -> Self
where Self: Deref<Target = T>, T: ?Sized,

Calls .tap_deref() only in debug builds, and is erased in release builds.
Source§

fn tap_deref_mut_dbg<T>(self, func: impl FnOnce(&mut T)) -> Self
where Self: DerefMut<Target = T> + Deref, T: ?Sized,

Calls .tap_deref_mut() only in debug builds, and is erased in release builds.
Source§

impl<T> TryConv for T

Source§

fn try_conv<T>(self) -> Result<T, Self::Error>
where Self: TryInto<T>,

Attempts to convert self into T using TryInto<T>. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<T> Upcast<T> for T

Source§

fn upcast(&self) -> Option<&T>

Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

impl<T> WasmNotSend for T
where T: Send,

Source§

impl<T> WasmNotSendSync for T

Source§

impl<T> WasmNotSync for T
where T: Sync,