rullama 0.5.0

Browser-resident Gemma 4 inference: pure Rust → WebAssembly + WebGPU. Loads Ollama's on-disk GGUF blobs and runs the forward pass on the local GPU via hand-written WGSL.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
//! Lazy GPU weight buffer cache.
//!
//! Each tensor in the GGUF gets uploaded to a `wgpu::Buffer` on first access; future
//! calls return clones of the same buffer (wgpu Buffers are Arc-internally, so
//! cloning is `Arc::clone` and free). Eliminates the per-call weight upload that
//! dominates `forward_token_gpu` cost.
//!
//! Two access modes:
//!
//! * Sync (`buffer`, `buffer_tiles`) — borrows tensor bytes from the in-memory reader
//!   and uploads. Only valid for `GgufReader::is_in_memory()` readers; errors otherwise.
//!   Used by all native / test callers.
//! * Async (`buffer_async`, `buffer_tiles_async`) — fetches the bytes through the
//!   reader's `TensorFetcher`, uploads, and drops the temporary `Vec<u8>` immediately.
//!   The streaming wasm32 path uses this so peak CPU memory stays bounded.
//!
//! Both paths populate the same buffer cache: once a tensor is on the GPU it doesn't
//! matter how it got there.

use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::Arc;

use crate::backend::{BindGroupCache, buf_id};
use crate::error::{Result, RullamaError};
use crate::gguf::{GgmlDtype, GgufReader};

/// One tile of a row-tiled tensor.
pub struct TiledTensor {
    pub buffer: wgpu::Buffer,
    /// Index of the first row (along the slow / second axis) covered by this buffer.
    pub row_start: usize,
    /// Number of rows covered.
    pub n_rows: usize,
}

/// Key for tile / tile-metadata maps: tensor name + tile size in elements.
type TileKey = (String, usize);
/// Tile metadata: `(byte_offset, row_count)` per tile slice.
type TileMeta = Vec<(usize, usize)>;

pub struct WeightCache {
    reader: Arc<GgufReader>,
    device: wgpu::Device,
    queue: wgpu::Queue,
    /// Shared bind-group cache (same handle as `WgpuCtx::bind_cache`).
    /// Each `drop_*_destroy` invalidates any cached bind groups that
    /// reference the buffers about to be destroyed, BEFORE calling
    /// `Buffer::destroy()` — guards against the use-after-destroy
    /// observed on iOS Safari WebGPU.
    bind_cache: Arc<BindGroupCache>,
    buffers: RefCell<HashMap<String, wgpu::Buffer>>,
    tiles: RefCell<HashMap<TileKey, Vec<wgpu::Buffer>>>,
    tile_meta: RefCell<HashMap<TileKey, TileMeta>>,
}

impl WeightCache {
    pub fn new(
        reader: Arc<GgufReader>,
        device: wgpu::Device,
        queue: wgpu::Queue,
        bind_cache: Arc<BindGroupCache>,
    ) -> Self {
        Self {
            reader,
            device,
            queue,
            bind_cache,
            buffers: RefCell::new(HashMap::new()),
            tiles: RefCell::new(HashMap::new()),
            tile_meta: RefCell::new(HashMap::new()),
        }
    }

    /// Borrow of the underlying GGUF reader (for callers that occasionally need an
    /// f32 dequant outside the GPU buffer path — e.g. the small RoPE freq-factors tensor).
    pub fn reader(&self) -> &GgufReader {
        &self.reader
    }

    /// Shared `Arc` to the underlying GGUF reader. Used by callers that need to
    /// re-build a sibling like `VisionForward` after the cache + struct was
    /// released to free GPU memory (the rebuild has to re-read `VisionConfig::from_gguf`).
    pub fn reader_arc(&self) -> Arc<GgufReader> {
        self.reader.clone()
    }

    /// Internal: create+upload a single GPU buffer from a slice.
    fn upload(&self, name: &str, bytes: &[u8]) -> wgpu::Buffer {
        // write_buffer requires a COPY_BUFFER_ALIGNMENT (4-byte) multiple. Most
        // tensors are already aligned, but a row-aligned Q6_K tile can land on a
        // non-/4 boundary (row_bytes = (k/256)*210 is odd-multiple when k/256 is
        // odd — e.g. the 12b token_embd, d_model 3840 → 3150 B/row → an 8,388,450 B
        // tile). Pad up; the extra bytes are past the last block and never read.
        let padded_len = (bytes.len() + 3) & !3;
        let buf = self.device.create_buffer(&wgpu::BufferDescriptor {
            label: Some(name),
            size: padded_len as u64,
            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
            mapped_at_creation: false,
        });
        if padded_len == bytes.len() {
            self.queue.write_buffer(&buf, 0, bytes);
        } else {
            let mut padded = bytes.to_vec();
            padded.resize(padded_len, 0);
            self.queue.write_buffer(&buf, 0, &padded);
        }
        crate::backend::gpu_mem::record_alloc(&format!("weight:{name}"), padded_len as u64);
        buf
    }

    /// Get the GPU buffer for the named tensor, uploading on first access. Sync path:
    /// borrows the bytes directly from an in-memory reader. Errors on a streaming reader.
    pub fn buffer(&self, name: &str) -> Result<wgpu::Buffer> {
        if let Some(b) = self.buffers.borrow().get(name) {
            return Ok(b.clone());
        }
        let bytes = self.reader.tensor_bytes(name)?;
        let buf = self.upload(name, bytes);
        let cloned = buf.clone();
        self.buffers.borrow_mut().insert(name.to_string(), buf);
        Ok(cloned)
    }

    /// Get the GPU buffer for the named tensor, uploading on first access. Async path:
    /// works for both in-memory and streaming readers. The fetched `Vec<u8>` is dropped
    /// the moment the upload finishes — important for wasm32 (4 GB linear-memory cap).
    pub async fn buffer_async(&self, name: &str) -> Result<wgpu::Buffer> {
        if let Some(b) = self.buffers.borrow().get(name) {
            return Ok(b.clone());
        }
        let bytes = self.reader.fetch_tensor_bytes(name).await?;
        let buf = self.upload(name, &bytes);
        drop(bytes);
        let cloned = buf.clone();
        self.buffers.borrow_mut().insert(name.to_string(), buf);
        Ok(cloned)
    }

    /// Get a GPU buffer holding ONE expert's slice of a 3-D stacked MoE tensor
    /// (`blk.N.ffn_*_exps.weight`, dims `[in, out, n_experts]`), range-fetching
    /// only that expert's bytes instead of the whole `n_experts`-deep tensor.
    /// This is the per-expert streaming lever: a token routes to top-8 of 128
    /// experts, so fetching only the selected slices is ~16× less bandwidth
    /// than `buffer_async` of the full tensor.
    ///
    /// Cached under `{name}#e{idx}` — still `blk.N.*`-prefixed, so the per-layer
    /// destroy (`drop_blk_layer_range_destroy`) reclaims it like any layer
    /// weight. The returned buffer is a standalone 1-expert tensor: index it in
    /// the expert matmul with expert id 0.
    pub async fn buffer_expert_async(&self, name: &str, expert_idx: usize) -> Result<wgpu::Buffer> {
        let key = format!("{name}#e{expert_idx}");
        if let Some(b) = self.buffers.borrow().get(&key) {
            return Ok(b.clone());
        }
        let desc = self.reader.tensor(name)?.clone();
        if desc.dims.len() != 3 {
            return Err(crate::error::RullamaError::Gguf(format!(
                "buffer_expert_async: {name} has {} dims, expected 3",
                desc.dims.len()
            )));
        }
        let in_len = desc.dims[0] as usize;
        let out_len = desc.dims[1] as usize;
        let n_experts = desc.dims[2] as usize;
        if expert_idx >= n_experts {
            return Err(crate::error::RullamaError::Gguf(format!(
                "buffer_expert_async: expert {expert_idx} >= {n_experts} for {name}"
            )));
        }
        let bytes_per_row = (in_len / desc.dtype.block_elems()) * desc.dtype.block_bytes();
        let slice_bytes = bytes_per_row * out_len;
        let abs =
            self.reader.data_section_offset() + desc.offset + (expert_idx * slice_bytes) as u64;
        let bytes = self.reader.fetcher().fetch(abs, slice_bytes as u64).await?;
        let buf = self.upload(&key, &bytes);
        drop(bytes);
        let cloned = buf.clone();
        self.buffers.borrow_mut().insert(key, buf);
        Ok(cloned)
    }

    /// Best-effort buffer fetch: Ok(None) if the tensor is absent.
    pub fn buffer_opt(&self, name: &str) -> Result<Option<wgpu::Buffer>> {
        if self.reader.tensor(name).is_err() {
            return Ok(None);
        }
        self.buffer(name).map(Some)
    }

    /// Async variant of [`buffer_opt`].
    pub async fn buffer_opt_async(&self, name: &str) -> Result<Option<wgpu::Buffer>> {
        if self.reader.tensor(name).is_err() {
            return Ok(None);
        }
        self.buffer_async(name).await.map(Some)
    }

    /// Look up a tensor's GGML dtype (without uploading).
    pub fn dtype(&self, name: &str) -> Result<GgmlDtype> {
        Ok(self.reader.tensor(name)?.dtype)
    }

    pub fn cached_count(&self) -> usize {
        self.buffers.borrow().len()
    }

    pub fn cached_bytes(&self) -> u64 {
        let single: u64 = self.buffers.borrow().values().map(|b| b.size()).sum();
        let tiled: u64 = self
            .tiles
            .borrow()
            .values()
            .flat_map(|v| v.iter().map(|b| b.size()))
            .sum();
        single + tiled
    }

    /// Evict all cached buffers whose tensor name starts with `prefix`,
    /// dropping the Rust handles only (no explicit `destroy()`). Returns
    /// the number of entries removed (single + tiled combined).
    ///
    /// **Safe to call mid-step**, even while in-flight GPU commands or
    /// cached bind groups still reference the buffers: dropping the handle
    /// doesn't invalidate the underlying `GPUBuffer`, it just lets wgpu's
    /// allocator reuse that memory for subsequent allocations in the same
    /// device. Use this at the forward→backward boundary (the backward
    /// re-fetches layers it needs) where the forward's commands may still
    /// be in flight.
    ///
    /// Does NOT promptly reclaim GPU memory on the WebGPU backend (that
    /// waits for browser GC of the dropped wrapper) — for cross-step
    /// reclaim use [`drop_prefix_destroy`](Self::drop_prefix_destroy) at a
    /// GPU-idle point instead.
    pub fn drop_prefix(&self, prefix: &str) -> usize {
        let mut removed = 0usize;
        self.buffers.borrow_mut().retain(|k, v| {
            let hit = k.starts_with(prefix);
            if hit {
                crate::backend::gpu_mem::record_free(&format!("weight:{k}"), v.size());
                removed += 1;
            }
            !hit
        });
        self.tiles.borrow_mut().retain(|(k, _), v| {
            let hit = k.starts_with(prefix);
            if hit {
                for b in v.iter() {
                    crate::backend::gpu_mem::record_free(&format!("weight:{k}"), b.size());
                }
                removed += 1;
            }
            !hit
        });
        self.tile_meta
            .borrow_mut()
            .retain(|(k, _), _| !k.starts_with(prefix));
        removed
    }

    /// Like [`drop_prefix`](Self::drop_prefix) but ALSO calls
    /// `wgpu::Buffer::destroy()` on every evicted buffer to force prompt
    /// GPU-memory reclaim.
    ///
    /// **Only call at a GPU-idle point** — after a fence / map / readback
    /// that guarantees no in-flight command (and no cached bind group
    /// about to be re-used) references these buffers. `destroy()` while a
    /// buffer is still referenced by pending work or a live bind group is
    /// a use-after-destroy: on iOS Safari WebGPU it crashes the tab (we
    /// observed training die at the head→backward transition when destroy
    /// fired at the backward *start*, before the forward's commands had
    /// drained).
    ///
    /// On the WebGPU backend dropping the handle alone leaves the
    /// `GPUBuffer` resident until GC; `destroy()` reclaims it immediately
    /// so the next training step's forward re-cache starts from genuinely
    /// freed VRAM instead of stacking on the previous step's leftovers and
    /// crossing the iOS WebContent ceiling → jetsam. On native it frees
    /// immediately either way. Used by the training step at end-of-step
    /// (post loss-readback, GPU drained) and by
    /// `Model::release_vision_weights` between inference turns.
    pub fn drop_prefix_destroy(&self, prefix: &str) -> usize {
        // **Use-after-destroy guard.** Collect ids of every buffer
        // about to be destroyed and invalidate any cached bind groups
        // referencing them BEFORE we call `Buffer::destroy()`. On iOS
        // Safari WebGPU a bind group referencing a destroyed buffer
        // becomes a device-lost trigger on next use; per WebGPU spec
        // destroy is supposed to be safe but WebKit's implementation
        // is observably non-compliant (bug 302711 family).
        let mut victims: Vec<u64> = Vec::new();
        for (k, v) in self.buffers.borrow().iter() {
            if k.starts_with(prefix) {
                victims.push(buf_id(v));
            }
        }
        for ((k, _), tiles) in self.tiles.borrow().iter() {
            if k.starts_with(prefix) {
                for b in tiles {
                    victims.push(buf_id(b));
                }
            }
        }
        self.bind_cache.invalidate_buffers(&victims);

        let mut removed = 0usize;
        self.buffers.borrow_mut().retain(|k, v| {
            if k.starts_with(prefix) {
                crate::backend::gpu_mem::record_free(&format!("weight:{k}"), v.size());
                v.destroy();
                removed += 1;
                false
            } else {
                true
            }
        });
        self.tiles.borrow_mut().retain(|(k, _), v| {
            if k.starts_with(prefix) {
                for b in v.iter() {
                    crate::backend::gpu_mem::record_free(&format!("weight:{k}"), b.size());
                    b.destroy();
                }
                removed += 1;
                false
            } else {
                true
            }
        });
        self.tile_meta
            .borrow_mut()
            .retain(|(k, _), _| !k.starts_with(prefix));
        removed
    }

    /// Single-pass targeted destroy for the fwd→bwd boundary on iOS.
    /// Destroys every cached `blk.{i}.*` weight where `i` is in
    /// `[start_layer, end_layer)`, in ONE iteration through the cache
    /// instead of the N separate `drop_prefix_destroy` calls the caller
    /// would otherwise make. On iOS Safari WebGPU each
    /// `GPUBuffer.destroy()` is an IPC round-trip to the GPU process;
    /// firing 25 × ~7 = 175 of them in a tight loop with separate
    /// HashMap traversals was empirically tripping jetsam right at the
    /// forward→head transition (real-device trail: `step 2 forward 35/35
    /// gpuMiB=1417` → 💥). One pass through, one retain closure, fewer
    /// IPC dispatches.
    ///
    /// Returns the number of cache entries removed.
    pub fn drop_blk_layer_range_destroy(&self, start_layer: u32, end_layer: u32) -> usize {
        if end_layer <= start_layer {
            return 0;
        }
        // Parse the "blk.{N}." prefix out of a key without allocating;
        // returns the layer number or None if the key doesn't match the
        // "blk.<digits>.<rest>" shape.
        fn parse_blk_layer(key: &str) -> Option<u32> {
            let rest = key.strip_prefix("blk.")?;
            let dot = rest.find('.')?;
            rest[..dot].parse().ok()
        }
        let in_range = |key: &str| -> bool {
            match parse_blk_layer(key) {
                Some(n) => n >= start_layer && n < end_layer,
                None => false,
            }
        };

        // **Use-after-destroy guard** — see drop_prefix_destroy.
        let mut victims: Vec<u64> = Vec::new();
        for (k, v) in self.buffers.borrow().iter() {
            if in_range(k) {
                victims.push(buf_id(v));
            }
        }
        for ((k, _), tiles) in self.tiles.borrow().iter() {
            if in_range(k) {
                for b in tiles {
                    victims.push(buf_id(b));
                }
            }
        }
        self.bind_cache.invalidate_buffers(&victims);

        let mut removed = 0usize;
        self.buffers.borrow_mut().retain(|k, v| {
            if in_range(k) {
                crate::backend::gpu_mem::record_free(&format!("weight:{k}"), v.size());
                v.destroy();
                removed += 1;
                false
            } else {
                true
            }
        });
        self.tiles.borrow_mut().retain(|(k, _), v| {
            if in_range(k) {
                for b in v.iter() {
                    crate::backend::gpu_mem::record_free(&format!("weight:{k}"), b.size());
                    b.destroy();
                }
                removed += 1;
                false
            } else {
                true
            }
        });
        self.tile_meta.borrow_mut().retain(|(k, _), _| !in_range(k));
        removed
    }

    /// Internal: compute the row tiling layout for a 2-D quantized tensor.
    fn tile_layout(&self, name: &str, max_bytes_per_tile: usize) -> Result<TileLayout> {
        let desc = self.reader.tensor(name)?;
        if desc.dims.len() != 2 {
            return Err(RullamaError::Inference(format!(
                "buffer_tiles: tensor {name} has {} dims, expected 2",
                desc.dims.len()
            )));
        }
        let row_len = desc.dims[0] as usize;
        let n_rows = desc.dims[1] as usize;
        let block_elems = desc.dtype.block_elems();
        if !row_len.is_multiple_of(block_elems) {
            return Err(RullamaError::Inference(format!(
                "buffer_tiles: row_len {row_len} not multiple of block_elems {block_elems}"
            )));
        }
        let blocks_per_row = row_len / block_elems;
        let row_bytes = blocks_per_row * desc.dtype.block_bytes();
        if row_bytes == 0 {
            return Err(RullamaError::Inference(format!(
                "buffer_tiles: row_bytes is 0 for {name}"
            )));
        }
        let rows_per_tile = (max_bytes_per_tile / row_bytes).max(1);
        Ok(TileLayout {
            n_rows,
            row_bytes,
            rows_per_tile,
        })
    }

    /// Split a 2-D quantized tensor along its slow (second) axis into multiple GPU
    /// buffers, each ≤ `max_bytes_per_tile` bytes. Sync; in-memory reader only.
    pub fn buffer_tiles(&self, name: &str, max_bytes_per_tile: usize) -> Result<Vec<TiledTensor>> {
        let key = (name.to_string(), max_bytes_per_tile);
        if let Some(out) = self.tiles_cached(&key) {
            return Ok(out);
        }
        let layout = self.tile_layout(name, max_bytes_per_tile)?;
        let all_bytes = self.reader.tensor_bytes(name)?;

        let mut bufs = Vec::new();
        let mut metas = Vec::new();
        let mut row_start = 0usize;
        while row_start < layout.n_rows {
            let row_end = (row_start + layout.rows_per_tile).min(layout.n_rows);
            let byte_start = row_start * layout.row_bytes;
            let byte_end = row_end * layout.row_bytes;
            let chunk = &all_bytes[byte_start..byte_end];
            let buf = self.upload(&format!("{name}#tile{row_start}"), chunk);
            metas.push((row_start, row_end - row_start));
            bufs.push(buf);
            row_start = row_end;
        }

        Ok(self.commit_tiles(key, bufs, metas))
    }

    /// Async variant of [`buffer_tiles`]. Fetches each tile's bytes through the
    /// fetcher (one Range request per tile when streaming), uploads, drops the
    /// temporary buffer. Works for any reader.
    pub async fn buffer_tiles_async(
        &self,
        name: &str,
        max_bytes_per_tile: usize,
    ) -> Result<Vec<TiledTensor>> {
        let key = (name.to_string(), max_bytes_per_tile);
        if let Some(out) = self.tiles_cached(&key) {
            return Ok(out);
        }
        let layout = self.tile_layout(name, max_bytes_per_tile)?;

        // Per-tile fetch — only one tile's bytes live in wasm linear memory at a
        // time. The old code path pulled the whole tensor (315 MiB for
        // `token_embd.weight` in gemma4:e2b) into one `Vec<u8>` before tiling,
        // which on iPhone 16e (8 GB shared RAM) was the spike that crashed the
        // WebContent process during the first `step()` — even with 1 GB
        // `max_buffer_size`, the wasm-side 315 MB allocation on top of ~2 GB of
        // already-resident layer weights tipped iOS Jetsam over.
        let mut bufs = Vec::new();
        let mut metas = Vec::new();
        let mut row_start = 0usize;
        while row_start < layout.n_rows {
            let row_end = (row_start + layout.rows_per_tile).min(layout.n_rows);
            let byte_start = (row_start * layout.row_bytes) as u64;
            let byte_end = (row_end * layout.row_bytes) as u64;
            let chunk = self
                .reader
                .fetch_tensor_range(name, byte_start, byte_end - byte_start)
                .await?;
            let buf = self.upload(&format!("{name}#tile{row_start}"), &chunk);
            drop(chunk);
            metas.push((row_start, row_end - row_start));
            bufs.push(buf);
            row_start = row_end;
        }

        Ok(self.commit_tiles(key, bufs, metas))
    }

    /// Number of tiles `buffer_tiles_async` would split `name` into at this tile size.
    pub fn tile_count(&self, name: &str, max_bytes_per_tile: usize) -> Result<usize> {
        let layout = self.tile_layout(name, max_bytes_per_tile)?;
        Ok(layout.n_rows.div_ceil(layout.rows_per_tile))
    }

    /// Fetch ONE tile (by index) and upload it to a fresh GPU buffer **without caching**, so the
    /// caller can `.destroy()` it right after use. The memory-tight counterpart to
    /// `buffer_tiles_async`, which uploads *every* tile and caches them all (~315 MiB resident for
    /// `token_embd.weight` on gemma4:e2b — the brick at the iPhone forward→head jetsam wall).
    /// Streaming + destroying one tile at a time holds the head-projection peak at
    /// ~`max_bytes_per_tile` instead of the whole tensor.
    pub async fn fetch_tile_uncached(
        &self,
        name: &str,
        max_bytes_per_tile: usize,
        tile_idx: usize,
    ) -> Result<TiledTensor> {
        let layout = self.tile_layout(name, max_bytes_per_tile)?;
        let row_start = tile_idx * layout.rows_per_tile;
        if row_start >= layout.n_rows {
            return Err(RullamaError::Inference(format!(
                "fetch_tile_uncached: tile {tile_idx} out of range for {name} ({} rows)",
                layout.n_rows
            )));
        }
        let row_end = (row_start + layout.rows_per_tile).min(layout.n_rows);
        let byte_start = (row_start * layout.row_bytes) as u64;
        let byte_len = ((row_end - row_start) * layout.row_bytes) as u64;
        let chunk = self
            .reader
            .fetch_tensor_range(name, byte_start, byte_len)
            .await?;
        // Create the tile buffer directly, NOT via `upload` — `upload` calls
        // `gpu_mem::record_alloc("weight:…")`, but the caller destroys this tile one matmul later
        // and there's no matching `record_free`. Routing transient tiles through `upload` makes the
        // gpu_mem `w` counter climb one tile per call with no balancing free — the false
        // "+315 MiB/token" the iPhone training beacon showed (the GPU memory itself is freed fine by
        // the caller's invalidate + destroy). These tiles are intentionally untracked.
        // write_buffer requires a COPY_BUFFER_ALIGNMENT (4-byte) multiple. A
        // row-aligned tile can land non-4-aligned when row_bytes isn't /4 — e.g.
        // Q6_K row_bytes = (k/256)*210 is odd-multiple-of-210 when k/256 is odd
        // (the 12b's token_embd, d_model 3840 → 3150 B/row → an 8,388,450 B tile).
        // Pad up; the extra bytes sit past the last block of the last row and are
        // never read by the matmul kernel.
        let mut data = chunk;
        let padded_len = (data.len() + 3) & !3;
        data.resize(padded_len, 0);
        let buf = self.device.create_buffer(&wgpu::BufferDescriptor {
            label: Some("weight_cache.stream_tile"),
            size: data.len() as u64,
            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
            mapped_at_creation: false,
        });
        self.queue.write_buffer(&buf, 0, &data);
        drop(data);
        Ok(TiledTensor {
            buffer: buf,
            row_start,
            n_rows: row_end - row_start,
        })
    }

    fn tiles_cached(&self, key: &(String, usize)) -> Option<Vec<TiledTensor>> {
        let tiles = self.tiles.borrow();
        let meta = self.tile_meta.borrow();
        match (tiles.get(key), meta.get(key)) {
            (Some(bufs), Some(metas)) => Some(
                bufs.iter()
                    .zip(metas.iter())
                    .map(|(buf, &(row_start, n_rows))| TiledTensor {
                        buffer: buf.clone(),
                        row_start,
                        n_rows,
                    })
                    .collect(),
            ),
            _ => None,
        }
    }

    fn commit_tiles(
        &self,
        key: (String, usize),
        bufs: Vec<wgpu::Buffer>,
        metas: Vec<(usize, usize)>,
    ) -> Vec<TiledTensor> {
        let result: Vec<TiledTensor> = bufs
            .iter()
            .zip(metas.iter())
            .map(|(buf, &(rs, nr))| TiledTensor {
                buffer: buf.clone(),
                row_start: rs,
                n_rows: nr,
            })
            .collect();
        self.tiles.borrow_mut().insert(key.clone(), bufs);
        self.tile_meta.borrow_mut().insert(key, metas);
        result
    }
}

struct TileLayout {
    n_rows: usize,
    row_bytes: usize,
    rows_per_tile: usize,
}