use crate::error::{Error, Result};
use crate::nn::linear::MaybeQuantLinear;
use crate::nn::varmap::VarMap;
use crate::nn::weight::Weight;
use crate::quant::tensor::QuantTensor;
use crate::quant::traits::DequantOps;
use numr::dtype::DType;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
pub struct VarBuilder<'a, R: Runtime> {
varmap: &'a mut VarMap<R>,
prefix: String,
device: &'a R::Device,
}
impl<'a, R: Runtime> VarBuilder<'a, R> {
pub fn new(varmap: &'a mut VarMap<R>, device: &'a R::Device) -> Self {
Self {
varmap,
prefix: String::new(),
device,
}
}
pub fn push_prefix(&mut self, segment: &str) -> VarBuilder<'_, R> {
let prefix = if self.prefix.is_empty() {
segment.to_string()
} else {
format!("{}.{}", self.prefix, segment)
};
VarBuilder {
varmap: self.varmap,
prefix,
device: self.device,
}
}
pub fn pp(&mut self, segment: &str) -> VarBuilder<'_, R> {
self.push_prefix(segment)
}
fn full_name(&self, name: &str) -> String {
if self.prefix.is_empty() {
name.to_string()
} else {
format!("{}.{}", self.prefix, name)
}
}
pub fn get(&self, name: &str) -> Result<&Weight<R>> {
let full = self.full_name(name);
self.varmap.get(&full)
}
pub fn get_tensor(&self, name: &str) -> Result<&Tensor<R>> {
let full = self.full_name(name);
self.varmap.get_tensor(&full)
}
pub fn get_quant_tensor(&self, name: &str) -> Result<&QuantTensor<R>> {
let full = self.full_name(name);
self.varmap.get_quant_tensor(&full)
}
pub fn take_tensor(&mut self, name: &str) -> Result<Tensor<R>> {
let full = self.full_name(name);
self.varmap.take_tensor(&full)
}
pub fn take_tensor_optional(&mut self, name: &str) -> Result<Option<Tensor<R>>> {
if self.contains(name) {
self.take_tensor(name).map(Some)
} else {
Ok(None)
}
}
pub fn take_weight_optional(&mut self, name: &str) -> Result<Option<Weight<R>>> {
if self.contains(name) {
self.take_weight(name).map(Some)
} else {
Ok(None)
}
}
pub fn take_maybe_quant_linear_optional(
&mut self,
name: &str,
bias_name: Option<&str>,
) -> Result<Option<MaybeQuantLinear<R>>> {
if self.contains(name) {
self.take_maybe_quant_linear(name, bias_name).map(Some)
} else {
Ok(None)
}
}
pub fn take_quant_tensor(&mut self, name: &str) -> Result<QuantTensor<R>> {
let full = self.full_name(name);
self.varmap.take_quant_tensor(&full)
}
pub fn take_weight(&mut self, name: &str) -> Result<Weight<R>> {
let full = self.full_name(name);
self.varmap.take(&full)
}
pub fn take_maybe_quant_linear(
&mut self,
name: &str,
bias_name: Option<&str>,
) -> Result<MaybeQuantLinear<R>> {
let weight = self.take_weight(name)?;
let bias = match bias_name {
Some(bn) => {
if self.contains(bn) {
Some(self.take_tensor(bn)?)
} else {
None
}
}
None => None,
};
Ok(MaybeQuantLinear::from_weight(weight, bias))
}
pub fn get_with_shape(&self, name: &str, expected_shape: &[usize]) -> Result<&Tensor<R>> {
let full = self.full_name(name);
let t = self.varmap.get_tensor(&full)?;
if t.shape() != expected_shape {
return Err(Error::ModelError {
reason: format!(
"shape mismatch for '{}': expected {:?}, got {:?}",
full,
expected_shape,
t.shape()
),
});
}
Ok(t)
}
pub fn device(&self) -> &R::Device {
self.device
}
pub fn contains(&self, name: &str) -> bool {
let full = self.full_name(name);
self.varmap.contains(&full)
}
pub fn prefix(&self) -> &str {
&self.prefix
}
pub fn take_tensor_dequant(&mut self, name: &str, target_dtype: DType) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
R::Client: DequantOps<R>,
{
match self.take_weight(name)? {
Weight::Standard(t) => Ok(t),
Weight::Quantized(qt) => {
let client = R::default_client(self.device);
client.dequantize(&qt, target_dtype)
}
Weight::DecomposedQuant(_) => Err(Error::ModelError {
reason: "cannot dequantize decomposed quantized tensor to standard tensor".into(),
}),
}
}
pub fn take_tensor_shard(
&mut self,
name: &str,
dim: usize,
rank: usize,
world_size: usize,
) -> Result<Tensor<R>> {
let full = self.take_tensor(name)?;
let shape = full.shape();
if dim >= shape.len() {
return Err(Error::ModelError {
reason: format!(
"take_tensor_shard: dim {} out of range for {}D tensor '{}'",
dim,
shape.len(),
name
),
});
}
let dim_size = shape[dim];
if dim_size % world_size != 0 {
return Err(Error::ModelError {
reason: format!(
"take_tensor_shard: dim {} size ({}) not divisible by world_size ({}) for '{}'",
dim, dim_size, world_size, name
),
});
}
let shard_size = dim_size / world_size;
let start = rank * shard_size;
full.narrow(dim as isize, start, shard_size)
.map(|t| t.contiguous())
.map_err(|e| Error::ModelError {
reason: format!("take_tensor_shard narrow failed for '{}': {e}", name),
})
}
}
impl<R: Runtime> VarBuilder<'static, R> {
pub fn from_var_map(varmap: Box<VarMap<R>>, device: &'static R::Device) -> Self {
let varmap_ref: &'static mut VarMap<R> = Box::leak(varmap);
Self {
varmap: varmap_ref,
prefix: String::new(),
device,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::quant::QuantFormat;
use numr::runtime::cpu::{CpuDevice, CpuRuntime};
fn device() -> CpuDevice {
CpuDevice::new()
}
#[test]
fn test_varbuilder_prefix() {
let d = device();
let mut map = VarMap::<CpuRuntime>::new();
map.insert(
"model.layers.0.self_attn.q_proj.weight".into(),
Tensor::from_slice(&[1.0f32], &[1], &d),
);
let mut vb = VarBuilder::new(&mut map, &d);
let mut vb = vb.pp("model");
let mut vb = vb.pp("layers");
let mut vb = vb.pp("0");
let vb = vb.pp("self_attn");
let t = vb.get_tensor("q_proj.weight").unwrap();
assert_eq!(t.shape(), &[1]);
}
#[test]
fn test_varbuilder_get_with_shape() {
let d = device();
let mut map = VarMap::<CpuRuntime>::new();
map.insert(
"w".into(),
Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &d),
);
let vb = VarBuilder::new(&mut map, &d);
assert!(vb.get_with_shape("w", &[2, 2]).is_ok());
assert!(vb.get_with_shape("w", &[4]).is_err());
}
#[test]
fn test_varbuilder_take_tensor() {
let d = device();
let mut map = VarMap::<CpuRuntime>::new();
map.insert(
"layer.weight".into(),
Tensor::from_slice(&[1.0f32, 2.0], &[2], &d),
);
let mut vb = VarBuilder::new(&mut map, &d);
let mut vb = vb.pp("layer");
let t = vb.take_tensor("weight").unwrap();
assert_eq!(t.shape(), &[2]);
assert!(vb.take_tensor("weight").is_err());
}
#[test]
fn test_varbuilder_take_tensor_shard() {
let d = device();
let mut map = VarMap::<CpuRuntime>::new();
let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
map.insert("weight".into(), Tensor::from_slice(&data, &[4, 6], &d));
let vb = VarBuilder::new(&mut map, &d);
let data2: Vec<f32> = (0..24).map(|i| i as f32).collect();
drop(vb);
map.insert("weight".into(), Tensor::from_slice(&data2, &[4, 6], &d));
let mut vb = VarBuilder::new(&mut map, &d);
let shard = vb.take_tensor_shard("weight", 0, 0, 2).unwrap();
assert_eq!(shard.shape(), &[2, 6]);
let data3: Vec<f32> = (0..24).map(|i| i as f32).collect();
drop(vb);
map.insert("weight".into(), Tensor::from_slice(&data3, &[4, 6], &d));
let mut vb = VarBuilder::new(&mut map, &d);
let shard = vb.take_tensor_shard("weight", 1, 1, 2).unwrap();
assert_eq!(shard.shape(), &[4, 3]);
}
#[test]
fn test_varbuilder_take_tensor_shard_not_divisible() {
let d = device();
let mut map = VarMap::<CpuRuntime>::new();
map.insert(
"weight".into(),
Tensor::from_slice(&[1.0f32; 15], &[3, 5], &d),
);
let mut vb = VarBuilder::new(&mut map, &d);
assert!(vb.take_tensor_shard("weight", 0, 0, 2).is_err());
}
#[test]
fn test_varbuilder_quant_prefix() {
let d = device();
let mut map = VarMap::<CpuRuntime>::new();
let data = vec![0u8; 18];
let qt = QuantTensor::from_bytes(&data, QuantFormat::Q4_0, &[32], &d).unwrap();
map.insert_quant("layers.0.weight".into(), qt);
let mut vb = VarBuilder::new(&mut map, &d);
let mut vb = vb.pp("layers");
let vb = vb.pp("0");
let qt = vb.get_quant_tensor("weight").unwrap();
assert_eq!(qt.shape(), &[32]);
}
}