#[cfg(feature = "gpu")]
use trueno::backends::gpu::wgpu;
#[cfg(feature = "gpu")]
pub struct WgpuBlock {
pub layer_idx: usize,
pub input_norm: wgpu::Buffer,
pub post_attn_norm: wgpu::Buffer,
pub w_q: wgpu::Buffer, pub w_k: wgpu::Buffer, pub w_v: wgpu::Buffer, pub w_o: wgpu::Buffer, pub w_gate: wgpu::Buffer, pub w_up: wgpu::Buffer, pub w_down: wgpu::Buffer,
pub lora: Option<WgpuLoraAdapters>,
}
#[cfg(feature = "gpu")]
pub struct WgpuLoraAdapters {
pub rank: u32,
pub scale: f32,
pub a_q: wgpu::Buffer,
pub b_q: wgpu::Buffer,
pub a_k: wgpu::Buffer,
pub b_k: wgpu::Buffer,
pub a_v: wgpu::Buffer,
pub b_v: wgpu::Buffer,
pub a_o: wgpu::Buffer,
pub b_o: wgpu::Buffer,
pub a_gate: wgpu::Buffer,
pub b_gate: wgpu::Buffer,
pub a_up: wgpu::Buffer,
pub b_up: wgpu::Buffer,
pub a_down: wgpu::Buffer,
pub b_down: wgpu::Buffer,
pub m_states: Vec<wgpu::Buffer>, pub v_states: Vec<wgpu::Buffer>, }
#[cfg(feature = "gpu")]
pub struct WgpuBlockManager {
pub device: wgpu::Device,
pub queue: wgpu::Queue,
pub blocks: Vec<WgpuBlock>,
pub hidden_buf: wgpu::Buffer, pub hidden_buf2: wgpu::Buffer, pub attn_out_buf: wgpu::Buffer, pub ffn_gate_buf: wgpu::Buffer, pub ffn_up_buf: wgpu::Buffer, pub ffn_silu_buf: wgpu::Buffer, pub norm_buf: wgpu::Buffer, pub q_buf: wgpu::Buffer, pub k_buf: wgpu::Buffer, pub v_buf: wgpu::Buffer,
pub embed_weight: wgpu::Buffer, pub lm_head_weight: wgpu::Buffer, pub logits_buf: wgpu::Buffer,
pub grad_hidden_buf: wgpu::Buffer, pub grad_logits_buf: wgpu::Buffer,
pub hidden_size: u32,
pub intermediate_size: u32,
pub num_heads: u32,
pub num_kv_heads: u32,
pub head_dim: u32,
pub max_seq_len: u32,
pub vocab_size: u32,
pub num_layers: u32,
}
#[cfg(feature = "gpu")]
impl WgpuBlockManager {
pub fn new(
device: wgpu::Device,
queue: wgpu::Queue,
hidden_size: u32,
intermediate_size: u32,
num_heads: u32,
num_kv_heads: u32,
head_dim: u32,
num_layers: u32,
vocab_size: u32,
max_seq_len: u32,
_lora_rank: Option<u32>,
_lora_alpha: Option<f32>,
) -> Self {
let q_dim = num_heads * head_dim;
let kv_dim = num_kv_heads * head_dim;
let max = max_seq_len;
let buf = |size: u32, label: &str| -> wgpu::Buffer {
device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: u64::from(size) * 4,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
})
};
Self {
blocks: Vec::with_capacity(num_layers as usize),
hidden_buf: buf(max * hidden_size, "hidden"),
hidden_buf2: buf(max * hidden_size, "hidden2"),
attn_out_buf: buf(max * hidden_size, "attn_out"),
ffn_gate_buf: buf(max * intermediate_size, "ffn_gate"),
ffn_up_buf: buf(max * intermediate_size, "ffn_up"),
ffn_silu_buf: buf(max * intermediate_size, "ffn_silu"),
norm_buf: buf(max * hidden_size, "norm"),
q_buf: buf(max * q_dim, "q"),
k_buf: buf(max * kv_dim, "k"),
v_buf: buf(max * kv_dim, "v"),
embed_weight: buf(vocab_size * hidden_size, "embed"),
lm_head_weight: buf(vocab_size * hidden_size, "lm_head"),
logits_buf: buf(max * vocab_size, "logits"),
grad_hidden_buf: buf(max * hidden_size, "grad_hidden"),
grad_logits_buf: buf(max * vocab_size, "grad_logits"),
hidden_size,
intermediate_size,
num_heads,
num_kv_heads,
head_dim,
max_seq_len: max,
vocab_size,
num_layers,
device,
queue,
}
}
pub fn upload_layer(
&mut self,
layer_idx: usize,
input_norm: &[f32],
post_attn_norm: &[f32],
w_q: &[f32],
w_k: &[f32],
w_v: &[f32],
w_o: &[f32],
w_gate: &[f32],
w_up: &[f32],
w_down: &[f32],
lora_rank: Option<u32>,
lora_scale: Option<f32>,
) {
let upload = |data: &[f32], label: &str| -> wgpu::Buffer {
let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: (data.len() * 4) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
self.queue.write_buffer(&buffer, 0, bytemuck::cast_slice(data));
buffer
};
let prefix = format!("L{layer_idx}");
let lora = lora_rank.map(|rank| {
let scale = lora_scale.unwrap_or(1.0);
let h = self.hidden_size as usize;
let q = (self.num_heads * self.head_dim) as usize;
let kv = (self.num_kv_heads * self.head_dim) as usize;
let inter = self.intermediate_size as usize;
let r = rank as usize;
let kaiming = |fan_in: usize, len: usize| -> Vec<f32> {
let std = (2.0 / fan_in as f32).sqrt();
(0..len).map(|i| ((i as f32 * 0.013 + layer_idx as f32).sin() * std)).collect()
};
let zeros = |len: usize| vec![0.0f32; len];
let pairs: Vec<(usize, usize, &str)> = vec![
(h, q, "q"),
(h, kv, "k"),
(h, kv, "v"),
(q, h, "o"),
(h, inter, "gate"),
(h, inter, "up"),
(inter, h, "down"),
];
let mut m_states = Vec::with_capacity(14);
let mut v_states = Vec::with_capacity(14);
let mut a_bufs = Vec::with_capacity(7);
let mut b_bufs = Vec::with_capacity(7);
for (in_d, out_d, name) in &pairs {
let a = upload(&kaiming(*in_d, in_d * r), &format!("{prefix}.lora_a_{name}"));
let b = upload(&zeros(r * out_d), &format!("{prefix}.lora_b_{name}"));
m_states.push(upload(&zeros(in_d * r), &format!("{prefix}.m_a_{name}")));
m_states.push(upload(&zeros(r * out_d), &format!("{prefix}.m_b_{name}")));
v_states.push(upload(&zeros(in_d * r), &format!("{prefix}.v_a_{name}")));
v_states.push(upload(&zeros(r * out_d), &format!("{prefix}.v_b_{name}")));
a_bufs.push(a);
b_bufs.push(b);
}
WgpuLoraAdapters {
rank,
scale,
a_q: a_bufs.remove(0),
b_q: b_bufs.remove(0),
a_k: a_bufs.remove(0),
b_k: b_bufs.remove(0),
a_v: a_bufs.remove(0),
b_v: b_bufs.remove(0),
a_o: a_bufs.remove(0),
b_o: b_bufs.remove(0),
a_gate: a_bufs.remove(0),
b_gate: b_bufs.remove(0),
a_up: a_bufs.remove(0),
b_up: b_bufs.remove(0),
a_down: a_bufs.remove(0),
b_down: b_bufs.remove(0),
m_states,
v_states,
}
});
self.blocks.push(WgpuBlock {
layer_idx,
input_norm: upload(input_norm, &format!("{prefix}.input_norm")),
post_attn_norm: upload(post_attn_norm, &format!("{prefix}.post_attn_norm")),
w_q: upload(w_q, &format!("{prefix}.q_proj")),
w_k: upload(w_k, &format!("{prefix}.k_proj")),
w_v: upload(w_v, &format!("{prefix}.v_proj")),
w_o: upload(w_o, &format!("{prefix}.o_proj")),
w_gate: upload(w_gate, &format!("{prefix}.gate_proj")),
w_up: upload(w_up, &format!("{prefix}.up_proj")),
w_down: upload(w_down, &format!("{prefix}.down_proj")),
lora,
});
eprintln!(
"[wgpu] Uploaded layer {}/{} ({})",
layer_idx + 1,
self.num_layers,
if self.blocks.last().unwrap().lora.is_some() { "with LoRA" } else { "frozen" }
);
}
pub fn upload_embeddings(&mut self, embed: &[f32], lm_head: &[f32]) {
self.queue.write_buffer(&self.embed_weight, 0, bytemuck::cast_slice(embed));
self.queue.write_buffer(&self.lm_head_weight, 0, bytemuck::cast_slice(lm_head));
eprintln!(
"[wgpu] Uploaded embeddings: embed=[{}×{}], lm_head=[{}×{}]",
self.vocab_size, self.hidden_size, self.vocab_size, self.hidden_size
);
}
pub fn gpu_memory_bytes(&self) -> u64 {
let h = u64::from(self.hidden_size);
let inter = u64::from(self.intermediate_size);
let q = u64::from(self.num_heads * self.head_dim);
let kv = u64::from(self.num_kv_heads * self.head_dim);
let v = u64::from(self.vocab_size);
let s = u64::from(self.max_seq_len);
let l = u64::from(self.num_layers);
let per_layer_weights =
(2 * h + q * h + kv * h * 2 + h * q + inter * h * 2 + h * inter) * 4;
let shared_bufs =
(s * h * 4 + s * inter * 3 + s * q + s * kv * 2 + s * v * 2 + v * h * 2) * 4;
per_layer_weights * l + shared_bufs
}
pub fn layer_count(&self) -> usize {
self.blocks.len()
}
}
#[cfg(test)]
#[cfg(feature = "gpu")]
mod tests {
use super::*;
#[test]
fn test_wgpu_block_manager_creation() {
let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
let adapter = match trueno::backends::gpu::runtime::block_on(
instance.request_adapter(&wgpu::RequestAdapterOptions::default()),
) {
Ok(a) => a,
Err(_) => return, };
let (device, queue) = match trueno::backends::gpu::runtime::block_on(
adapter.request_device(&wgpu::DeviceDescriptor::default()),
) {
Ok(dq) => dq,
Err(_) => return,
};
let mut mgr = WgpuBlockManager::new(
device,
queue,
64, 128, 4, 4, 16, 2, 100, 32, Some(8), Some(2.0), );
for i in 0..2 {
let h = 64;
let inter = 128;
let q_dim = 4 * 16;
let kv_dim = 4 * 16;
mgr.upload_layer(
i,
&vec![1.0; h], &vec![1.0; h], &vec![0.01; q_dim * h], &vec![0.01; kv_dim * h], &vec![0.01; kv_dim * h], &vec![0.01; h * q_dim], &vec![0.01; inter * h], &vec![0.01; inter * h], &vec![0.01; h * inter], Some(8),
Some(2.0 / 8.0),
);
}
assert_eq!(mgr.layer_count(), 2);
assert!(mgr.gpu_memory_bytes() > 0);
}
}