Skip to main content

candle_transformers/
quantized_var_builder.rs

1//! Varbuilder for Loading gguf files
2//!
3//! VarBuilder is a utility to store quantized tensors from a [GGUF model file](https://huggingface.co/docs/hub/gguf).
4//! These tensors can be loaded from disk using `from_gguf` or from an in-memory
5//! buffer using `from_gguf_buffer`.
6
7use candle::quantized::QTensor;
8use candle::{Device, Result, Shape};
9use std::sync::Arc;
10
11// VarBuilder specialized for QTensors
12#[derive(Clone)]
13pub struct VarBuilder {
14    data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
15    path: Vec<String>,
16    device: Device,
17}
18
19impl VarBuilder {
20    pub fn from_gguf<P: AsRef<std::path::Path>>(p: P, device: &Device) -> Result<Self> {
21        let mut file = std::fs::File::open(p)?;
22        let content = candle::quantized::gguf_file::Content::read(&mut file)?;
23        let mut data = std::collections::HashMap::new();
24        for tensor_name in content.tensor_infos.keys() {
25            let tensor = content.tensor(&mut file, tensor_name, device)?;
26            data.insert(tensor_name.to_string(), Arc::new(tensor));
27        }
28        Ok(Self {
29            data: Arc::new(data),
30            path: Vec::new(),
31            device: device.clone(),
32        })
33    }
34
35    pub fn from_gguf_buffer(buffer: &[u8], device: &Device) -> Result<Self> {
36        let mut cursor = std::io::Cursor::new(buffer);
37        let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
38        let mut data = std::collections::HashMap::new();
39        for tensor_name in content.tensor_infos.keys() {
40            let tensor = content.tensor(&mut cursor, tensor_name, device)?;
41            data.insert(tensor_name.to_string(), Arc::new(tensor));
42        }
43        Ok(Self {
44            data: Arc::new(data),
45            path: Vec::new(),
46            device: device.clone(),
47        })
48    }
49
50    pub fn pp<S: ToString>(&self, s: S) -> Self {
51        let mut path = self.path.clone();
52        path.push(s.to_string());
53        Self {
54            data: self.data.clone(),
55            path,
56            device: self.device.clone(),
57        }
58    }
59
60    fn path(&self, tensor_name: &str) -> String {
61        if self.path.is_empty() {
62            tensor_name.to_string()
63        } else {
64            [&self.path.join("."), tensor_name].join(".")
65        }
66    }
67
68    pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Arc<QTensor>> {
69        let path = self.path(name);
70        match self.data.get(&path) {
71            None => {
72                candle::bail!("cannot find tensor {path}")
73            }
74            Some(qtensor) => {
75                let shape = s.into();
76                if qtensor.shape() != &shape {
77                    candle::bail!(
78                        "shape mismatch for {name}, got {:?}, expected {shape:?}",
79                        qtensor.shape()
80                    )
81                }
82                Ok(qtensor.clone())
83            }
84        }
85    }
86
87    pub fn get_no_shape(&self, name: &str) -> Result<Arc<QTensor>> {
88        let path = self.path(name);
89        match self.data.get(&path) {
90            None => {
91                candle::bail!("cannot find tensor {name}")
92            }
93            Some(qtensor) => Ok(qtensor.clone()),
94        }
95    }
96
97    pub fn device(&self) -> &Device {
98        &self.device
99    }
100
101    pub fn contains_key(&self, key: &str) -> bool {
102        self.data.contains_key(key)
103    }
104}