1#[cfg(feature = "gpu")]
18use trueno::backends::gpu::wgpu;
19
20#[cfg(feature = "gpu")]
22pub struct WgpuBlock {
23 pub layer_idx: usize,
24
25 pub input_norm: wgpu::Buffer,
27 pub post_attn_norm: wgpu::Buffer,
28
29 pub w_q: wgpu::Buffer, pub w_k: wgpu::Buffer, pub w_v: wgpu::Buffer, pub w_o: wgpu::Buffer, pub w_gate: wgpu::Buffer, pub w_up: wgpu::Buffer, pub w_down: wgpu::Buffer, pub lora: Option<WgpuLoraAdapters>,
41}
42
43#[cfg(feature = "gpu")]
45pub struct WgpuLoraAdapters {
46 pub rank: u32,
47 pub scale: f32, pub a_q: wgpu::Buffer,
51 pub b_q: wgpu::Buffer,
52 pub a_k: wgpu::Buffer,
53 pub b_k: wgpu::Buffer,
54 pub a_v: wgpu::Buffer,
55 pub b_v: wgpu::Buffer,
56 pub a_o: wgpu::Buffer,
57 pub b_o: wgpu::Buffer,
58 pub a_gate: wgpu::Buffer,
59 pub b_gate: wgpu::Buffer,
60 pub a_up: wgpu::Buffer,
61 pub b_up: wgpu::Buffer,
62 pub a_down: wgpu::Buffer,
63 pub b_down: wgpu::Buffer,
64
65 pub m_states: Vec<wgpu::Buffer>, pub v_states: Vec<wgpu::Buffer>, }
69
70#[cfg(feature = "gpu")]
72pub struct WgpuBlockManager {
73 pub device: wgpu::Device,
74 pub queue: wgpu::Queue,
75 pub blocks: Vec<WgpuBlock>,
76
77 pub hidden_buf: wgpu::Buffer, pub hidden_buf2: wgpu::Buffer, pub attn_out_buf: wgpu::Buffer, pub ffn_gate_buf: wgpu::Buffer, pub ffn_up_buf: wgpu::Buffer, pub ffn_silu_buf: wgpu::Buffer, pub norm_buf: wgpu::Buffer, pub q_buf: wgpu::Buffer, pub k_buf: wgpu::Buffer, pub v_buf: wgpu::Buffer, pub embed_weight: wgpu::Buffer, pub lm_head_weight: wgpu::Buffer, pub logits_buf: wgpu::Buffer, pub grad_hidden_buf: wgpu::Buffer, pub grad_logits_buf: wgpu::Buffer, pub hidden_size: u32,
100 pub intermediate_size: u32,
101 pub num_heads: u32,
102 pub num_kv_heads: u32,
103 pub head_dim: u32,
104 pub max_seq_len: u32,
105 pub vocab_size: u32,
106 pub num_layers: u32,
107}
108
109#[cfg(feature = "gpu")]
110impl WgpuBlockManager {
111 pub fn new(
116 device: wgpu::Device,
117 queue: wgpu::Queue,
118 hidden_size: u32,
119 intermediate_size: u32,
120 num_heads: u32,
121 num_kv_heads: u32,
122 head_dim: u32,
123 num_layers: u32,
124 vocab_size: u32,
125 max_seq_len: u32,
126 _lora_rank: Option<u32>,
127 _lora_alpha: Option<f32>,
128 ) -> Self {
129 let q_dim = num_heads * head_dim;
130 let kv_dim = num_kv_heads * head_dim;
131 let max = max_seq_len;
132
133 let buf = |size: u32, label: &str| -> wgpu::Buffer {
135 device.create_buffer(&wgpu::BufferDescriptor {
136 label: Some(label),
137 size: u64::from(size) * 4,
138 usage: wgpu::BufferUsages::STORAGE
139 | wgpu::BufferUsages::COPY_SRC
140 | wgpu::BufferUsages::COPY_DST,
141 mapped_at_creation: false,
142 })
143 };
144
145 Self {
146 blocks: Vec::with_capacity(num_layers as usize),
147 hidden_buf: buf(max * hidden_size, "hidden"),
148 hidden_buf2: buf(max * hidden_size, "hidden2"),
149 attn_out_buf: buf(max * hidden_size, "attn_out"),
150 ffn_gate_buf: buf(max * intermediate_size, "ffn_gate"),
151 ffn_up_buf: buf(max * intermediate_size, "ffn_up"),
152 ffn_silu_buf: buf(max * intermediate_size, "ffn_silu"),
153 norm_buf: buf(max * hidden_size, "norm"),
154 q_buf: buf(max * q_dim, "q"),
155 k_buf: buf(max * kv_dim, "k"),
156 v_buf: buf(max * kv_dim, "v"),
157 embed_weight: buf(vocab_size * hidden_size, "embed"),
158 lm_head_weight: buf(vocab_size * hidden_size, "lm_head"),
159 logits_buf: buf(max * vocab_size, "logits"),
160 grad_hidden_buf: buf(max * hidden_size, "grad_hidden"),
161 grad_logits_buf: buf(max * vocab_size, "grad_logits"),
162 hidden_size,
163 intermediate_size,
164 num_heads,
165 num_kv_heads,
166 head_dim,
167 max_seq_len: max,
168 vocab_size,
169 num_layers,
170 device,
171 queue,
172 }
173 }
174
175 pub fn upload_layer(
177 &mut self,
178 layer_idx: usize,
179 input_norm: &[f32],
180 post_attn_norm: &[f32],
181 w_q: &[f32],
182 w_k: &[f32],
183 w_v: &[f32],
184 w_o: &[f32],
185 w_gate: &[f32],
186 w_up: &[f32],
187 w_down: &[f32],
188 lora_rank: Option<u32>,
189 lora_scale: Option<f32>,
190 ) {
191 let upload = |data: &[f32], label: &str| -> wgpu::Buffer {
192 let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
193 label: Some(label),
194 size: (data.len() * 4) as u64,
195 usage: wgpu::BufferUsages::STORAGE
196 | wgpu::BufferUsages::COPY_SRC
197 | wgpu::BufferUsages::COPY_DST,
198 mapped_at_creation: false,
199 });
200 self.queue.write_buffer(&buffer, 0, bytemuck::cast_slice(data));
201 buffer
202 };
203
204 let prefix = format!("L{layer_idx}");
205
206 let lora = lora_rank.map(|rank| {
207 let scale = lora_scale.unwrap_or(1.0);
208 let h = self.hidden_size as usize;
209 let q = (self.num_heads * self.head_dim) as usize;
210 let kv = (self.num_kv_heads * self.head_dim) as usize;
211 let inter = self.intermediate_size as usize;
212 let r = rank as usize;
213
214 let kaiming = |fan_in: usize, len: usize| -> Vec<f32> {
216 let std = (2.0 / fan_in as f32).sqrt();
217 (0..len).map(|i| ((i as f32 * 0.013 + layer_idx as f32).sin() * std)).collect()
218 };
219 let zeros = |len: usize| vec![0.0f32; len];
220
221 let pairs: Vec<(usize, usize, &str)> = vec![
222 (h, q, "q"),
223 (h, kv, "k"),
224 (h, kv, "v"),
225 (q, h, "o"),
226 (h, inter, "gate"),
227 (h, inter, "up"),
228 (inter, h, "down"),
229 ];
230
231 let mut m_states = Vec::with_capacity(14);
232 let mut v_states = Vec::with_capacity(14);
233 let mut a_bufs = Vec::with_capacity(7);
234 let mut b_bufs = Vec::with_capacity(7);
235
236 for (in_d, out_d, name) in &pairs {
237 let a = upload(&kaiming(*in_d, in_d * r), &format!("{prefix}.lora_a_{name}"));
238 let b = upload(&zeros(r * out_d), &format!("{prefix}.lora_b_{name}"));
239 m_states.push(upload(&zeros(in_d * r), &format!("{prefix}.m_a_{name}")));
240 m_states.push(upload(&zeros(r * out_d), &format!("{prefix}.m_b_{name}")));
241 v_states.push(upload(&zeros(in_d * r), &format!("{prefix}.v_a_{name}")));
242 v_states.push(upload(&zeros(r * out_d), &format!("{prefix}.v_b_{name}")));
243 a_bufs.push(a);
244 b_bufs.push(b);
245 }
246
247 WgpuLoraAdapters {
248 rank,
249 scale,
250 a_q: a_bufs.remove(0),
251 b_q: b_bufs.remove(0),
252 a_k: a_bufs.remove(0),
253 b_k: b_bufs.remove(0),
254 a_v: a_bufs.remove(0),
255 b_v: b_bufs.remove(0),
256 a_o: a_bufs.remove(0),
257 b_o: b_bufs.remove(0),
258 a_gate: a_bufs.remove(0),
259 b_gate: b_bufs.remove(0),
260 a_up: a_bufs.remove(0),
261 b_up: b_bufs.remove(0),
262 a_down: a_bufs.remove(0),
263 b_down: b_bufs.remove(0),
264 m_states,
265 v_states,
266 }
267 });
268
269 self.blocks.push(WgpuBlock {
270 layer_idx,
271 input_norm: upload(input_norm, &format!("{prefix}.input_norm")),
272 post_attn_norm: upload(post_attn_norm, &format!("{prefix}.post_attn_norm")),
273 w_q: upload(w_q, &format!("{prefix}.q_proj")),
274 w_k: upload(w_k, &format!("{prefix}.k_proj")),
275 w_v: upload(w_v, &format!("{prefix}.v_proj")),
276 w_o: upload(w_o, &format!("{prefix}.o_proj")),
277 w_gate: upload(w_gate, &format!("{prefix}.gate_proj")),
278 w_up: upload(w_up, &format!("{prefix}.up_proj")),
279 w_down: upload(w_down, &format!("{prefix}.down_proj")),
280 lora,
281 });
282
283 eprintln!(
284 "[wgpu] Uploaded layer {}/{} ({})",
285 layer_idx + 1,
286 self.num_layers,
287 if self.blocks.last().unwrap().lora.is_some() { "with LoRA" } else { "frozen" }
288 );
289 }
290
291 pub fn upload_embeddings(&mut self, embed: &[f32], lm_head: &[f32]) {
293 self.queue.write_buffer(&self.embed_weight, 0, bytemuck::cast_slice(embed));
294 self.queue.write_buffer(&self.lm_head_weight, 0, bytemuck::cast_slice(lm_head));
295 eprintln!(
296 "[wgpu] Uploaded embeddings: embed=[{}×{}], lm_head=[{}×{}]",
297 self.vocab_size, self.hidden_size, self.vocab_size, self.hidden_size
298 );
299 }
300
301 pub fn gpu_memory_bytes(&self) -> u64 {
303 let h = u64::from(self.hidden_size);
304 let inter = u64::from(self.intermediate_size);
305 let q = u64::from(self.num_heads * self.head_dim);
306 let kv = u64::from(self.num_kv_heads * self.head_dim);
307 let v = u64::from(self.vocab_size);
308 let s = u64::from(self.max_seq_len);
309 let l = u64::from(self.num_layers);
310
311 let per_layer_weights =
313 (2 * h + q * h + kv * h * 2 + h * q + inter * h * 2 + h * inter) * 4;
314 let shared_bufs =
315 (s * h * 4 + s * inter * 3 + s * q + s * kv * 2 + s * v * 2 + v * h * 2) * 4;
316
317 per_layer_weights * l + shared_bufs
318 }
319
320 pub fn layer_count(&self) -> usize {
322 self.blocks.len()
323 }
324}
325
326#[cfg(test)]
327#[cfg(feature = "gpu")]
328mod tests {
329 use super::*;
330
331 #[test]
332 fn test_wgpu_block_manager_creation() {
333 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
334 let adapter = match trueno::backends::gpu::runtime::block_on(
335 instance.request_adapter(&wgpu::RequestAdapterOptions::default()),
336 ) {
337 Ok(a) => a,
338 Err(_) => return, };
340 let (device, queue) = match trueno::backends::gpu::runtime::block_on(
341 adapter.request_device(&wgpu::DeviceDescriptor::default()),
342 ) {
343 Ok(dq) => dq,
344 Err(_) => return,
345 };
346
347 let mut mgr = WgpuBlockManager::new(
348 device,
349 queue,
350 64, 128, 4, 4, 16, 2, 100, 32, Some(8), Some(2.0), );
361
362 for i in 0..2 {
364 let h = 64;
365 let inter = 128;
366 let q_dim = 4 * 16;
367 let kv_dim = 4 * 16;
368 mgr.upload_layer(
369 i,
370 &vec![1.0; h], &vec![1.0; h], &vec![0.01; q_dim * h], &vec![0.01; kv_dim * h], &vec![0.01; kv_dim * h], &vec![0.01; h * q_dim], &vec![0.01; inter * h], &vec![0.01; inter * h], &vec![0.01; h * inter], Some(8),
380 Some(2.0 / 8.0),
381 );
382 }
383
384 assert_eq!(mgr.layer_count(), 2);
385 assert!(mgr.gpu_memory_bytes() > 0);
386 }
387}