moshi_db/
nn.rs

1use candle::quantized::QTensor;
2use candle::{DType, Device, Module, Result, Shape, Tensor};
3use candle_transformers::quantized_nn as candle_qnn;
4use candle_transformers::quantized_var_builder::VarBuilder as QuantizedVarBuilder;
5
6use std::sync::Arc;
7
8#[derive(Clone)]
9pub enum MaybeQuantizedWeight {
10    // Enum types around real and quantized model weights
11    Real(Tensor),
12    Quantized(Arc<QTensor>),
13}
14
15impl MaybeQuantizedWeight {
16    fn to_tensor(&self, dev: &Device) -> Result<Tensor> {
17        match self {
18            Self::Real(t) => Ok(t.clone()),
19            Self::Quantized(t) => t.dequantize(dev),
20        }
21    }
22}
23
24pub fn matmul_dtype(device: &candle::Device) -> DType {
25    // Dtype used for intermediate matmul in attention during quantized execution
26    if device.is_cuda() {
27        DType::BF16
28    } else {
29        DType::F32
30    }
31}
32
33#[derive(Clone)]
34pub enum MaybeQuantizedVarBuilder<'a> {
35    // Enum types around real and quantized var builders
36    Real(candle_nn::VarBuilder<'a>),
37    Quantized(QuantizedVarBuilder),
38}
39
40impl MaybeQuantizedVarBuilder<'_> {
41    pub fn pp<S: ToString>(&self, s: S) -> Self {
42        match self {
43            Self::Real(weights) => MaybeQuantizedVarBuilder::Real(weights.pp(s)),
44            Self::Quantized(weights) => MaybeQuantizedVarBuilder::Quantized(weights.pp(s)),
45        }
46    }
47
48    pub fn get<S: Into<Shape>>(&self, s: S, path: &str) -> Result<MaybeQuantizedWeight> {
49        let w = match self {
50            Self::Real(weights) => MaybeQuantizedWeight::Real(weights.get(s, path)?),
51            Self::Quantized(weights) => MaybeQuantizedWeight::Quantized(weights.get(s, path)?),
52        };
53        Ok(w)
54    }
55
56    pub fn get_as_tensor<S: Into<Shape>>(&self, s: S, path: &str) -> Result<Tensor> {
57        let w = match self {
58            Self::Real(weights) => MaybeQuantizedWeight::Real(weights.get(s, path)?),
59            Self::Quantized(weights) => MaybeQuantizedWeight::Quantized(weights.get(s, path)?),
60        };
61        w.to_tensor(self.device())
62    }
63
64    pub fn get_unquantized<S: Into<Shape>>(&self, s: S, path: &str) -> Result<Tensor> {
65        match self {
66            Self::Real(weights) => weights.get(s, path),
67            Self::Quantized(weights) => weights.get(s, path)?.dequantize(weights.device()),
68        }
69    }
70
71    pub fn contains_key(&self, name: &str) -> bool {
72        match self {
73            Self::Real(weights) => weights.contains_tensor(name),
74            Self::Quantized(weights) => weights.contains_key(name),
75        }
76    }
77
78    pub fn device(&self) -> &Device {
79        match self {
80            Self::Real(weights) => weights.device(),
81            Self::Quantized(weights) => weights.device(),
82        }
83    }
84
85    pub fn dtype(&self) -> DType {
86        match self {
87            Self::Real(weights) => weights.dtype(),
88            Self::Quantized(_) => DType::F32,
89        }
90    }
91}
92
93#[derive(Debug, Clone)]
94pub enum MaybeQuantizedLinear {
95    Real(candle_nn::Linear),
96    Quantized(candle_qnn::Linear),
97}
98
99impl Module for MaybeQuantizedLinear {
100    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
101        match self {
102            Self::Real(module) => module.forward(xs),
103            Self::Quantized(module) => module.forward(xs),
104        }
105    }
106}
107
108impl MaybeQuantizedLinear {
109    pub fn dtype(&self) -> DType {
110        match self {
111            Self::Real(l) => l.weight().dtype(),
112            Self::Quantized(_) => DType::F32,
113        }
114    }
115}
116
117#[derive(Debug, Clone)]
118pub enum MaybeQuantizedEmbedding {
119    Real(candle_nn::Embedding),
120    Quantized(candle_qnn::Embedding),
121}
122
123impl MaybeQuantizedEmbedding {
124    pub fn new(in_vocab_size: usize, dim: usize, vb: MaybeQuantizedVarBuilder) -> Result<Self> {
125        let emb = match vb {
126            MaybeQuantizedVarBuilder::Real(weights) => {
127                MaybeQuantizedEmbedding::Real(candle_nn::embedding(in_vocab_size, dim, weights)?)
128            }
129            MaybeQuantizedVarBuilder::Quantized(weights) => MaybeQuantizedEmbedding::Quantized(
130                candle_transformers::quantized_nn::Embedding::new(in_vocab_size, dim, weights)?,
131            ),
132        };
133        Ok(emb)
134    }
135
136    pub fn embeddings(&self) -> &Tensor {
137        match self {
138            MaybeQuantizedEmbedding::Real(weights) => weights.embeddings(),
139            MaybeQuantizedEmbedding::Quantized(weights) => weights.embeddings(),
140        }
141    }
142
143    pub fn hidden_size(&self) -> Result<usize> {
144        let size = match self {
145            MaybeQuantizedEmbedding::Real(weights) => weights.hidden_size(),
146            MaybeQuantizedEmbedding::Quantized(weights) => weights.embeddings().dim(1)?,
147        };
148        Ok(size)
149    }
150
151    pub fn dtype(&self) -> DType {
152        match self {
153            Self::Real(l) => l.embeddings().dtype(),
154            Self::Quantized(_) => DType::F32,
155        }
156    }
157}
158
159impl Module for MaybeQuantizedEmbedding {
160    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
161        match self {
162            Self::Real(module) => module.forward(xs),
163            Self::Quantized(module) => module.forward(xs),
164        }
165    }
166}
167
168pub fn linear(
169    in_d: usize,
170    out_d: usize,
171    bias: bool,
172    vb: MaybeQuantizedVarBuilder,
173) -> Result<MaybeQuantizedLinear> {
174    let output_linear = match vb {
175        MaybeQuantizedVarBuilder::Real(weights) => {
176            if bias {
177                MaybeQuantizedLinear::Real(candle_nn::linear(in_d, out_d, weights)?)
178            } else {
179                MaybeQuantizedLinear::Real(candle_nn::linear_no_bias(in_d, out_d, weights)?)
180            }
181        }
182        MaybeQuantizedVarBuilder::Quantized(weights) => {
183            MaybeQuantizedLinear::Quantized(candle_qnn::linear_b(in_d, out_d, bias, weights)?)
184        }
185    };
186    Ok(output_linear)
187}
188
189pub fn linear_from(
190    weight: MaybeQuantizedWeight,
191    bias: Option<Tensor>,
192) -> Result<MaybeQuantizedLinear> {
193    let layer = match weight {
194        MaybeQuantizedWeight::Real(w) => {
195            MaybeQuantizedLinear::Real(candle_nn::Linear::new(w, bias))
196        }
197        MaybeQuantizedWeight::Quantized(w) => {
198            MaybeQuantizedLinear::Quantized(candle_qnn::Linear::from_arc(w, bias)?)
199        }
200    };
201    Ok(layer)
202}