candle_transformers/
quantized_nn.rs1use crate::models::with_tracing::QMatMul;
8use crate::quantized_var_builder::VarBuilder;
9use candle::quantized::QTensor;
10use candle::{Module, Result, Tensor};
11
12#[derive(Debug, Clone)]
13pub struct Embedding {
14 inner: candle_nn::Embedding,
15 span: tracing::Span,
16}
17
18impl Embedding {
19 pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
20 let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?;
21 let inner = candle_nn::Embedding::new(embeddings, d2);
22 let span = tracing::span!(tracing::Level::TRACE, "embedding");
23 Ok(Self { inner, span })
24 }
25
26 pub fn embeddings(&self) -> &Tensor {
27 self.inner.embeddings()
28 }
29}
30
31impl Module for Embedding {
32 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
33 let _enter = self.span.enter();
34 self.inner.forward(xs)
35 }
36}
37
38#[derive(Debug, Clone)]
39pub struct Linear {
40 weight: QMatMul,
41 bias: Option<Tensor>,
42}
43
44impl Linear {
45 pub fn from_arc(weight: std::sync::Arc<QTensor>, bias: Option<Tensor>) -> Result<Self> {
46 let weight = QMatMul::from_weights(weight)?;
47 Ok(Self { weight, bias })
48 }
49
50 pub fn from_weights(weight: QMatMul, bias: Option<Tensor>) -> Self {
51 Self { weight, bias }
52 }
53}
54
55impl Module for Linear {
56 fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
57 let x = x.apply(&self.weight)?;
58 match &self.bias {
59 None => Ok(x),
60 Some(bias) => x.broadcast_add(bias),
61 }
62 }
63}
64
65pub fn linear_b(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
66 let bias = if bias {
67 Some(vb.get(out_dim, "bias")?.dequantize(vb.device())?)
68 } else {
69 None
70 };
71 let weight = QMatMul::new(in_dim, out_dim, vb)?;
72 Ok(Linear { weight, bias })
73}
74
75pub fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
76 let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
77 let weight = QMatMul::new(in_dim, out_dim, vb)?;
78 Ok(Linear {
79 weight,
80 bias: Some(bias),
81 })
82}
83
84pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
85 let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
86 let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
87 Ok(candle_nn::LayerNorm::new(weight, bias, eps))
88}
89
90pub fn layer_norm_no_bias(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
91 let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
92 Ok(candle_nn::LayerNorm::new_no_bias(weight, eps))
93}
94
95pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
96 let weight = QMatMul::new(in_dim, out_dim, vb)?;
97 Ok(Linear { weight, bias: None })
98}
99
100#[derive(Debug, Clone)]
101pub struct RmsNorm {
102 weight: Tensor,
103 eps: f64,
104 span: tracing::Span,
105}
106
107impl RmsNorm {
108 pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
109 let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
110 let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
111 Ok(Self { weight, eps, span })
112 }
113
114 pub fn from_qtensor(weight: QTensor, eps: f64) -> Result<Self> {
115 let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
116 let weight = weight.dequantize(&weight.device())?;
117 Ok(Self { weight, eps, span })
118 }
119}
120
121impl Module for RmsNorm {
122 fn forward(&self, x: &Tensor) -> Result<Tensor> {
123 let _enter = self.span.enter();
124 candle_nn::ops::rms_norm(x, &self.weight, self.eps as f32)
125 }
126}