bitnet_quantize/
adapter.rs1#[cfg(feature = "peft")]
7use candle_core::Tensor;
8#[cfg(feature = "peft")]
9use candle_nn::VarMap;
10
11use crate::config::BitNetConfig;
12use crate::layer::BitLinear;
13
14#[cfg(feature = "peft")]
15use crate::error::Result;
16
17#[derive(Debug, Clone)]
19pub struct BitNetAdapterConfig {
20 pub bitnet: BitNetConfig,
22
23 pub target_modules: Vec<String>,
25}
26
27impl Default for BitNetAdapterConfig {
28 fn default() -> Self {
29 Self {
30 bitnet: BitNetConfig::default(),
31 target_modules: vec![
32 "q_proj".to_string(),
33 "k_proj".to_string(),
34 "v_proj".to_string(),
35 "o_proj".to_string(),
36 "gate_proj".to_string(),
37 "up_proj".to_string(),
38 "down_proj".to_string(),
39 ],
40 }
41 }
42}
43
44impl BitNetAdapterConfig {
45 #[must_use]
47 pub fn new(bitnet: BitNetConfig) -> Self {
48 Self {
49 bitnet,
50 ..Default::default()
51 }
52 }
53
54 #[must_use]
56 pub fn with_target_modules(mut self, modules: Vec<String>) -> Self {
57 self.target_modules = modules;
58 self
59 }
60}
61
62#[derive(Debug)]
67pub struct BitNetAdapter {
68 layer: BitLinear,
70
71 config: BitNetAdapterConfig,
73
74 frozen: bool,
76}
77
78impl BitNetAdapter {
79 #[must_use]
81 pub fn new(layer: BitLinear, config: BitNetAdapterConfig) -> Self {
82 Self {
83 layer,
84 config,
85 frozen: false,
86 }
87 }
88
89 #[must_use]
91 pub const fn layer(&self) -> &BitLinear {
92 &self.layer
93 }
94
95 pub fn layer_mut(&mut self) -> &mut BitLinear {
97 &mut self.layer
98 }
99
100 #[must_use]
102 pub const fn config(&self) -> &BitNetAdapterConfig {
103 &self.config
104 }
105
106 #[must_use]
108 pub const fn is_frozen(&self) -> bool {
109 self.frozen
110 }
111
112 pub fn freeze(&mut self) {
114 self.frozen = true;
115 }
116
117 pub fn unfreeze(&mut self) {
119 self.frozen = false;
120 }
121
122 #[must_use]
124 pub fn num_parameters(&self) -> usize {
125 self.layer.in_features() * self.layer.out_features()
126 }
127
128 #[must_use]
130 pub fn compression_ratio(&self) -> f32 {
131 self.layer.compression_ratio()
132 }
133}
134
135#[cfg(feature = "peft")]
136impl peft_rs::AdapterConfig for BitNetAdapterConfig {
137 fn validate(&self) -> peft_rs::Result<()> {
138 self.bitnet
139 .validate()
140 .map_err(|e| peft_rs::Error::Config(e.to_string()))
141 }
142}
143
144#[cfg(feature = "peft")]
145impl peft_rs::Adapter for BitNetAdapter {
146 type Config = BitNetAdapterConfig;
147
148 fn forward(&self, input: &Tensor, _base_output: Option<&Tensor>) -> peft_rs::Result<Tensor> {
149 use candle_nn::Module;
150 self.layer
151 .forward(input)
152 .map_err(|e| peft_rs::Error::Forward(e.to_string()))
153 }
154
155 fn num_parameters(&self) -> usize {
156 self.num_parameters()
157 }
158
159 fn config(&self) -> &Self::Config {
160 &self.config
161 }
162}
163
164#[cfg(feature = "peft")]
165impl peft_rs::Trainable for BitNetAdapter {
166 fn register_parameters(&self, _var_map: &mut VarMap, _prefix: &str) -> peft_rs::Result<()> {
167 Ok(())
170 }
171
172 fn freeze(&mut self) {
173 self.frozen = true;
174 }
175
176 fn unfreeze(&mut self) {
177 self.frozen = false;
178 }
179
180 fn is_frozen(&self) -> bool {
181 self.frozen
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use candle_core::Device;
189 use candle_core::Tensor;
190
191 #[test]
192 fn test_adapter_creation() {
193 let device = Device::Cpu;
194 let bitnet_config = BitNetConfig::default().with_group_size(64);
195 let adapter_config = BitNetAdapterConfig::new(bitnet_config);
196
197 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
198 let layer = BitLinear::from_weight(&weight, None, &adapter_config.bitnet).unwrap();
199
200 let adapter = BitNetAdapter::new(layer, adapter_config);
201
202 assert_eq!(adapter.num_parameters(), 64 * 128);
203 assert!(!adapter.is_frozen());
204 }
205
206 #[test]
207 fn test_adapter_freeze_unfreeze() {
208 let device = Device::Cpu;
209 let bitnet_config = BitNetConfig::default().with_group_size(64);
210 let adapter_config = BitNetAdapterConfig::new(bitnet_config);
211
212 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
213 let layer = BitLinear::from_weight(&weight, None, &adapter_config.bitnet).unwrap();
214
215 let mut adapter = BitNetAdapter::new(layer, adapter_config);
216
217 adapter.freeze();
218 assert!(adapter.is_frozen());
219
220 adapter.unfreeze();
221 assert!(!adapter.is_frozen());
222 }
223
224 #[test]
225 fn test_adapter_config_default() {
226 let config = BitNetAdapterConfig::default();
227
228 assert!(!config.target_modules.is_empty());
229 assert!(config.target_modules.contains(&"q_proj".to_string()));
230 }
231}