Skip to main content

bitnet_quantize/
adapter.rs

1//! peft-rs Adapter integration for BitNet.
2//!
3//! This module provides `BitNetAdapter` which implements the peft-rs `Adapter` trait,
4//! enabling BitNet quantization to be used within the PEFT fine-tuning framework.
5
6#[cfg(feature = "peft")]
7use candle_core::Tensor;
8#[cfg(feature = "peft")]
9use candle_nn::VarMap;
10
11use crate::config::BitNetConfig;
12use crate::layer::BitLinear;
13
14#[cfg(feature = "peft")]
15use crate::error::Result;
16
17/// BitNet adapter configuration for peft-rs integration.
18#[derive(Debug, Clone)]
19pub struct BitNetAdapterConfig {
20    /// BitNet quantization configuration.
21    pub bitnet: BitNetConfig,
22
23    /// Target modules to apply BitNet quantization to.
24    pub target_modules: Vec<String>,
25}
26
27impl Default for BitNetAdapterConfig {
28    fn default() -> Self {
29        Self {
30            bitnet: BitNetConfig::default(),
31            target_modules: vec![
32                "q_proj".to_string(),
33                "k_proj".to_string(),
34                "v_proj".to_string(),
35                "o_proj".to_string(),
36                "gate_proj".to_string(),
37                "up_proj".to_string(),
38                "down_proj".to_string(),
39            ],
40        }
41    }
42}
43
44impl BitNetAdapterConfig {
45    /// Create a new adapter configuration.
46    #[must_use]
47    pub fn new(bitnet: BitNetConfig) -> Self {
48        Self {
49            bitnet,
50            ..Default::default()
51        }
52    }
53
54    /// Set target modules.
55    #[must_use]
56    pub fn with_target_modules(mut self, modules: Vec<String>) -> Self {
57        self.target_modules = modules;
58        self
59    }
60}
61
62/// BitNet adapter for peft-rs integration.
63///
64/// This adapter wraps a BitLinear layer and implements the peft-rs Adapter trait
65/// (when the `peft` feature is enabled).
66#[derive(Debug)]
67pub struct BitNetAdapter {
68    /// The underlying BitLinear layer.
69    layer: BitLinear,
70
71    /// Adapter configuration.
72    config: BitNetAdapterConfig,
73
74    /// Whether the adapter is frozen.
75    frozen: bool,
76}
77
78impl BitNetAdapter {
79    /// Create a new BitNet adapter from a BitLinear layer.
80    #[must_use]
81    pub fn new(layer: BitLinear, config: BitNetAdapterConfig) -> Self {
82        Self {
83            layer,
84            config,
85            frozen: false,
86        }
87    }
88
89    /// Get reference to the underlying layer.
90    #[must_use]
91    pub const fn layer(&self) -> &BitLinear {
92        &self.layer
93    }
94
95    /// Get mutable reference to the underlying layer.
96    pub fn layer_mut(&mut self) -> &mut BitLinear {
97        &mut self.layer
98    }
99
100    /// Get reference to the configuration.
101    #[must_use]
102    pub const fn config(&self) -> &BitNetAdapterConfig {
103        &self.config
104    }
105
106    /// Check if the adapter is frozen.
107    #[must_use]
108    pub const fn is_frozen(&self) -> bool {
109        self.frozen
110    }
111
112    /// Freeze the adapter (disable gradient computation).
113    pub fn freeze(&mut self) {
114        self.frozen = true;
115    }
116
117    /// Unfreeze the adapter (enable gradient computation).
118    pub fn unfreeze(&mut self) {
119        self.frozen = false;
120    }
121
122    /// Get the number of quantized parameters.
123    #[must_use]
124    pub fn num_parameters(&self) -> usize {
125        self.layer.in_features() * self.layer.out_features()
126    }
127
128    /// Get the compression ratio.
129    #[must_use]
130    pub fn compression_ratio(&self) -> f32 {
131        self.layer.compression_ratio()
132    }
133}
134
135#[cfg(feature = "peft")]
136impl peft_rs::AdapterConfig for BitNetAdapterConfig {
137    fn validate(&self) -> peft_rs::Result<()> {
138        self.bitnet
139            .validate()
140            .map_err(|e| peft_rs::Error::Config(e.to_string()))
141    }
142}
143
144#[cfg(feature = "peft")]
145impl peft_rs::Adapter for BitNetAdapter {
146    type Config = BitNetAdapterConfig;
147
148    fn forward(&self, input: &Tensor, _base_output: Option<&Tensor>) -> peft_rs::Result<Tensor> {
149        use candle_nn::Module;
150        self.layer
151            .forward(input)
152            .map_err(|e| peft_rs::Error::Forward(e.to_string()))
153    }
154
155    fn num_parameters(&self) -> usize {
156        self.num_parameters()
157    }
158
159    fn config(&self) -> &Self::Config {
160        &self.config
161    }
162}
163
164#[cfg(feature = "peft")]
165impl peft_rs::Trainable for BitNetAdapter {
166    fn register_parameters(&self, _var_map: &mut VarMap, _prefix: &str) -> peft_rs::Result<()> {
167        // BitNet weights are quantized and typically not trained directly
168        // The bias (if present) could be registered here
169        Ok(())
170    }
171
172    fn freeze(&mut self) {
173        self.frozen = true;
174    }
175
176    fn unfreeze(&mut self) {
177        self.frozen = false;
178    }
179
180    fn is_frozen(&self) -> bool {
181        self.frozen
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use candle_core::Device;
189    use candle_core::Tensor;
190
191    #[test]
192    fn test_adapter_creation() {
193        let device = Device::Cpu;
194        let bitnet_config = BitNetConfig::default().with_group_size(64);
195        let adapter_config = BitNetAdapterConfig::new(bitnet_config);
196
197        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
198        let layer = BitLinear::from_weight(&weight, None, &adapter_config.bitnet).unwrap();
199
200        let adapter = BitNetAdapter::new(layer, adapter_config);
201
202        assert_eq!(adapter.num_parameters(), 64 * 128);
203        assert!(!adapter.is_frozen());
204    }
205
206    #[test]
207    fn test_adapter_freeze_unfreeze() {
208        let device = Device::Cpu;
209        let bitnet_config = BitNetConfig::default().with_group_size(64);
210        let adapter_config = BitNetAdapterConfig::new(bitnet_config);
211
212        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
213        let layer = BitLinear::from_weight(&weight, None, &adapter_config.bitnet).unwrap();
214
215        let mut adapter = BitNetAdapter::new(layer, adapter_config);
216
217        adapter.freeze();
218        assert!(adapter.is_frozen());
219
220        adapter.unfreeze();
221        assert!(!adapter.is_frozen());
222    }
223
224    #[test]
225    fn test_adapter_config_default() {
226        let config = BitNetAdapterConfig::default();
227
228        assert!(!config.target_modules.is_empty());
229        assert!(config.target_modules.contains(&"q_proj".to_string()));
230    }
231}