1use candle::quantized::QTensor;
2use candle::{DType, Device, Module, Result, Shape, Tensor};
3use candle_transformers::quantized_nn as candle_qnn;
4use candle_transformers::quantized_var_builder::VarBuilder as QuantizedVarBuilder;
5
6use std::sync::Arc;
7
8#[derive(Clone)]
9pub enum MaybeQuantizedWeight {
10 Real(Tensor),
12 Quantized(Arc<QTensor>),
13}
14
15impl MaybeQuantizedWeight {
16 fn to_tensor(&self, dev: &Device) -> Result<Tensor> {
17 match self {
18 Self::Real(t) => Ok(t.clone()),
19 Self::Quantized(t) => t.dequantize(dev),
20 }
21 }
22}
23
24pub fn matmul_dtype(device: &candle::Device) -> DType {
25 if device.is_cuda() {
27 DType::BF16
28 } else {
29 DType::F32
30 }
31}
32
33#[derive(Clone)]
34pub enum MaybeQuantizedVarBuilder<'a> {
35 Real(candle_nn::VarBuilder<'a>),
37 Quantized(QuantizedVarBuilder),
38}
39
40impl MaybeQuantizedVarBuilder<'_> {
41 pub fn pp<S: ToString>(&self, s: S) -> Self {
42 match self {
43 Self::Real(weights) => MaybeQuantizedVarBuilder::Real(weights.pp(s)),
44 Self::Quantized(weights) => MaybeQuantizedVarBuilder::Quantized(weights.pp(s)),
45 }
46 }
47
48 pub fn get<S: Into<Shape>>(&self, s: S, path: &str) -> Result<MaybeQuantizedWeight> {
49 let w = match self {
50 Self::Real(weights) => MaybeQuantizedWeight::Real(weights.get(s, path)?),
51 Self::Quantized(weights) => MaybeQuantizedWeight::Quantized(weights.get(s, path)?),
52 };
53 Ok(w)
54 }
55
56 pub fn get_as_tensor<S: Into<Shape>>(&self, s: S, path: &str) -> Result<Tensor> {
57 let w = match self {
58 Self::Real(weights) => MaybeQuantizedWeight::Real(weights.get(s, path)?),
59 Self::Quantized(weights) => MaybeQuantizedWeight::Quantized(weights.get(s, path)?),
60 };
61 w.to_tensor(self.device())
62 }
63
64 pub fn get_unquantized<S: Into<Shape>>(&self, s: S, path: &str) -> Result<Tensor> {
65 match self {
66 Self::Real(weights) => weights.get(s, path),
67 Self::Quantized(weights) => weights.get(s, path)?.dequantize(weights.device()),
68 }
69 }
70
71 pub fn contains_key(&self, name: &str) -> bool {
72 match self {
73 Self::Real(weights) => weights.contains_tensor(name),
74 Self::Quantized(weights) => weights.contains_key(name),
75 }
76 }
77
78 pub fn device(&self) -> &Device {
79 match self {
80 Self::Real(weights) => weights.device(),
81 Self::Quantized(weights) => weights.device(),
82 }
83 }
84
85 pub fn dtype(&self) -> DType {
86 match self {
87 Self::Real(weights) => weights.dtype(),
88 Self::Quantized(_) => DType::F32,
89 }
90 }
91}
92
93#[derive(Debug, Clone)]
94pub enum MaybeQuantizedLinear {
95 Real(candle_nn::Linear),
96 Quantized(candle_qnn::Linear),
97}
98
99impl Module for MaybeQuantizedLinear {
100 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
101 match self {
102 Self::Real(module) => module.forward(xs),
103 Self::Quantized(module) => module.forward(xs),
104 }
105 }
106}
107
108impl MaybeQuantizedLinear {
109 pub fn dtype(&self) -> DType {
110 match self {
111 Self::Real(l) => l.weight().dtype(),
112 Self::Quantized(_) => DType::F32,
113 }
114 }
115}
116
117#[derive(Debug, Clone)]
118pub enum MaybeQuantizedEmbedding {
119 Real(candle_nn::Embedding),
120 Quantized(candle_qnn::Embedding),
121}
122
123impl MaybeQuantizedEmbedding {
124 pub fn new(in_vocab_size: usize, dim: usize, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
125 let emb = match vb {
126 MaybeQuantizedVarBuilder::Real(weights) => {
127 MaybeQuantizedEmbedding::Real(candle_nn::embedding(in_vocab_size, dim, weights)?)
128 }
129 MaybeQuantizedVarBuilder::Quantized(weights) => MaybeQuantizedEmbedding::Quantized(
130 candle_transformers::quantized_nn::Embedding::new(in_vocab_size, dim, weights)?,
131 ),
132 };
133 Ok(emb)
134 }
135
136 pub fn embeddings(&self) -> &Tensor {
137 match self {
138 MaybeQuantizedEmbedding::Real(weights) => weights.embeddings(),
139 MaybeQuantizedEmbedding::Quantized(weights) => weights.embeddings(),
140 }
141 }
142
143 pub fn hidden_size(&self) -> Result<usize> {
144 let size = match self {
145 MaybeQuantizedEmbedding::Real(weights) => weights.hidden_size(),
146 MaybeQuantizedEmbedding::Quantized(weights) => weights.embeddings().dim(1)?,
147 };
148 Ok(size)
149 }
150
151 pub fn dtype(&self) -> DType {
152 match self {
153 Self::Real(l) => l.embeddings().dtype(),
154 Self::Quantized(_) => DType::F32,
155 }
156 }
157}
158
159impl Module for MaybeQuantizedEmbedding {
160 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
161 match self {
162 Self::Real(module) => module.forward(xs),
163 Self::Quantized(module) => module.forward(xs),
164 }
165 }
166}
167
168pub fn linear(
169 in_d: usize,
170 out_d: usize,
171 bias: bool,
172 vb: MaybeQuantizedVarBuilder,
173) -> Result<MaybeQuantizedLinear> {
174 let output_linear = match vb {
175 MaybeQuantizedVarBuilder::Real(weights) => {
176 if bias {
177 MaybeQuantizedLinear::Real(candle_nn::linear(in_d, out_d, weights)?)
178 } else {
179 MaybeQuantizedLinear::Real(candle_nn::linear_no_bias(in_d, out_d, weights)?)
180 }
181 }
182 MaybeQuantizedVarBuilder::Quantized(weights) => {
183 MaybeQuantizedLinear::Quantized(candle_qnn::linear_b(in_d, out_d, bias, weights)?)
184 }
185 };
186 Ok(output_linear)
187}
188
189pub fn linear_from(
190 weight: MaybeQuantizedWeight,
191 bias: Option<Tensor>,
192) -> Result<MaybeQuantizedLinear> {
193 let layer = match weight {
194 MaybeQuantizedWeight::Real(w) => {
195 MaybeQuantizedLinear::Real(candle_nn::Linear::new(w, bias))
196 }
197 MaybeQuantizedWeight::Quantized(w) => {
198 MaybeQuantizedLinear::Quantized(candle_qnn::Linear::from_arc(w, bias)?)
199 }
200 };
201 Ok(layer)
202}