#![allow(clippy::too_many_arguments)]
use std::sync::Arc;
use bytemuck::cast_slice;
use futures_channel::oneshot;
use crate::backend::dispatch::{
add_bias_batched_chained, block_local_attention_chained, clamp_chained,
depthwise_conv1d_chained, fence_submitted_work, glu_split_chained, half_residual_add_chained,
matmul_bf16_batched_chained, matmul_f16_batched_chained, rmsnorm_per_row_chained,
scale_chained, scale_per_inner_dim_chained, silu_chained,
};
use crate::backend::{Pipelines, WeightCache, WgpuCtx};
use crate::error::{Result, RullamaError};
use crate::gguf::{GgmlDtype, dequant_tensor_to_f32_async};
use crate::multimodal::audio::{AudioConfig, AudioPrefix};
const MAX_SEQ: usize = 768;
#[derive(Clone, Copy, Default)]
struct Clamp {
in_min: f32,
in_max: f32,
out_min: f32,
out_max: f32,
}
struct GpuAudioBlockMeta {
per_dim_scale: wgpu::Buffer,
conv_dw: wgpu::Buffer,
cl_attn_q: Clamp,
cl_attn_k: Clamp,
cl_attn_v: Clamp,
cl_attn_o: Clamp,
cl_ffw_up: Clamp,
cl_ffw_down: Clamp,
cl_ffw_up_1: Clamp,
cl_ffw_down_1: Clamp,
cl_conv_pw1: Clamp,
cl_conv_pw2: Clamp,
}
struct GpuAudioBlockWeights {
pre_norm: wgpu::Buffer, ffw_norm: wgpu::Buffer,
ffw_up: wgpu::Buffer, ffw_down: wgpu::Buffer, ffw_post_norm: wgpu::Buffer,
ffw_norm_1: wgpu::Buffer,
ffw_up_1: wgpu::Buffer, ffw_down_1: wgpu::Buffer, ffw_post_norm_1: wgpu::Buffer,
attn_pre_norm: wgpu::Buffer,
attn_post_norm: wgpu::Buffer,
attn_q: wgpu::Buffer, attn_k: wgpu::Buffer, attn_v: wgpu::Buffer, attn_o: wgpu::Buffer, linear_pos: wgpu::Buffer, conv_norm: wgpu::Buffer,
norm_conv: wgpu::Buffer,
conv_pw1: wgpu::Buffer, conv_pw2: wgpu::Buffer, }
struct Scratch {
h_main: wgpu::Buffer, residual: wgpu::Buffer, h_norm: wgpu::Buffer, ffw_h: wgpu::Buffer, ffw_out: wgpu::Buffer, pw1_out: wgpu::Buffer, glu_out: wgpu::Buffer, conv_dw_out: wgpu::Buffer, pw2_out: wgpu::Buffer, q_buf: wgpu::Buffer, k_padded: wgpu::Buffer, v_padded: wgpu::Buffer, pos_emb: wgpu::Buffer, pos_proj: wgpu::Buffer, attn_out: wgpu::Buffer, fc_out: wgpu::Buffer, fc_normed: wgpu::Buffer, soft: wgpu::Buffer, soft_read: wgpu::Buffer, }
pub struct GpuAudioForward {
cfg: AudioConfig,
ctx: WgpuCtx,
pipes: Arc<Pipelines>,
wcache: Arc<WeightCache>,
cpu_prefix: AudioPrefix,
blocks: Vec<GpuAudioBlockMeta>,
proj_fc: wgpu::Buffer, proj_fc_dtype: GgmlDtype,
proj_fc_bias: Option<wgpu::Buffer>, proj_input: wgpu::Buffer, proj_input_dtype: GgmlDtype,
scratch: Scratch,
}
impl GpuAudioForward {
pub async fn new(
cfg: AudioConfig,
ctx: WgpuCtx,
pipes: Arc<Pipelines>,
wcache: Arc<WeightCache>,
) -> Result<Self> {
let cpu_prefix = AudioPrefix::new(cfg.clone(), wcache.clone()).await?;
let device = &ctx.device;
let queue = &ctx.queue;
let hidden = cfg.hidden as usize;
let ffn = cfg.ffn_inter as usize;
let head_dim = cfg.head_dim() as usize;
let max_span = (cfg.max_past + cfg.max_future + 1) as usize;
let max_padded = MAX_SEQ;
let pad_left = cfg.max_past as usize;
let pad_right = (cfg.max_future + cfg.chunk_size - 1) as usize;
let max_k_padded = pad_left + max_padded + pad_right;
let d_text = cfg.d_text as usize;
let alloc_storage = |label: &str, n_f32: usize| -> wgpu::Buffer {
device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: (n_f32 * 4).max(4) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
};
let h_main = alloc_storage("aud.h_main", MAX_SEQ * hidden);
let residual = alloc_storage("aud.residual", MAX_SEQ * hidden);
let h_norm = alloc_storage("aud.h_norm", MAX_SEQ * hidden);
let ffw_h = alloc_storage("aud.ffw_h", MAX_SEQ * ffn);
let ffw_out = alloc_storage("aud.ffw_out", MAX_SEQ * hidden);
let pw1_out = alloc_storage("aud.pw1", MAX_SEQ * 2 * hidden);
let glu_out = alloc_storage("aud.glu", MAX_SEQ * hidden);
let conv_dw_out = alloc_storage("aud.dw_out", MAX_SEQ * hidden);
let pw2_out = alloc_storage("aud.pw2", MAX_SEQ * hidden);
let q_buf = alloc_storage("aud.q", max_padded * hidden);
let k_padded = alloc_storage("aud.k_padded", max_k_padded * hidden);
let v_padded = alloc_storage("aud.v_padded", max_k_padded * hidden);
let pos_emb = alloc_storage("aud.pos_emb", max_span * hidden);
let pos_proj = alloc_storage("aud.pos_proj", max_span * hidden);
let attn_out = alloc_storage("aud.attn_out", max_padded * hidden);
let fc_out = alloc_storage("aud.fc_out", MAX_SEQ * d_text);
let fc_normed = alloc_storage("aud.fc_normed", MAX_SEQ * d_text);
let soft = alloc_storage("aud.soft", MAX_SEQ * d_text);
let soft_read = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("aud.soft_read"),
size: (MAX_SEQ * d_text * 4) as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
{
let half_dim = hidden / 2;
let log_inc = (10000f32).ln() / (half_dim.saturating_sub(1)).max(1) as f32;
let mut pos_emb_cpu = vec![0f32; max_span * hidden];
for p in 0..max_span {
let rel_pos = (cfg.max_past as f32) - (p as f32);
for d in 0..half_dim {
let angle = rel_pos * (-(d as f32) * log_inc).exp();
pos_emb_cpu[p * hidden + d] = angle.sin();
pos_emb_cpu[p * hidden + half_dim + d] = angle.cos();
}
}
queue.write_buffer(&pos_emb, 0, cast_slice(&pos_emb_cpu));
}
let scratch = Scratch {
h_main,
residual,
h_norm,
ffw_h,
ffw_out,
pw1_out,
glu_out,
conv_dw_out,
pw2_out,
q_buf,
k_padded,
v_padded,
pos_emb,
pos_proj,
attn_out,
fc_out,
fc_normed,
soft,
soft_read,
};
let q_scale_base = (head_dim as f32).powf(-0.5) / std::f32::consts::LN_2;
let mut blocks = Vec::with_capacity(cfg.n_layers as usize);
for i in 0..cfg.n_layers {
blocks.push(load_gpu_block_meta(&wcache, i, &ctx, q_scale_base).await?);
}
let proj_fc = wcache.buffer_async("mm.a.fc.weight").await?;
let proj_fc_dtype = wcache.reader().tensor("mm.a.fc.weight")?.dtype;
let proj_fc_bias = wcache.buffer_opt_async("mm.a.fc.bias").await?;
let proj_input = wcache.buffer_async("mm.a.input_projection.weight").await?;
let proj_input_dtype = wcache
.reader()
.tensor("mm.a.input_projection.weight")?
.dtype;
Ok(Self {
cfg,
ctx,
pipes,
wcache,
cpu_prefix,
blocks,
proj_fc,
proj_fc_dtype,
proj_fc_bias,
proj_input,
proj_input_dtype,
scratch,
})
}
pub fn cfg(&self) -> &AudioConfig {
&self.cfg
}
pub async fn encode(
&self,
pcm: &[f32],
cancel: Option<Arc<std::sync::atomic::AtomicBool>>,
) -> Result<Vec<f32>> {
let (h_cpu, seq) = self.cpu_prefix.prefix_to_hidden(pcm)?;
if seq == 0 {
return Ok(Vec::new());
}
if seq > MAX_SEQ {
return Err(RullamaError::Inference(format!(
"audio: seq {seq} > MAX_SEQ {MAX_SEQ} (audio longer than 30 s)"
)));
}
let cfg = &self.cfg;
let hidden = cfg.hidden as usize;
let n_heads = cfg.n_heads as usize;
let head_dim = cfg.head_dim() as usize;
let chunk_size = cfg.chunk_size as usize;
let max_past = cfg.max_past as usize;
let max_future = cfg.max_future as usize;
let context_size = max_past + chunk_size + max_future;
let max_span = max_past + max_future + 1;
let pad_left = max_past;
let pad_right = max_future + chunk_size - 1;
let num_chunks = seq.div_ceil(chunk_size);
let padded_len = num_chunks * chunk_size;
let k_padded_len = pad_left + padded_len + pad_right;
let d_text = cfg.d_text as usize;
let logit_cap = cfg.logit_cap;
let k_scale = (1.0f32 + std::f32::consts::E).ln() / std::f32::consts::LN_2;
let queue = &self.ctx.queue;
queue.write_buffer(&self.scratch.h_main, 0, cast_slice(&h_cpu));
for b in 0..self.blocks.len() {
if let Some(c) = cancel.as_ref()
&& c.load(std::sync::atomic::Ordering::Relaxed)
{
return Err(RullamaError::Cancelled);
}
let w = fetch_gpu_block_weights(&self.wcache, b as u32).await?;
let mut benc =
self.ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("aud.block"),
});
self.dispatch_block(
&mut benc,
&self.blocks[b],
&w,
seq,
padded_len,
k_padded_len,
hidden,
n_heads,
head_dim,
chunk_size,
context_size,
max_span,
max_past,
max_future,
pad_left,
logit_cap,
k_scale,
);
self.ctx.queue.submit(Some(benc.finish()));
fence_submitted_work(&self.ctx.device, &self.ctx.queue).await?;
}
let mut enc = self
.ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("aud.epilogue"),
});
self.dispatch_projector(&mut enc, seq, hidden, d_text);
let read_bytes = (seq * d_text * 4) as u64;
enc.copy_buffer_to_buffer(
&self.scratch.soft,
0,
&self.scratch.soft_read,
0,
read_bytes,
);
self.ctx.queue.submit(Some(enc.finish()));
let slice = self.scratch.soft_read.slice(..read_bytes);
let (tx, rx) = oneshot::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
let _ = tx.send(r);
});
self.ctx
.device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.map_err(|e| RullamaError::Inference(format!("device.poll: {e}")))?;
rx.await
.map_err(|_| RullamaError::Inference("readback channel".into()))?
.map_err(|e| RullamaError::Inference(format!("map_async: {e:?}")))?;
let data = slice.get_mapped_range();
let out: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
self.scratch.soft_read.unmap();
Ok(out)
}
#[allow(clippy::too_many_arguments)]
fn dispatch_block(
&self,
enc: &mut wgpu::CommandEncoder,
meta: &GpuAudioBlockMeta,
w: &GpuAudioBlockWeights,
seq: usize,
padded_len: usize,
k_padded_len: usize,
hidden: usize,
n_heads: usize,
head_dim: usize,
chunk_size: usize,
context_size: usize,
max_span: usize,
max_past: usize,
max_future: usize,
pad_left: usize,
logit_cap: f32,
k_scale: f32,
) {
let cfg = &self.cfg;
let ffn = cfg.ffn_inter as usize;
let eps = cfg.eps;
let gc = cfg.grad_clip;
let s = &self.scratch;
let n_h = seq * hidden;
self.dispatch_ffw(
enc,
&w.ffw_norm,
&w.ffw_up,
&meta.cl_ffw_up,
&w.ffw_down,
&meta.cl_ffw_down,
&w.ffw_post_norm,
seq,
hidden,
ffn,
eps,
gc,
);
enc.copy_buffer_to_buffer(&s.h_main, 0, &s.residual, 0, (n_h * 4) as u64);
clamp_chained(&self.ctx, &self.pipes, enc, &s.h_main, n_h, -gc, gc);
rmsnorm_per_row_chained(
&self.ctx,
&self.pipes,
enc,
&s.h_main,
Some(&w.attn_pre_norm),
&s.h_main,
&s.h_norm,
seq,
hidden,
eps,
);
let cl_q = &meta.cl_attn_q;
if cl_q.in_max != 0.0 {
clamp_chained(
&self.ctx,
&self.pipes,
enc,
&s.h_norm,
n_h,
cl_q.in_min,
cl_q.in_max,
);
}
matmul_bf16_batched_chained(
&self.ctx,
&self.pipes,
enc,
&w.attn_q,
&s.h_norm,
&s.q_buf,
hidden,
hidden,
seq,
);
if cl_q.out_max != 0.0 {
clamp_chained(
&self.ctx,
&self.pipes,
enc,
&s.q_buf,
seq * hidden,
cl_q.out_min,
cl_q.out_max,
);
}
enc.clear_buffer(&s.k_padded, 0, Some((k_padded_len * hidden * 4) as u64));
enc.clear_buffer(&s.v_padded, 0, Some((k_padded_len * hidden * 4) as u64));
matmul_bf16_batched_chained(
&self.ctx,
&self.pipes,
enc,
&w.attn_k,
&s.h_norm,
&s.attn_out,
hidden,
hidden,
seq,
);
let cl_k = &meta.cl_attn_k;
if cl_k.out_max != 0.0 {
clamp_chained(
&self.ctx,
&self.pipes,
enc,
&s.attn_out,
seq * hidden,
cl_k.out_min,
cl_k.out_max,
);
}
scale_chained(
&self.ctx,
&self.pipes,
enc,
&s.attn_out,
seq * hidden,
k_scale,
);
enc.copy_buffer_to_buffer(
&s.attn_out,
0,
&s.k_padded,
(pad_left * hidden * 4) as u64,
(seq * hidden * 4) as u64,
);
matmul_bf16_batched_chained(
&self.ctx,
&self.pipes,
enc,
&w.attn_v,
&s.h_norm,
&s.attn_out,
hidden,
hidden,
seq,
);
let cl_v = &meta.cl_attn_v;
if cl_v.out_max != 0.0 {
clamp_chained(
&self.ctx,
&self.pipes,
enc,
&s.attn_out,
seq * hidden,
cl_v.out_min,
cl_v.out_max,
);
}
enc.copy_buffer_to_buffer(
&s.attn_out,
0,
&s.v_padded,
(pad_left * hidden * 4) as u64,
(seq * hidden * 4) as u64,
);
scale_per_inner_dim_chained(
&self.ctx,
&self.pipes,
enc,
&s.q_buf,
&meta.per_dim_scale,
seq * hidden,
head_dim,
);
matmul_bf16_batched_chained(
&self.ctx,
&self.pipes,
enc,
&w.linear_pos,
&s.pos_emb,
&s.pos_proj,
hidden,
hidden,
max_span,
);
if padded_len > seq {
enc.clear_buffer(
&s.q_buf,
(seq * hidden * 4) as u64,
Some(((padded_len - seq) * hidden * 4) as u64),
);
}
block_local_attention_chained(
&self.ctx,
&self.pipes,
enc,
&s.q_buf,
&s.k_padded,
&s.v_padded,
&s.pos_proj,
&s.attn_out,
seq,
padded_len,
hidden,
n_heads,
head_dim,
chunk_size,
context_size,
max_span,
max_past,
max_future,
pad_left,
logit_cap,
);
let cl_o = &meta.cl_attn_o;
if cl_o.in_max != 0.0 {
clamp_chained(
&self.ctx,
&self.pipes,
enc,
&s.attn_out,
seq * hidden,
cl_o.in_min,
cl_o.in_max,
);
}
matmul_bf16_batched_chained(
&self.ctx,
&self.pipes,
enc,
&w.attn_o,
&s.attn_out,
&s.ffw_out,
hidden,
hidden,
seq,
);
if cl_o.out_max != 0.0 {
clamp_chained(
&self.ctx,
&self.pipes,
enc,
&s.ffw_out,
seq * hidden,
cl_o.out_min,
cl_o.out_max,
);
}
clamp_chained(
&self.ctx,
&self.pipes,
enc,
&s.ffw_out,
seq * hidden,
-gc,
gc,
);
rmsnorm_per_row_chained(
&self.ctx,
&self.pipes,
enc,
&s.ffw_out,
Some(&w.attn_post_norm),
&s.h_main,
&s.h_norm,
seq,
hidden,
eps,
);
enc.copy_buffer_to_buffer(&s.residual, 0, &s.h_main, 0, (n_h * 4) as u64);
crate::backend::dispatch::residual_add_chained(
&self.ctx,
&self.pipes,
enc,
&s.h_main,
&s.h_norm,
n_h,
);
self.dispatch_lightconv(enc, meta, w, seq, hidden, eps, gc);
self.dispatch_ffw(
enc,
&w.ffw_norm_1,
&w.ffw_up_1,
&meta.cl_ffw_up_1,
&w.ffw_down_1,
&meta.cl_ffw_down_1,
&w.ffw_post_norm_1,
seq,
hidden,
ffn,
eps,
gc,
);
clamp_chained(&self.ctx, &self.pipes, enc, &s.h_main, n_h, -gc, gc);
rmsnorm_per_row_chained(
&self.ctx,
&self.pipes,
enc,
&s.h_main,
Some(&w.pre_norm),
&s.h_main,
&s.ffw_out,
seq,
hidden,
eps,
);
enc.copy_buffer_to_buffer(&s.ffw_out, 0, &s.h_main, 0, (n_h * 4) as u64);
}
#[allow(clippy::too_many_arguments)]
fn dispatch_ffw(
&self,
enc: &mut wgpu::CommandEncoder,
norm_w: &wgpu::Buffer,
up_w: &wgpu::Buffer,
up_clamp: &Clamp,
down_w: &wgpu::Buffer,
down_clamp: &Clamp,
post_norm_w: &wgpu::Buffer,
seq: usize,
hidden: usize,
ffn: usize,
eps: f32,
gc: f32,
) {
let s = &self.scratch;
let n_h = seq * hidden;
let n_f = seq * ffn;
enc.copy_buffer_to_buffer(&s.h_main, 0, &s.residual, 0, (n_h * 4) as u64);
clamp_chained(&self.ctx, &self.pipes, enc, &s.h_main, n_h, -gc, gc);
rmsnorm_per_row_chained(
&self.ctx,
&self.pipes,
enc,
&s.h_main,
Some(norm_w),
&s.h_main,
&s.h_norm,
seq,
hidden,
eps,
);
if up_clamp.in_max != 0.0 {
clamp_chained(
&self.ctx,
&self.pipes,
enc,
&s.h_norm,
n_h,
up_clamp.in_min,
up_clamp.in_max,
);
}
matmul_bf16_batched_chained(
&self.ctx,
&self.pipes,
enc,
up_w,
&s.h_norm,
&s.ffw_h,
hidden,
ffn,
seq,
);
if up_clamp.out_max != 0.0 {
clamp_chained(
&self.ctx,
&self.pipes,
enc,
&s.ffw_h,
n_f,
up_clamp.out_min,
up_clamp.out_max,
);
}
silu_chained(&self.ctx, &self.pipes, enc, &s.ffw_h, n_f);
if down_clamp.in_max != 0.0 {
clamp_chained(
&self.ctx,
&self.pipes,
enc,
&s.ffw_h,
n_f,
down_clamp.in_min,
down_clamp.in_max,
);
}
matmul_bf16_batched_chained(
&self.ctx,
&self.pipes,
enc,
down_w,
&s.ffw_h,
&s.ffw_out,
ffn,
hidden,
seq,
);
if down_clamp.out_max != 0.0 {
clamp_chained(
&self.ctx,
&self.pipes,
enc,
&s.ffw_out,
n_h,
down_clamp.out_min,
down_clamp.out_max,
);
}
clamp_chained(&self.ctx, &self.pipes, enc, &s.ffw_out, n_h, -gc, gc);
rmsnorm_per_row_chained(
&self.ctx,
&self.pipes,
enc,
&s.ffw_out,
Some(post_norm_w),
&s.ffw_out,
&s.h_norm,
seq,
hidden,
eps,
);
half_residual_add_chained(&self.ctx, &self.pipes, enc, &s.residual, &s.h_norm, n_h);
enc.copy_buffer_to_buffer(&s.residual, 0, &s.h_main, 0, (n_h * 4) as u64);
}
fn dispatch_lightconv(
&self,
enc: &mut wgpu::CommandEncoder,
meta: &GpuAudioBlockMeta,
w: &GpuAudioBlockWeights,
seq: usize,
hidden: usize,
eps: f32,
gc: f32,
) {
let s = &self.scratch;
let n_h = seq * hidden;
let n_2h = seq * 2 * hidden;
let kernel = self.cfg.conv_kernel as usize;
enc.copy_buffer_to_buffer(&s.h_main, 0, &s.residual, 0, (n_h * 4) as u64);
rmsnorm_per_row_chained(
&self.ctx,
&self.pipes,
enc,
&s.h_main,
Some(&w.conv_norm),
&s.h_main,
&s.h_norm,
seq,
hidden,
eps,
);
let cl_pw1 = &meta.cl_conv_pw1;
if cl_pw1.in_max != 0.0 {
clamp_chained(
&self.ctx,
&self.pipes,
enc,
&s.h_norm,
n_h,
cl_pw1.in_min,
cl_pw1.in_max,
);
}
matmul_bf16_batched_chained(
&self.ctx,
&self.pipes,
enc,
&w.conv_pw1,
&s.h_norm,
&s.pw1_out,
hidden,
2 * hidden,
seq,
);
if cl_pw1.out_max != 0.0 {
clamp_chained(
&self.ctx,
&self.pipes,
enc,
&s.pw1_out,
n_2h,
cl_pw1.out_min,
cl_pw1.out_max,
);
}
glu_split_chained(
&self.ctx,
&self.pipes,
enc,
&s.pw1_out,
&s.glu_out,
seq,
hidden,
);
depthwise_conv1d_chained(
&self.ctx,
&self.pipes,
enc,
&s.glu_out,
&meta.conv_dw,
&s.conv_dw_out,
seq,
hidden,
kernel,
);
clamp_chained(&self.ctx, &self.pipes, enc, &s.conv_dw_out, n_h, -gc, gc);
rmsnorm_per_row_chained(
&self.ctx,
&self.pipes,
enc,
&s.conv_dw_out,
Some(&w.norm_conv),
&s.conv_dw_out,
&s.h_norm,
seq,
hidden,
eps,
);
silu_chained(&self.ctx, &self.pipes, enc, &s.h_norm, n_h);
let cl_pw2 = &meta.cl_conv_pw2;
if cl_pw2.in_max != 0.0 {
clamp_chained(
&self.ctx,
&self.pipes,
enc,
&s.h_norm,
n_h,
cl_pw2.in_min,
cl_pw2.in_max,
);
}
matmul_bf16_batched_chained(
&self.ctx,
&self.pipes,
enc,
&w.conv_pw2,
&s.h_norm,
&s.pw2_out,
hidden,
hidden,
seq,
);
if cl_pw2.out_max != 0.0 {
clamp_chained(
&self.ctx,
&self.pipes,
enc,
&s.pw2_out,
n_h,
cl_pw2.out_min,
cl_pw2.out_max,
);
}
enc.copy_buffer_to_buffer(&s.residual, 0, &s.h_main, 0, (n_h * 4) as u64);
crate::backend::dispatch::residual_add_chained(
&self.ctx,
&self.pipes,
enc,
&s.h_main,
&s.pw2_out,
n_h,
);
}
fn dispatch_projector(
&self,
enc: &mut wgpu::CommandEncoder,
seq: usize,
hidden: usize,
d_text: usize,
) {
let s = &self.scratch;
let eps = self.cfg.eps;
match self.proj_fc_dtype {
GgmlDtype::F16 => matmul_f16_batched_chained(
&self.ctx,
&self.pipes,
enc,
&self.proj_fc,
&s.h_main,
&s.fc_out,
hidden,
d_text,
seq,
),
GgmlDtype::BF16 => matmul_bf16_batched_chained(
&self.ctx,
&self.pipes,
enc,
&self.proj_fc,
&s.h_main,
&s.fc_out,
hidden,
d_text,
seq,
),
other => panic!("audio projector FC dtype {other:?} not supported"),
}
if let Some(bias) = self.proj_fc_bias.as_ref() {
add_bias_batched_chained(&self.ctx, &self.pipes, enc, &s.fc_out, bias, d_text, seq);
}
rmsnorm_per_row_chained(
&self.ctx,
&self.pipes,
enc,
&s.fc_out,
None,
&s.fc_out,
&s.fc_normed,
seq,
d_text,
eps,
);
match self.proj_input_dtype {
GgmlDtype::F16 => matmul_f16_batched_chained(
&self.ctx,
&self.pipes,
enc,
&self.proj_input,
&s.fc_normed,
&s.soft,
d_text,
d_text,
seq,
),
GgmlDtype::BF16 => matmul_bf16_batched_chained(
&self.ctx,
&self.pipes,
enc,
&self.proj_input,
&s.fc_normed,
&s.soft,
d_text,
d_text,
seq,
),
other => panic!("audio projector input dtype {other:?} not supported"),
}
}
}
async fn load_gpu_block_meta(
wcache: &Arc<WeightCache>,
i: u32,
ctx: &WgpuCtx,
q_scale_base: f32,
) -> Result<GpuAudioBlockMeta> {
let p = format!("a.blk.{i}.");
let r = wcache.reader();
let per_dim_scale_cpu =
dequant_tensor_to_f32_async(r, &format!("{p}per_dim_scale.weight")).await?;
let scaled: Vec<f32> = per_dim_scale_cpu
.iter()
.map(|&v| v * q_scale_base)
.collect();
let per_dim_scale_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("aud.per_dim_scale"),
size: (scaled.len() * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue
.write_buffer(&per_dim_scale_buf, 0, cast_slice(&scaled));
let conv_dw_cpu = dequant_tensor_to_f32_async(r, &format!("{p}conv_dw.weight")).await?;
let conv_dw_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("aud.conv_dw"),
size: (conv_dw_cpu.len() * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue
.write_buffer(&conv_dw_buf, 0, cast_slice(&conv_dw_cpu));
Ok(GpuAudioBlockMeta {
per_dim_scale: per_dim_scale_buf,
conv_dw: conv_dw_buf,
cl_attn_q: load_clamp(wcache, &format!("{p}attn_q")).await,
cl_attn_k: load_clamp(wcache, &format!("{p}attn_k")).await,
cl_attn_v: load_clamp(wcache, &format!("{p}attn_v")).await,
cl_attn_o: load_clamp(wcache, &format!("{p}attn_out")).await,
cl_ffw_up: load_clamp(wcache, &format!("{p}ffn_up")).await,
cl_ffw_down: load_clamp(wcache, &format!("{p}ffn_down")).await,
cl_ffw_up_1: load_clamp(wcache, &format!("{p}ffn_up_1")).await,
cl_ffw_down_1: load_clamp(wcache, &format!("{p}ffn_down_1")).await,
cl_conv_pw1: load_clamp(wcache, &format!("{p}conv_pw1")).await,
cl_conv_pw2: load_clamp(wcache, &format!("{p}conv_pw2")).await,
})
}
async fn fetch_gpu_block_weights(
wcache: &Arc<WeightCache>,
i: u32,
) -> Result<GpuAudioBlockWeights> {
let p = format!("a.blk.{i}.");
let names = [
"layer_pre_norm.weight",
"ffn_norm.weight",
"ffn_up.weight",
"ffn_down.weight",
"ffn_post_norm.weight",
"ffn_norm_1.weight",
"ffn_up_1.weight",
"ffn_down_1.weight",
"ffn_post_norm_1.weight",
"ln1.weight",
"ln2.weight",
"attn_q.weight",
"attn_k.weight",
"attn_v.weight",
"attn_out.weight",
"linear_pos.weight",
"conv_norm.weight",
"norm_conv.weight",
"conv_pw1.weight",
"conv_pw2.weight",
];
let buffers: Vec<wgpu::Buffer> = futures_util::future::try_join_all(names.iter().map(|n| {
let full = format!("{p}{n}");
async move { wcache.buffer_async(&full).await }
}))
.await?;
let mut it = buffers.into_iter();
Ok(GpuAudioBlockWeights {
pre_norm: it.next().unwrap(),
ffw_norm: it.next().unwrap(),
ffw_up: it.next().unwrap(),
ffw_down: it.next().unwrap(),
ffw_post_norm: it.next().unwrap(),
ffw_norm_1: it.next().unwrap(),
ffw_up_1: it.next().unwrap(),
ffw_down_1: it.next().unwrap(),
ffw_post_norm_1: it.next().unwrap(),
attn_pre_norm: it.next().unwrap(),
attn_post_norm: it.next().unwrap(),
attn_q: it.next().unwrap(),
attn_k: it.next().unwrap(),
attn_v: it.next().unwrap(),
attn_o: it.next().unwrap(),
linear_pos: it.next().unwrap(),
conv_norm: it.next().unwrap(),
norm_conv: it.next().unwrap(),
conv_pw1: it.next().unwrap(),
conv_pw2: it.next().unwrap(),
})
}
async fn load_clamp(wcache: &Arc<WeightCache>, prefix: &str) -> Clamp {
let one = |suffix: &str| {
let name = format!("{prefix}.{suffix}");
async move {
match wcache.reader().tensor(&name) {
Ok(_) => dequant_tensor_to_f32_async(wcache.reader(), &name)
.await
.ok()
.and_then(|v| v.first().copied())
.unwrap_or(0.0),
Err(_) => 0.0,
}
}
};
Clamp {
in_min: one("input_min").await,
in_max: one("input_max").await,
out_min: one("output_min").await,
out_max: one("output_max").await,
}
}