use std::sync::Arc;
use bytemuck::cast_slice;
use futures_channel::oneshot;
use crate::backend::dispatch::{
add_bias_batched_chained, block_local_attention_chained, clamp_chained,
depthwise_conv1d_chained, glu_split_chained, half_residual_add_chained,
matmul_bf16_batched_chained, matmul_f16_batched_chained,
rmsnorm_per_row_chained, scale_chained, scale_per_inner_dim_chained,
silu_chained,
};
use crate::backend::{Pipelines, WeightCache, WgpuCtx};
use crate::error::{Result, RullamaError};
use crate::gguf::{dequant_tensor_to_f32_async, GgmlDtype};
use crate::multimodal::audio::{AudioConfig, AudioPrefix};
const MAX_SEQ: usize = 768;
#[derive(Clone, Copy, Default)]
struct Clamp { in_min: f32, in_max: f32, out_min: f32, out_max: f32 }
struct GpuAudioBlockMeta {
per_dim_scale: wgpu::Buffer,
conv_dw: wgpu::Buffer,
cl_attn_q: Clamp,
cl_attn_k: Clamp,
cl_attn_v: Clamp,
cl_attn_o: Clamp,
cl_ffw_up: Clamp,
cl_ffw_down: Clamp,
cl_ffw_up_1: Clamp,
cl_ffw_down_1: Clamp,
cl_conv_pw1: Clamp,
cl_conv_pw2: Clamp,
}
struct GpuAudioBlockWeights {
pre_norm: wgpu::Buffer, ffw_norm: wgpu::Buffer,
ffw_up: wgpu::Buffer, ffw_down: wgpu::Buffer, ffw_post_norm: wgpu::Buffer,
ffw_norm_1: wgpu::Buffer,
ffw_up_1: wgpu::Buffer, ffw_down_1: wgpu::Buffer, ffw_post_norm_1: wgpu::Buffer,
attn_pre_norm: wgpu::Buffer,
attn_post_norm: wgpu::Buffer,
attn_q: wgpu::Buffer, attn_k: wgpu::Buffer, attn_v: wgpu::Buffer, attn_o: wgpu::Buffer, linear_pos: wgpu::Buffer, conv_norm: wgpu::Buffer,
norm_conv: wgpu::Buffer,
conv_pw1: wgpu::Buffer, conv_pw2: wgpu::Buffer, }
struct Scratch {
h_main: wgpu::Buffer, residual: wgpu::Buffer, h_norm: wgpu::Buffer, ffw_h: wgpu::Buffer, ffw_out: wgpu::Buffer, pw1_out: wgpu::Buffer, glu_out: wgpu::Buffer, conv_dw_out: wgpu::Buffer, pw2_out: wgpu::Buffer, q_buf: wgpu::Buffer, k_padded: wgpu::Buffer, v_padded: wgpu::Buffer, pos_emb: wgpu::Buffer, pos_proj: wgpu::Buffer, attn_out: wgpu::Buffer, fc_out: wgpu::Buffer, fc_normed: wgpu::Buffer, soft: wgpu::Buffer, soft_read: wgpu::Buffer, }
pub struct GpuAudioForward {
cfg: AudioConfig,
ctx: WgpuCtx,
pipes: Arc<Pipelines>,
wcache: Arc<WeightCache>,
cpu_prefix: AudioPrefix,
blocks: Vec<GpuAudioBlockMeta>,
proj_fc: wgpu::Buffer, proj_fc_dtype: GgmlDtype,
proj_fc_bias: Option<wgpu::Buffer>, proj_input: wgpu::Buffer, proj_input_dtype: GgmlDtype,
scratch: Scratch,
}
impl GpuAudioForward {
pub async fn new(
cfg: AudioConfig,
ctx: WgpuCtx,
pipes: Arc<Pipelines>,
wcache: Arc<WeightCache>,
) -> Result<Self> {
let cpu_prefix = AudioPrefix::new(cfg.clone(), wcache.clone()).await?;
let device = &ctx.device;
let queue = &ctx.queue;
let hidden = cfg.hidden as usize;
let ffn = cfg.ffn_inter as usize;
let head_dim = cfg.head_dim() as usize;
let max_span = (cfg.max_past + cfg.max_future + 1) as usize;
let max_padded = MAX_SEQ;
let pad_left = cfg.max_past as usize;
let pad_right = (cfg.max_future + cfg.chunk_size - 1) as usize;
let max_k_padded = pad_left + max_padded + pad_right;
let d_text = cfg.d_text as usize;
let alloc_storage = |label: &str, n_f32: usize| -> wgpu::Buffer {
device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: (n_f32 * 4).max(4) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
};
let h_main = alloc_storage("aud.h_main", MAX_SEQ * hidden);
let residual = alloc_storage("aud.residual", MAX_SEQ * hidden);
let h_norm = alloc_storage("aud.h_norm", MAX_SEQ * hidden);
let ffw_h = alloc_storage("aud.ffw_h", MAX_SEQ * ffn);
let ffw_out = alloc_storage("aud.ffw_out", MAX_SEQ * hidden);
let pw1_out = alloc_storage("aud.pw1", MAX_SEQ * 2 * hidden);
let glu_out = alloc_storage("aud.glu", MAX_SEQ * hidden);
let conv_dw_out = alloc_storage("aud.dw_out", MAX_SEQ * hidden);
let pw2_out = alloc_storage("aud.pw2", MAX_SEQ * hidden);
let q_buf = alloc_storage("aud.q", max_padded * hidden);
let k_padded = alloc_storage("aud.k_padded", max_k_padded * hidden);
let v_padded = alloc_storage("aud.v_padded", max_k_padded * hidden);
let pos_emb = alloc_storage("aud.pos_emb", max_span * hidden);
let pos_proj = alloc_storage("aud.pos_proj", max_span * hidden);
let attn_out = alloc_storage("aud.attn_out", max_padded * hidden);
let fc_out = alloc_storage("aud.fc_out", MAX_SEQ * d_text);
let fc_normed = alloc_storage("aud.fc_normed", MAX_SEQ * d_text);
let soft = alloc_storage("aud.soft", MAX_SEQ * d_text);
let soft_read = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("aud.soft_read"),
size: (MAX_SEQ * d_text * 4) as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
{
let half_dim = hidden / 2;
let log_inc = (10000f32).ln() / (half_dim.saturating_sub(1)).max(1) as f32;
let mut pos_emb_cpu = vec![0f32; max_span * hidden];
for p in 0..max_span {
let rel_pos = (cfg.max_past as f32) - (p as f32);
for d in 0..half_dim {
let angle = rel_pos * (-(d as f32) * log_inc).exp();
pos_emb_cpu[p * hidden + d] = angle.sin();
pos_emb_cpu[p * hidden + half_dim + d] = angle.cos();
}
}
queue.write_buffer(&pos_emb, 0, cast_slice(&pos_emb_cpu));
}
let scratch = Scratch {
h_main, residual, h_norm, ffw_h, ffw_out,
pw1_out, glu_out, conv_dw_out, pw2_out,
q_buf, k_padded, v_padded, pos_emb, pos_proj, attn_out,
fc_out, fc_normed, soft, soft_read,
};
let q_scale_base = (head_dim as f32).powf(-0.5) / std::f32::consts::LN_2;
let mut blocks = Vec::with_capacity(cfg.n_layers as usize);
for i in 0..cfg.n_layers {
blocks.push(load_gpu_block_meta(&wcache, i, &ctx, q_scale_base).await?);
}
let proj_fc = wcache.buffer_async("mm.a.fc.weight").await?;
let proj_fc_dtype = wcache.reader().tensor("mm.a.fc.weight")?.dtype;
let proj_fc_bias = wcache.buffer_opt_async("mm.a.fc.bias").await?;
let proj_input = wcache.buffer_async("mm.a.input_projection.weight").await?;
let proj_input_dtype = wcache.reader().tensor("mm.a.input_projection.weight")?.dtype;
Ok(Self {
cfg, ctx, pipes, wcache,
cpu_prefix,
blocks,
proj_fc, proj_fc_dtype, proj_fc_bias,
proj_input, proj_input_dtype,
scratch,
})
}
pub fn cfg(&self) -> &AudioConfig { &self.cfg }
pub async fn encode(&self, pcm: &[f32]) -> Result<Vec<f32>> {
let (h_cpu, seq) = self.cpu_prefix.prefix_to_hidden(pcm)?;
if seq == 0 { return Ok(Vec::new()); }
if seq > MAX_SEQ {
return Err(RullamaError::Inference(format!(
"audio: seq {seq} > MAX_SEQ {MAX_SEQ} (audio longer than 30 s)"
)));
}
let cfg = &self.cfg;
let hidden = cfg.hidden as usize;
let n_heads = cfg.n_heads as usize;
let head_dim = cfg.head_dim() as usize;
let chunk_size = cfg.chunk_size as usize;
let max_past = cfg.max_past as usize;
let max_future = cfg.max_future as usize;
let context_size = max_past + chunk_size + max_future;
let max_span = max_past + max_future + 1;
let pad_left = max_past;
let pad_right = max_future + chunk_size - 1;
let num_chunks = seq.div_ceil(chunk_size);
let padded_len = num_chunks * chunk_size;
let k_padded_len = pad_left + padded_len + pad_right;
let d_text = cfg.d_text as usize;
let logit_cap = cfg.logit_cap;
let k_scale = (1.0f32 + std::f32::consts::E).ln() / std::f32::consts::LN_2;
let queue = &self.ctx.queue;
queue.write_buffer(&self.scratch.h_main, 0, cast_slice(&h_cpu));
for b in 0..self.blocks.len() {
let w = fetch_gpu_block_weights(&self.wcache, b as u32).await?;
let mut enc = self.ctx.device.create_command_encoder(
&wgpu::CommandEncoderDescriptor { label: Some(&format!("aud.block{b}")) }
);
self.dispatch_block(
&mut enc, &self.blocks[b], &w,
seq, padded_len, k_padded_len,
hidden, n_heads, head_dim, chunk_size,
context_size, max_span, max_past, max_future,
pad_left, logit_cap, k_scale,
);
self.ctx.queue.submit(Some(enc.finish()));
drop(w);
}
let mut enc = self.ctx.device.create_command_encoder(
&wgpu::CommandEncoderDescriptor { label: Some("aud.projector") }
);
self.dispatch_projector(&mut enc, seq, hidden, d_text);
let read_bytes = (seq * d_text * 4) as u64;
enc.copy_buffer_to_buffer(&self.scratch.soft, 0, &self.scratch.soft_read, 0, read_bytes);
self.ctx.queue.submit(Some(enc.finish()));
let slice = self.scratch.soft_read.slice(..read_bytes);
let (tx, rx) = oneshot::channel();
slice.map_async(wgpu::MapMode::Read, move |r| { let _ = tx.send(r); });
self.ctx.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None })
.map_err(|e| RullamaError::Inference(format!("device.poll: {e}")))?;
rx.await
.map_err(|_| RullamaError::Inference("readback channel".into()))?
.map_err(|e| RullamaError::Inference(format!("map_async: {e:?}")))?;
let data = slice.get_mapped_range();
let out: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
self.scratch.soft_read.unmap();
Ok(out)
}
#[allow(clippy::too_many_arguments)]
fn dispatch_block(
&self,
enc: &mut wgpu::CommandEncoder,
meta: &GpuAudioBlockMeta,
w: &GpuAudioBlockWeights,
seq: usize, padded_len: usize, k_padded_len: usize,
hidden: usize, n_heads: usize, head_dim: usize, chunk_size: usize,
context_size: usize, max_span: usize, max_past: usize, max_future: usize,
pad_left: usize, logit_cap: f32, k_scale: f32,
) {
let cfg = &self.cfg;
let ffn = cfg.ffn_inter as usize;
let eps = cfg.eps;
let gc = cfg.grad_clip;
let s = &self.scratch;
let n_h = seq * hidden;
self.dispatch_ffw(
enc, &w.ffw_norm,
&w.ffw_up, &meta.cl_ffw_up,
&w.ffw_down, &meta.cl_ffw_down,
&w.ffw_post_norm,
seq, hidden, ffn, eps, gc,
);
enc.copy_buffer_to_buffer(&s.h_main, 0, &s.residual, 0, (n_h * 4) as u64);
clamp_chained(&self.ctx, &self.pipes, enc, &s.h_main, n_h, -gc, gc);
rmsnorm_per_row_chained(
&self.ctx, &self.pipes, enc,
&s.h_main, Some(&w.attn_pre_norm), &s.h_main,
&s.h_norm, seq, hidden, eps,
);
let cl_q = &meta.cl_attn_q;
if cl_q.in_max != 0.0 {
clamp_chained(&self.ctx, &self.pipes, enc, &s.h_norm, n_h, cl_q.in_min, cl_q.in_max);
}
matmul_bf16_batched_chained(
&self.ctx, &self.pipes, enc,
&w.attn_q, &s.h_norm, &s.q_buf,
hidden, hidden, seq,
);
if cl_q.out_max != 0.0 {
clamp_chained(&self.ctx, &self.pipes, enc, &s.q_buf,
seq * hidden, cl_q.out_min, cl_q.out_max);
}
enc.clear_buffer(&s.k_padded, 0, Some((k_padded_len * hidden * 4) as u64));
enc.clear_buffer(&s.v_padded, 0, Some((k_padded_len * hidden * 4) as u64));
matmul_bf16_batched_chained(
&self.ctx, &self.pipes, enc,
&w.attn_k, &s.h_norm, &s.attn_out,
hidden, hidden, seq,
);
let cl_k = &meta.cl_attn_k;
if cl_k.out_max != 0.0 {
clamp_chained(&self.ctx, &self.pipes, enc, &s.attn_out,
seq * hidden, cl_k.out_min, cl_k.out_max);
}
scale_chained(&self.ctx, &self.pipes, enc, &s.attn_out, seq * hidden, k_scale);
enc.copy_buffer_to_buffer(
&s.attn_out, 0,
&s.k_padded, (pad_left * hidden * 4) as u64,
(seq * hidden * 4) as u64,
);
matmul_bf16_batched_chained(
&self.ctx, &self.pipes, enc,
&w.attn_v, &s.h_norm, &s.attn_out,
hidden, hidden, seq,
);
let cl_v = &meta.cl_attn_v;
if cl_v.out_max != 0.0 {
clamp_chained(&self.ctx, &self.pipes, enc, &s.attn_out,
seq * hidden, cl_v.out_min, cl_v.out_max);
}
enc.copy_buffer_to_buffer(
&s.attn_out, 0,
&s.v_padded, (pad_left * hidden * 4) as u64,
(seq * hidden * 4) as u64,
);
scale_per_inner_dim_chained(
&self.ctx, &self.pipes, enc,
&s.q_buf, &meta.per_dim_scale,
seq * hidden, head_dim,
);
matmul_bf16_batched_chained(
&self.ctx, &self.pipes, enc,
&w.linear_pos, &s.pos_emb, &s.pos_proj,
hidden, hidden, max_span,
);
if padded_len > seq {
enc.clear_buffer(
&s.q_buf,
(seq * hidden * 4) as u64,
Some(((padded_len - seq) * hidden * 4) as u64),
);
}
block_local_attention_chained(
&self.ctx, &self.pipes, enc,
&s.q_buf, &s.k_padded, &s.v_padded, &s.pos_proj, &s.attn_out,
seq, padded_len, hidden, n_heads, head_dim,
chunk_size, context_size, max_span,
max_past, max_future, pad_left, logit_cap,
);
let cl_o = &meta.cl_attn_o;
if cl_o.in_max != 0.0 {
clamp_chained(&self.ctx, &self.pipes, enc, &s.attn_out,
seq * hidden, cl_o.in_min, cl_o.in_max);
}
matmul_bf16_batched_chained(
&self.ctx, &self.pipes, enc,
&w.attn_o, &s.attn_out, &s.ffw_out,
hidden, hidden, seq,
);
if cl_o.out_max != 0.0 {
clamp_chained(&self.ctx, &self.pipes, enc, &s.ffw_out,
seq * hidden, cl_o.out_min, cl_o.out_max);
}
clamp_chained(&self.ctx, &self.pipes, enc, &s.ffw_out, seq * hidden, -gc, gc);
rmsnorm_per_row_chained(
&self.ctx, &self.pipes, enc,
&s.ffw_out, Some(&w.attn_post_norm), &s.h_main,
&s.h_norm, seq, hidden, eps,
);
enc.copy_buffer_to_buffer(&s.residual, 0, &s.h_main, 0, (n_h * 4) as u64);
crate::backend::dispatch::residual_add_chained(
&self.ctx, &self.pipes, enc,
&s.h_main, &s.h_norm, n_h,
);
self.dispatch_lightconv(
enc, meta, w,
seq, hidden, eps, gc,
);
self.dispatch_ffw(
enc, &w.ffw_norm_1,
&w.ffw_up_1, &meta.cl_ffw_up_1,
&w.ffw_down_1, &meta.cl_ffw_down_1,
&w.ffw_post_norm_1,
seq, hidden, ffn, eps, gc,
);
clamp_chained(&self.ctx, &self.pipes, enc, &s.h_main, n_h, -gc, gc);
rmsnorm_per_row_chained(
&self.ctx, &self.pipes, enc,
&s.h_main, Some(&w.pre_norm), &s.h_main,
&s.ffw_out, seq, hidden, eps,
);
enc.copy_buffer_to_buffer(&s.ffw_out, 0, &s.h_main, 0, (n_h * 4) as u64);
}
#[allow(clippy::too_many_arguments)]
fn dispatch_ffw(
&self,
enc: &mut wgpu::CommandEncoder,
norm_w: &wgpu::Buffer,
up_w: &wgpu::Buffer, up_clamp: &Clamp,
down_w: &wgpu::Buffer, down_clamp: &Clamp,
post_norm_w: &wgpu::Buffer,
seq: usize, hidden: usize, ffn: usize, eps: f32, gc: f32,
) {
let s = &self.scratch;
let n_h = seq * hidden;
let n_f = seq * ffn;
enc.copy_buffer_to_buffer(&s.h_main, 0, &s.residual, 0, (n_h * 4) as u64);
clamp_chained(&self.ctx, &self.pipes, enc, &s.h_main, n_h, -gc, gc);
rmsnorm_per_row_chained(
&self.ctx, &self.pipes, enc,
&s.h_main, Some(norm_w), &s.h_main,
&s.h_norm, seq, hidden, eps,
);
if up_clamp.in_max != 0.0 {
clamp_chained(&self.ctx, &self.pipes, enc, &s.h_norm, n_h, up_clamp.in_min, up_clamp.in_max);
}
matmul_bf16_batched_chained(
&self.ctx, &self.pipes, enc,
up_w, &s.h_norm, &s.ffw_h,
hidden, ffn, seq,
);
if up_clamp.out_max != 0.0 {
clamp_chained(&self.ctx, &self.pipes, enc, &s.ffw_h, n_f, up_clamp.out_min, up_clamp.out_max);
}
silu_chained(&self.ctx, &self.pipes, enc, &s.ffw_h, n_f);
if down_clamp.in_max != 0.0 {
clamp_chained(&self.ctx, &self.pipes, enc, &s.ffw_h, n_f, down_clamp.in_min, down_clamp.in_max);
}
matmul_bf16_batched_chained(
&self.ctx, &self.pipes, enc,
down_w, &s.ffw_h, &s.ffw_out,
ffn, hidden, seq,
);
if down_clamp.out_max != 0.0 {
clamp_chained(&self.ctx, &self.pipes, enc, &s.ffw_out, n_h, down_clamp.out_min, down_clamp.out_max);
}
clamp_chained(&self.ctx, &self.pipes, enc, &s.ffw_out, n_h, -gc, gc);
rmsnorm_per_row_chained(
&self.ctx, &self.pipes, enc,
&s.ffw_out, Some(post_norm_w), &s.ffw_out,
&s.h_norm, seq, hidden, eps,
);
half_residual_add_chained(&self.ctx, &self.pipes, enc, &s.residual, &s.h_norm, n_h);
enc.copy_buffer_to_buffer(&s.residual, 0, &s.h_main, 0, (n_h * 4) as u64);
}
fn dispatch_lightconv(
&self,
enc: &mut wgpu::CommandEncoder,
meta: &GpuAudioBlockMeta,
w: &GpuAudioBlockWeights,
seq: usize, hidden: usize, eps: f32, gc: f32,
) {
let s = &self.scratch;
let n_h = seq * hidden;
let n_2h = seq * 2 * hidden;
let kernel = self.cfg.conv_kernel as usize;
enc.copy_buffer_to_buffer(&s.h_main, 0, &s.residual, 0, (n_h * 4) as u64);
rmsnorm_per_row_chained(
&self.ctx, &self.pipes, enc,
&s.h_main, Some(&w.conv_norm), &s.h_main,
&s.h_norm, seq, hidden, eps,
);
let cl_pw1 = &meta.cl_conv_pw1;
if cl_pw1.in_max != 0.0 {
clamp_chained(&self.ctx, &self.pipes, enc, &s.h_norm, n_h, cl_pw1.in_min, cl_pw1.in_max);
}
matmul_bf16_batched_chained(
&self.ctx, &self.pipes, enc,
&w.conv_pw1, &s.h_norm, &s.pw1_out,
hidden, 2 * hidden, seq,
);
if cl_pw1.out_max != 0.0 {
clamp_chained(&self.ctx, &self.pipes, enc, &s.pw1_out, n_2h, cl_pw1.out_min, cl_pw1.out_max);
}
glu_split_chained(&self.ctx, &self.pipes, enc, &s.pw1_out, &s.glu_out, seq, hidden);
depthwise_conv1d_chained(
&self.ctx, &self.pipes, enc,
&s.glu_out, &meta.conv_dw, &s.conv_dw_out,
seq, hidden, kernel,
);
clamp_chained(&self.ctx, &self.pipes, enc, &s.conv_dw_out, n_h, -gc, gc);
rmsnorm_per_row_chained(
&self.ctx, &self.pipes, enc,
&s.conv_dw_out, Some(&w.norm_conv), &s.conv_dw_out,
&s.h_norm, seq, hidden, eps,
);
silu_chained(&self.ctx, &self.pipes, enc, &s.h_norm, n_h);
let cl_pw2 = &meta.cl_conv_pw2;
if cl_pw2.in_max != 0.0 {
clamp_chained(&self.ctx, &self.pipes, enc, &s.h_norm, n_h, cl_pw2.in_min, cl_pw2.in_max);
}
matmul_bf16_batched_chained(
&self.ctx, &self.pipes, enc,
&w.conv_pw2, &s.h_norm, &s.pw2_out,
hidden, hidden, seq,
);
if cl_pw2.out_max != 0.0 {
clamp_chained(&self.ctx, &self.pipes, enc, &s.pw2_out, n_h, cl_pw2.out_min, cl_pw2.out_max);
}
enc.copy_buffer_to_buffer(&s.residual, 0, &s.h_main, 0, (n_h * 4) as u64);
crate::backend::dispatch::residual_add_chained(
&self.ctx, &self.pipes, enc,
&s.h_main, &s.pw2_out, n_h,
);
}
fn dispatch_projector(
&self,
enc: &mut wgpu::CommandEncoder,
seq: usize, hidden: usize, d_text: usize,
) {
let s = &self.scratch;
let eps = self.cfg.eps;
match self.proj_fc_dtype {
GgmlDtype::F16 => matmul_f16_batched_chained(
&self.ctx, &self.pipes, enc,
&self.proj_fc, &s.h_main, &s.fc_out,
hidden, d_text, seq,
),
GgmlDtype::BF16 => matmul_bf16_batched_chained(
&self.ctx, &self.pipes, enc,
&self.proj_fc, &s.h_main, &s.fc_out,
hidden, d_text, seq,
),
other => panic!("audio projector FC dtype {other:?} not supported"),
}
if let Some(bias) = self.proj_fc_bias.as_ref() {
add_bias_batched_chained(
&self.ctx, &self.pipes, enc,
&s.fc_out, bias, d_text, seq,
);
}
rmsnorm_per_row_chained(
&self.ctx, &self.pipes, enc,
&s.fc_out, None, &s.fc_out,
&s.fc_normed, seq, d_text, eps,
);
match self.proj_input_dtype {
GgmlDtype::F16 => matmul_f16_batched_chained(
&self.ctx, &self.pipes, enc,
&self.proj_input, &s.fc_normed, &s.soft,
d_text, d_text, seq,
),
GgmlDtype::BF16 => matmul_bf16_batched_chained(
&self.ctx, &self.pipes, enc,
&self.proj_input, &s.fc_normed, &s.soft,
d_text, d_text, seq,
),
other => panic!("audio projector input dtype {other:?} not supported"),
}
}
}
async fn load_gpu_block_meta(
wcache: &Arc<WeightCache>, i: u32, ctx: &WgpuCtx, q_scale_base: f32,
) -> Result<GpuAudioBlockMeta> {
let p = format!("a.blk.{i}.");
let r = wcache.reader();
let per_dim_scale_cpu = dequant_tensor_to_f32_async(r, &format!("{p}per_dim_scale.weight")).await?;
let scaled: Vec<f32> = per_dim_scale_cpu.iter().map(|&v| v * q_scale_base).collect();
let per_dim_scale_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("aud.per_dim_scale"),
size: (scaled.len() * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue.write_buffer(&per_dim_scale_buf, 0, cast_slice(&scaled));
let conv_dw_cpu = dequant_tensor_to_f32_async(r, &format!("{p}conv_dw.weight")).await?;
let conv_dw_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("aud.conv_dw"),
size: (conv_dw_cpu.len() * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue.write_buffer(&conv_dw_buf, 0, cast_slice(&conv_dw_cpu));
Ok(GpuAudioBlockMeta {
per_dim_scale: per_dim_scale_buf,
conv_dw: conv_dw_buf,
cl_attn_q: load_clamp(wcache, &format!("{p}attn_q")).await,
cl_attn_k: load_clamp(wcache, &format!("{p}attn_k")).await,
cl_attn_v: load_clamp(wcache, &format!("{p}attn_v")).await,
cl_attn_o: load_clamp(wcache, &format!("{p}attn_out")).await,
cl_ffw_up: load_clamp(wcache, &format!("{p}ffn_up")).await,
cl_ffw_down: load_clamp(wcache, &format!("{p}ffn_down")).await,
cl_ffw_up_1: load_clamp(wcache, &format!("{p}ffn_up_1")).await,
cl_ffw_down_1: load_clamp(wcache, &format!("{p}ffn_down_1")).await,
cl_conv_pw1: load_clamp(wcache, &format!("{p}conv_pw1")).await,
cl_conv_pw2: load_clamp(wcache, &format!("{p}conv_pw2")).await,
})
}
async fn fetch_gpu_block_weights(
wcache: &Arc<WeightCache>, i: u32,
) -> Result<GpuAudioBlockWeights> {
let p = format!("a.blk.{i}.");
let buf = |suffix: &str| -> _ {
let name = format!("{p}{suffix}");
async move { wcache.buffer_async_ephemeral(&name).await }
};
Ok(GpuAudioBlockWeights {
pre_norm: buf("layer_pre_norm.weight").await?,
ffw_norm: buf("ffn_norm.weight").await?,
ffw_up: buf("ffn_up.weight").await?,
ffw_down: buf("ffn_down.weight").await?,
ffw_post_norm: buf("ffn_post_norm.weight").await?,
ffw_norm_1: buf("ffn_norm_1.weight").await?,
ffw_up_1: buf("ffn_up_1.weight").await?,
ffw_down_1: buf("ffn_down_1.weight").await?,
ffw_post_norm_1: buf("ffn_post_norm_1.weight").await?,
attn_pre_norm: buf("ln1.weight").await?,
attn_post_norm: buf("ln2.weight").await?,
attn_q: buf("attn_q.weight").await?,
attn_k: buf("attn_k.weight").await?,
attn_v: buf("attn_v.weight").await?,
attn_o: buf("attn_out.weight").await?,
linear_pos: buf("linear_pos.weight").await?,
conv_norm: buf("conv_norm.weight").await?,
norm_conv: buf("norm_conv.weight").await?,
conv_pw1: buf("conv_pw1.weight").await?,
conv_pw2: buf("conv_pw2.weight").await?,
})
}
async fn load_clamp(wcache: &Arc<WeightCache>, prefix: &str) -> Clamp {
let one = |suffix: &str| {
let name = format!("{prefix}.{suffix}");
async move {
match wcache.reader().tensor(&name) {
Ok(_) => dequant_tensor_to_f32_async(wcache.reader(), &name).await
.ok().and_then(|v| v.first().copied()).unwrap_or(0.0),
Err(_) => 0.0,
}
}
};
Clamp {
in_min: one("input_min").await,
in_max: one("input_max").await,
out_min: one("output_min").await,
out_max: one("output_max").await,
}
}