candle_transformers/
quantized_var_builder.rs1use candle::quantized::QTensor;
8use candle::{Device, Result, Shape};
9use std::sync::Arc;
10
11#[derive(Clone)]
13pub struct VarBuilder {
14 data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
15 path: Vec<String>,
16 device: Device,
17}
18
19impl VarBuilder {
20 pub fn from_gguf<P: AsRef<std::path::Path>>(p: P, device: &Device) -> Result<Self> {
21 let mut file = std::fs::File::open(p)?;
22 let content = candle::quantized::gguf_file::Content::read(&mut file)?;
23 let mut data = std::collections::HashMap::new();
24 for tensor_name in content.tensor_infos.keys() {
25 let tensor = content.tensor(&mut file, tensor_name, device)?;
26 data.insert(tensor_name.to_string(), Arc::new(tensor));
27 }
28 Ok(Self {
29 data: Arc::new(data),
30 path: Vec::new(),
31 device: device.clone(),
32 })
33 }
34
35 pub fn from_gguf_buffer(buffer: &[u8], device: &Device) -> Result<Self> {
36 let mut cursor = std::io::Cursor::new(buffer);
37 let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
38 let mut data = std::collections::HashMap::new();
39 for tensor_name in content.tensor_infos.keys() {
40 let tensor = content.tensor(&mut cursor, tensor_name, device)?;
41 data.insert(tensor_name.to_string(), Arc::new(tensor));
42 }
43 Ok(Self {
44 data: Arc::new(data),
45 path: Vec::new(),
46 device: device.clone(),
47 })
48 }
49
50 pub fn pp<S: ToString>(&self, s: S) -> Self {
51 let mut path = self.path.clone();
52 path.push(s.to_string());
53 Self {
54 data: self.data.clone(),
55 path,
56 device: self.device.clone(),
57 }
58 }
59
60 fn path(&self, tensor_name: &str) -> String {
61 if self.path.is_empty() {
62 tensor_name.to_string()
63 } else {
64 [&self.path.join("."), tensor_name].join(".")
65 }
66 }
67
68 pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Arc<QTensor>> {
69 let path = self.path(name);
70 match self.data.get(&path) {
71 None => {
72 candle::bail!("cannot find tensor {path}")
73 }
74 Some(qtensor) => {
75 let shape = s.into();
76 if qtensor.shape() != &shape {
77 candle::bail!(
78 "shape mismatch for {name}, got {:?}, expected {shape:?}",
79 qtensor.shape()
80 )
81 }
82 Ok(qtensor.clone())
83 }
84 }
85 }
86
87 pub fn get_no_shape(&self, name: &str) -> Result<Arc<QTensor>> {
88 let path = self.path(name);
89 match self.data.get(&path) {
90 None => {
91 candle::bail!("cannot find tensor {name}")
92 }
93 Some(qtensor) => Ok(qtensor.clone()),
94 }
95 }
96
97 pub fn device(&self) -> &Device {
98 &self.device
99 }
100
101 pub fn contains_key(&self, key: &str) -> bool {
102 self.data.contains_key(key)
103 }
104}