use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::Arc;
use crate::backend::{BindGroupCache, buf_id};
use crate::error::{Result, RullamaError};
use crate::gguf::{GgmlDtype, GgufReader};
pub struct TiledTensor {
pub buffer: wgpu::Buffer,
pub row_start: usize,
pub n_rows: usize,
}
type TileKey = (String, usize);
type TileMeta = Vec<(usize, usize)>;
pub struct WeightCache {
reader: Arc<GgufReader>,
device: wgpu::Device,
queue: wgpu::Queue,
bind_cache: Arc<BindGroupCache>,
buffers: RefCell<HashMap<String, wgpu::Buffer>>,
tiles: RefCell<HashMap<TileKey, Vec<wgpu::Buffer>>>,
tile_meta: RefCell<HashMap<TileKey, TileMeta>>,
}
impl WeightCache {
pub fn new(
reader: Arc<GgufReader>,
device: wgpu::Device,
queue: wgpu::Queue,
bind_cache: Arc<BindGroupCache>,
) -> Self {
Self {
reader,
device,
queue,
bind_cache,
buffers: RefCell::new(HashMap::new()),
tiles: RefCell::new(HashMap::new()),
tile_meta: RefCell::new(HashMap::new()),
}
}
pub fn reader(&self) -> &GgufReader {
&self.reader
}
pub fn reader_arc(&self) -> Arc<GgufReader> {
self.reader.clone()
}
fn upload(&self, name: &str, bytes: &[u8]) -> wgpu::Buffer {
let buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(name),
size: bytes.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
self.queue.write_buffer(&buf, 0, bytes);
crate::backend::gpu_mem::record_alloc(&format!("weight:{name}"), bytes.len() as u64);
buf
}
pub fn buffer(&self, name: &str) -> Result<wgpu::Buffer> {
if let Some(b) = self.buffers.borrow().get(name) {
return Ok(b.clone());
}
let bytes = self.reader.tensor_bytes(name)?;
let buf = self.upload(name, bytes);
let cloned = buf.clone();
self.buffers.borrow_mut().insert(name.to_string(), buf);
Ok(cloned)
}
pub async fn buffer_async(&self, name: &str) -> Result<wgpu::Buffer> {
if let Some(b) = self.buffers.borrow().get(name) {
return Ok(b.clone());
}
let bytes = self.reader.fetch_tensor_bytes(name).await?;
let buf = self.upload(name, &bytes);
drop(bytes);
let cloned = buf.clone();
self.buffers.borrow_mut().insert(name.to_string(), buf);
Ok(cloned)
}
pub fn buffer_opt(&self, name: &str) -> Result<Option<wgpu::Buffer>> {
if self.reader.tensor(name).is_err() {
return Ok(None);
}
self.buffer(name).map(Some)
}
pub async fn buffer_opt_async(&self, name: &str) -> Result<Option<wgpu::Buffer>> {
if self.reader.tensor(name).is_err() {
return Ok(None);
}
self.buffer_async(name).await.map(Some)
}
pub fn dtype(&self, name: &str) -> Result<GgmlDtype> {
Ok(self.reader.tensor(name)?.dtype)
}
pub fn cached_count(&self) -> usize {
self.buffers.borrow().len()
}
pub fn cached_bytes(&self) -> u64 {
let single: u64 = self.buffers.borrow().values().map(|b| b.size()).sum();
let tiled: u64 = self
.tiles
.borrow()
.values()
.flat_map(|v| v.iter().map(|b| b.size()))
.sum();
single + tiled
}
pub fn drop_prefix(&self, prefix: &str) -> usize {
let mut removed = 0usize;
self.buffers.borrow_mut().retain(|k, v| {
let hit = k.starts_with(prefix);
if hit {
crate::backend::gpu_mem::record_free(&format!("weight:{k}"), v.size());
removed += 1;
}
!hit
});
self.tiles.borrow_mut().retain(|(k, _), v| {
let hit = k.starts_with(prefix);
if hit {
for b in v.iter() {
crate::backend::gpu_mem::record_free(&format!("weight:{k}"), b.size());
}
removed += 1;
}
!hit
});
self.tile_meta
.borrow_mut()
.retain(|(k, _), _| !k.starts_with(prefix));
removed
}
pub fn drop_prefix_destroy(&self, prefix: &str) -> usize {
let mut victims: Vec<u64> = Vec::new();
for (k, v) in self.buffers.borrow().iter() {
if k.starts_with(prefix) {
victims.push(buf_id(v));
}
}
for ((k, _), tiles) in self.tiles.borrow().iter() {
if k.starts_with(prefix) {
for b in tiles {
victims.push(buf_id(b));
}
}
}
self.bind_cache.invalidate_buffers(&victims);
let mut removed = 0usize;
self.buffers.borrow_mut().retain(|k, v| {
if k.starts_with(prefix) {
crate::backend::gpu_mem::record_free(&format!("weight:{k}"), v.size());
v.destroy();
removed += 1;
false
} else {
true
}
});
self.tiles.borrow_mut().retain(|(k, _), v| {
if k.starts_with(prefix) {
for b in v.iter() {
crate::backend::gpu_mem::record_free(&format!("weight:{k}"), b.size());
b.destroy();
}
removed += 1;
false
} else {
true
}
});
self.tile_meta
.borrow_mut()
.retain(|(k, _), _| !k.starts_with(prefix));
removed
}
pub fn drop_blk_layer_range_destroy(&self, start_layer: u32, end_layer: u32) -> usize {
if end_layer <= start_layer {
return 0;
}
fn parse_blk_layer(key: &str) -> Option<u32> {
let rest = key.strip_prefix("blk.")?;
let dot = rest.find('.')?;
rest[..dot].parse().ok()
}
let in_range = |key: &str| -> bool {
match parse_blk_layer(key) {
Some(n) => n >= start_layer && n < end_layer,
None => false,
}
};
let mut victims: Vec<u64> = Vec::new();
for (k, v) in self.buffers.borrow().iter() {
if in_range(k) {
victims.push(buf_id(v));
}
}
for ((k, _), tiles) in self.tiles.borrow().iter() {
if in_range(k) {
for b in tiles {
victims.push(buf_id(b));
}
}
}
self.bind_cache.invalidate_buffers(&victims);
let mut removed = 0usize;
self.buffers.borrow_mut().retain(|k, v| {
if in_range(k) {
crate::backend::gpu_mem::record_free(&format!("weight:{k}"), v.size());
v.destroy();
removed += 1;
false
} else {
true
}
});
self.tiles.borrow_mut().retain(|(k, _), v| {
if in_range(k) {
for b in v.iter() {
crate::backend::gpu_mem::record_free(&format!("weight:{k}"), b.size());
b.destroy();
}
removed += 1;
false
} else {
true
}
});
self.tile_meta.borrow_mut().retain(|(k, _), _| !in_range(k));
removed
}
fn tile_layout(&self, name: &str, max_bytes_per_tile: usize) -> Result<TileLayout> {
let desc = self.reader.tensor(name)?;
if desc.dims.len() != 2 {
return Err(RullamaError::Inference(format!(
"buffer_tiles: tensor {name} has {} dims, expected 2",
desc.dims.len()
)));
}
let row_len = desc.dims[0] as usize;
let n_rows = desc.dims[1] as usize;
let block_elems = desc.dtype.block_elems();
if !row_len.is_multiple_of(block_elems) {
return Err(RullamaError::Inference(format!(
"buffer_tiles: row_len {row_len} not multiple of block_elems {block_elems}"
)));
}
let blocks_per_row = row_len / block_elems;
let row_bytes = blocks_per_row * desc.dtype.block_bytes();
if row_bytes == 0 {
return Err(RullamaError::Inference(format!(
"buffer_tiles: row_bytes is 0 for {name}"
)));
}
let rows_per_tile = (max_bytes_per_tile / row_bytes).max(1);
Ok(TileLayout {
n_rows,
row_bytes,
rows_per_tile,
})
}
pub fn buffer_tiles(&self, name: &str, max_bytes_per_tile: usize) -> Result<Vec<TiledTensor>> {
let key = (name.to_string(), max_bytes_per_tile);
if let Some(out) = self.tiles_cached(&key) {
return Ok(out);
}
let layout = self.tile_layout(name, max_bytes_per_tile)?;
let all_bytes = self.reader.tensor_bytes(name)?;
let mut bufs = Vec::new();
let mut metas = Vec::new();
let mut row_start = 0usize;
while row_start < layout.n_rows {
let row_end = (row_start + layout.rows_per_tile).min(layout.n_rows);
let byte_start = row_start * layout.row_bytes;
let byte_end = row_end * layout.row_bytes;
let chunk = &all_bytes[byte_start..byte_end];
let buf = self.upload(&format!("{name}#tile{row_start}"), chunk);
metas.push((row_start, row_end - row_start));
bufs.push(buf);
row_start = row_end;
}
Ok(self.commit_tiles(key, bufs, metas))
}
pub async fn buffer_tiles_async(
&self,
name: &str,
max_bytes_per_tile: usize,
) -> Result<Vec<TiledTensor>> {
let key = (name.to_string(), max_bytes_per_tile);
if let Some(out) = self.tiles_cached(&key) {
return Ok(out);
}
let layout = self.tile_layout(name, max_bytes_per_tile)?;
let mut bufs = Vec::new();
let mut metas = Vec::new();
let mut row_start = 0usize;
while row_start < layout.n_rows {
let row_end = (row_start + layout.rows_per_tile).min(layout.n_rows);
let byte_start = (row_start * layout.row_bytes) as u64;
let byte_end = (row_end * layout.row_bytes) as u64;
let chunk = self
.reader
.fetch_tensor_range(name, byte_start, byte_end - byte_start)
.await?;
let buf = self.upload(&format!("{name}#tile{row_start}"), &chunk);
drop(chunk);
metas.push((row_start, row_end - row_start));
bufs.push(buf);
row_start = row_end;
}
Ok(self.commit_tiles(key, bufs, metas))
}
fn tiles_cached(&self, key: &(String, usize)) -> Option<Vec<TiledTensor>> {
let tiles = self.tiles.borrow();
let meta = self.tile_meta.borrow();
match (tiles.get(key), meta.get(key)) {
(Some(bufs), Some(metas)) => Some(
bufs.iter()
.zip(metas.iter())
.map(|(buf, &(row_start, n_rows))| TiledTensor {
buffer: buf.clone(),
row_start,
n_rows,
})
.collect(),
),
_ => None,
}
}
fn commit_tiles(
&self,
key: (String, usize),
bufs: Vec<wgpu::Buffer>,
metas: Vec<(usize, usize)>,
) -> Vec<TiledTensor> {
let result: Vec<TiledTensor> = bufs
.iter()
.zip(metas.iter())
.map(|(buf, &(rs, nr))| TiledTensor {
buffer: buf.clone(),
row_start: rs,
n_rows: nr,
})
.collect();
self.tiles.borrow_mut().insert(key.clone(), bufs);
self.tile_meta.borrow_mut().insert(key, metas);
result
}
}
struct TileLayout {
n_rows: usize,
row_bytes: usize,
rows_per_tile: usize,
}