#[cfg(feature = "peft")]
use candle_core::Tensor;
#[cfg(feature = "peft")]
use candle_nn::VarMap;
use crate::config::BitNetConfig;
use crate::layer::BitLinear;
#[cfg(feature = "peft")]
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct BitNetAdapterConfig {
pub bitnet: BitNetConfig,
pub target_modules: Vec<String>,
}
impl Default for BitNetAdapterConfig {
fn default() -> Self {
Self {
bitnet: BitNetConfig::default(),
target_modules: vec![
"q_proj".to_string(),
"k_proj".to_string(),
"v_proj".to_string(),
"o_proj".to_string(),
"gate_proj".to_string(),
"up_proj".to_string(),
"down_proj".to_string(),
],
}
}
}
impl BitNetAdapterConfig {
#[must_use]
pub fn new(bitnet: BitNetConfig) -> Self {
Self {
bitnet,
..Default::default()
}
}
#[must_use]
pub fn with_target_modules(mut self, modules: Vec<String>) -> Self {
self.target_modules = modules;
self
}
}
#[derive(Debug)]
pub struct BitNetAdapter {
layer: BitLinear,
config: BitNetAdapterConfig,
frozen: bool,
}
impl BitNetAdapter {
#[must_use]
pub fn new(layer: BitLinear, config: BitNetAdapterConfig) -> Self {
Self {
layer,
config,
frozen: false,
}
}
#[must_use]
pub const fn layer(&self) -> &BitLinear {
&self.layer
}
pub fn layer_mut(&mut self) -> &mut BitLinear {
&mut self.layer
}
#[must_use]
pub const fn config(&self) -> &BitNetAdapterConfig {
&self.config
}
#[must_use]
pub const fn is_frozen(&self) -> bool {
self.frozen
}
pub fn freeze(&mut self) {
self.frozen = true;
}
pub fn unfreeze(&mut self) {
self.frozen = false;
}
#[must_use]
pub fn num_parameters(&self) -> usize {
self.layer.in_features() * self.layer.out_features()
}
#[must_use]
pub fn compression_ratio(&self) -> f32 {
self.layer.compression_ratio()
}
}
#[cfg(feature = "peft")]
impl peft_rs::AdapterConfig for BitNetAdapterConfig {
fn validate(&self) -> peft_rs::Result<()> {
self.bitnet
.validate()
.map_err(|e| peft_rs::Error::Config(e.to_string()))
}
}
#[cfg(feature = "peft")]
impl peft_rs::Adapter for BitNetAdapter {
type Config = BitNetAdapterConfig;
fn forward(&self, input: &Tensor, _base_output: Option<&Tensor>) -> peft_rs::Result<Tensor> {
use candle_nn::Module;
self.layer
.forward(input)
.map_err(|e| peft_rs::Error::Forward(e.to_string()))
}
fn num_parameters(&self) -> usize {
self.num_parameters()
}
fn config(&self) -> &Self::Config {
&self.config
}
}
#[cfg(feature = "peft")]
impl peft_rs::Trainable for BitNetAdapter {
fn register_parameters(&self, _var_map: &mut VarMap, _prefix: &str) -> peft_rs::Result<()> {
Ok(())
}
fn freeze(&mut self) {
self.frozen = true;
}
fn unfreeze(&mut self) {
self.frozen = false;
}
fn is_frozen(&self) -> bool {
self.frozen
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
use candle_core::Tensor;
#[test]
fn test_adapter_creation() {
let device = Device::Cpu;
let bitnet_config = BitNetConfig::default().with_group_size(64);
let adapter_config = BitNetAdapterConfig::new(bitnet_config);
let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
let layer = BitLinear::from_weight(&weight, None, &adapter_config.bitnet).unwrap();
let adapter = BitNetAdapter::new(layer, adapter_config);
assert_eq!(adapter.num_parameters(), 64 * 128);
assert!(!adapter.is_frozen());
}
#[test]
fn test_adapter_freeze_unfreeze() {
let device = Device::Cpu;
let bitnet_config = BitNetConfig::default().with_group_size(64);
let adapter_config = BitNetAdapterConfig::new(bitnet_config);
let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
let layer = BitLinear::from_weight(&weight, None, &adapter_config.bitnet).unwrap();
let mut adapter = BitNetAdapter::new(layer, adapter_config);
adapter.freeze();
assert!(adapter.is_frozen());
adapter.unfreeze();
assert!(!adapter.is_frozen());
}
#[test]
fn test_adapter_config_default() {
let config = BitNetAdapterConfig::default();
assert!(!config.target_modules.is_empty());
assert!(config.target_modules.contains(&"q_proj".to_string()));
}
}