use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::Arc;
use crate::error::{Result, RullamaError};
use crate::gguf::{GgmlDtype, GgufReader};
pub struct TiledTensor {
pub buffer: wgpu::Buffer,
pub row_start: usize,
pub n_rows: usize,
}
type TileKey = (String, usize);
type TileMeta = Vec<(usize, usize)>;
pub struct WeightCache {
reader: Arc<GgufReader>,
device: wgpu::Device,
queue: wgpu::Queue,
buffers: RefCell<HashMap<String, wgpu::Buffer>>,
tiles: RefCell<HashMap<TileKey, Vec<wgpu::Buffer>>>,
tile_meta: RefCell<HashMap<TileKey, TileMeta>>,
}
impl WeightCache {
pub fn new(reader: Arc<GgufReader>, device: wgpu::Device, queue: wgpu::Queue) -> Self {
Self {
reader,
device,
queue,
buffers: RefCell::new(HashMap::new()),
tiles: RefCell::new(HashMap::new()),
tile_meta: RefCell::new(HashMap::new()),
}
}
pub fn reader(&self) -> &GgufReader {
&self.reader
}
pub fn reader_arc(&self) -> Arc<GgufReader> {
self.reader.clone()
}
fn upload(&self, name: &str, bytes: &[u8]) -> wgpu::Buffer {
let buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(name),
size: bytes.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
self.queue.write_buffer(&buf, 0, bytes);
buf
}
pub fn buffer(&self, name: &str) -> Result<wgpu::Buffer> {
if let Some(b) = self.buffers.borrow().get(name) {
return Ok(b.clone());
}
let bytes = self.reader.tensor_bytes(name)?;
let buf = self.upload(name, bytes);
let cloned = buf.clone();
self.buffers.borrow_mut().insert(name.to_string(), buf);
Ok(cloned)
}
pub async fn buffer_async(&self, name: &str) -> Result<wgpu::Buffer> {
if let Some(b) = self.buffers.borrow().get(name) {
return Ok(b.clone());
}
let bytes = self.reader.fetch_tensor_bytes(name).await?;
let buf = self.upload(name, &bytes);
drop(bytes);
let cloned = buf.clone();
self.buffers.borrow_mut().insert(name.to_string(), buf);
Ok(cloned)
}
pub fn buffer_opt(&self, name: &str) -> Result<Option<wgpu::Buffer>> {
if self.reader.tensor(name).is_err() {
return Ok(None);
}
self.buffer(name).map(Some)
}
pub async fn buffer_opt_async(&self, name: &str) -> Result<Option<wgpu::Buffer>> {
if self.reader.tensor(name).is_err() {
return Ok(None);
}
self.buffer_async(name).await.map(Some)
}
pub fn dtype(&self, name: &str) -> Result<GgmlDtype> {
Ok(self.reader.tensor(name)?.dtype)
}
pub fn cached_count(&self) -> usize {
self.buffers.borrow().len()
}
pub fn cached_bytes(&self) -> u64 {
let single: u64 = self.buffers.borrow().values().map(|b| b.size()).sum();
let tiled: u64 = self
.tiles
.borrow()
.values()
.flat_map(|v| v.iter().map(|b| b.size()))
.sum();
single + tiled
}
pub fn drop_prefix(&self, prefix: &str) -> usize {
let mut removed = 0usize;
self.buffers.borrow_mut().retain(|k, _| {
if k.starts_with(prefix) {
removed += 1;
false
} else {
true
}
});
self.tiles.borrow_mut().retain(|(k, _), _| {
if k.starts_with(prefix) {
removed += 1;
false
} else {
true
}
});
self.tile_meta
.borrow_mut()
.retain(|(k, _), _| !k.starts_with(prefix));
removed
}
fn tile_layout(&self, name: &str, max_bytes_per_tile: usize) -> Result<TileLayout> {
let desc = self.reader.tensor(name)?;
if desc.dims.len() != 2 {
return Err(RullamaError::Inference(format!(
"buffer_tiles: tensor {name} has {} dims, expected 2",
desc.dims.len()
)));
}
let row_len = desc.dims[0] as usize;
let n_rows = desc.dims[1] as usize;
let block_elems = desc.dtype.block_elems();
if !row_len.is_multiple_of(block_elems) {
return Err(RullamaError::Inference(format!(
"buffer_tiles: row_len {row_len} not multiple of block_elems {block_elems}"
)));
}
let blocks_per_row = row_len / block_elems;
let row_bytes = blocks_per_row * desc.dtype.block_bytes();
if row_bytes == 0 {
return Err(RullamaError::Inference(format!(
"buffer_tiles: row_bytes is 0 for {name}"
)));
}
let rows_per_tile = (max_bytes_per_tile / row_bytes).max(1);
Ok(TileLayout {
n_rows,
row_bytes,
rows_per_tile,
})
}
pub fn buffer_tiles(&self, name: &str, max_bytes_per_tile: usize) -> Result<Vec<TiledTensor>> {
let key = (name.to_string(), max_bytes_per_tile);
if let Some(out) = self.tiles_cached(&key) {
return Ok(out);
}
let layout = self.tile_layout(name, max_bytes_per_tile)?;
let all_bytes = self.reader.tensor_bytes(name)?;
let mut bufs = Vec::new();
let mut metas = Vec::new();
let mut row_start = 0usize;
while row_start < layout.n_rows {
let row_end = (row_start + layout.rows_per_tile).min(layout.n_rows);
let byte_start = row_start * layout.row_bytes;
let byte_end = row_end * layout.row_bytes;
let chunk = &all_bytes[byte_start..byte_end];
let buf = self.upload(&format!("{name}#tile{row_start}"), chunk);
metas.push((row_start, row_end - row_start));
bufs.push(buf);
row_start = row_end;
}
Ok(self.commit_tiles(key, bufs, metas))
}
pub async fn buffer_tiles_async(
&self,
name: &str,
max_bytes_per_tile: usize,
) -> Result<Vec<TiledTensor>> {
let key = (name.to_string(), max_bytes_per_tile);
if let Some(out) = self.tiles_cached(&key) {
return Ok(out);
}
let layout = self.tile_layout(name, max_bytes_per_tile)?;
let mut bufs = Vec::new();
let mut metas = Vec::new();
let mut row_start = 0usize;
while row_start < layout.n_rows {
let row_end = (row_start + layout.rows_per_tile).min(layout.n_rows);
let byte_start = (row_start * layout.row_bytes) as u64;
let byte_end = (row_end * layout.row_bytes) as u64;
let chunk = self
.reader
.fetch_tensor_range(name, byte_start, byte_end - byte_start)
.await?;
let buf = self.upload(&format!("{name}#tile{row_start}"), &chunk);
drop(chunk);
metas.push((row_start, row_end - row_start));
bufs.push(buf);
row_start = row_end;
}
Ok(self.commit_tiles(key, bufs, metas))
}
fn tiles_cached(&self, key: &(String, usize)) -> Option<Vec<TiledTensor>> {
let tiles = self.tiles.borrow();
let meta = self.tile_meta.borrow();
match (tiles.get(key), meta.get(key)) {
(Some(bufs), Some(metas)) => Some(
bufs.iter()
.zip(metas.iter())
.map(|(buf, &(row_start, n_rows))| TiledTensor {
buffer: buf.clone(),
row_start,
n_rows,
})
.collect(),
),
_ => None,
}
}
fn commit_tiles(
&self,
key: (String, usize),
bufs: Vec<wgpu::Buffer>,
metas: Vec<(usize, usize)>,
) -> Vec<TiledTensor> {
let result: Vec<TiledTensor> = bufs
.iter()
.zip(metas.iter())
.map(|(buf, &(rs, nr))| TiledTensor {
buffer: buf.clone(),
row_start: rs,
n_rows: nr,
})
.collect();
self.tiles.borrow_mut().insert(key.clone(), bufs);
self.tile_meta.borrow_mut().insert(key, metas);
result
}
}
struct TileLayout {
n_rows: usize,
row_bytes: usize,
rows_per_tile: usize,
}