use super::dtype::GgmlDtype;
use super::quant::dequant_into_f32;
use super::reader::GgufReader;
use crate::error::{Result, RullamaError};
pub fn dequant_tensor_to_f32(r: &GgufReader, name: &str) -> Result<Vec<f32>> {
let desc = r.tensor(name)?;
let bytes = r.tensor_bytes(name)?;
let elems = desc.elem_count() as usize;
let mut out = vec![0f32; elems];
dequant_into_f32(desc.dtype, bytes, &mut out)?;
Ok(out)
}
pub fn dequant_row_to_f32(r: &GgufReader, name: &str, row_idx: usize) -> Result<Vec<f32>> {
let desc = r.tensor(name)?;
if desc.dims.len() != 2 {
return Err(RullamaError::Gguf(format!(
"dequant_row_to_f32: tensor {} has {} dims, expected 2",
desc.name,
desc.dims.len()
)));
}
let row_len = desc.dims[0] as usize;
let n_rows = desc.dims[1] as usize;
if row_idx >= n_rows {
return Err(RullamaError::Gguf(format!(
"row {row_idx} out of bounds for tensor {} ({} rows)",
desc.name, n_rows
)));
}
let block_elems = desc.dtype.block_elems();
if !row_len.is_multiple_of(block_elems) {
return Err(RullamaError::Gguf(format!(
"row_len {} not multiple of block_elems {} for {}",
row_len, block_elems, desc.name
)));
}
let blocks_per_row = row_len / block_elems;
let bytes_per_row = blocks_per_row * desc.dtype.block_bytes();
let all_bytes = r.tensor_bytes(name)?;
let start = row_idx * bytes_per_row;
let end = start + bytes_per_row;
if end > all_bytes.len() {
return Err(RullamaError::Gguf(format!(
"row bytes {start}..{end} extend past tensor data {} for {}",
all_bytes.len(),
desc.name
)));
}
let mut out = vec![0f32; row_len];
dequant_into_f32(desc.dtype, &all_bytes[start..end], &mut out)?;
Ok(out)
}
#[allow(dead_code)]
pub(crate) fn dtype_of(r: &GgufReader, name: &str) -> Result<GgmlDtype> {
Ok(r.tensor(name)?.dtype)
}
pub async fn dequant_tensor_to_f32_async(r: &GgufReader, name: &str) -> Result<Vec<f32>> {
let desc = r.tensor(name)?.clone();
let bytes = r.fetch_tensor_bytes(name).await?;
let elems = desc.elem_count() as usize;
let mut out = vec![0f32; elems];
dequant_into_f32(desc.dtype, &bytes, &mut out)?;
Ok(out)
}
pub async fn dequant_row_to_f32_async(
r: &GgufReader,
name: &str,
row_idx: usize,
) -> Result<Vec<f32>> {
let desc = r.tensor(name)?.clone();
if desc.dims.len() != 2 {
return Err(RullamaError::Gguf(format!(
"dequant_row_to_f32_async: tensor {} has {} dims, expected 2",
desc.name,
desc.dims.len()
)));
}
let row_len = desc.dims[0] as usize;
let n_rows = desc.dims[1] as usize;
if row_idx >= n_rows {
return Err(RullamaError::Gguf(format!(
"row {row_idx} out of bounds for tensor {} ({} rows)",
desc.name, n_rows
)));
}
let block_elems = desc.dtype.block_elems();
if !row_len.is_multiple_of(block_elems) {
return Err(RullamaError::Gguf(format!(
"row_len {} not multiple of block_elems {} for {}",
row_len, block_elems, desc.name
)));
}
let blocks_per_row = row_len / block_elems;
let bytes_per_row = blocks_per_row * desc.dtype.block_bytes();
let row_bytes = {
let abs_offset =
(r.tensor(name)?.offset + (row_idx * bytes_per_row) as u64) + r.data_section_offset();
r.fetcher().fetch(abs_offset, bytes_per_row as u64).await?
};
let mut out = vec![0f32; row_len];
dequant_into_f32(desc.dtype, &row_bytes, &mut out)?;
Ok(out)
}