use crate::common::{OptimizerState, StateMemoryStats};
use crate::traits::StatefulOptimizer;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::errors::{Result, TrustformersError};
use trustformers_core::tensor::Tensor;
use trustformers_core::traits::Optimizer;
const NF4_VALUES: [f32; 16] = [
-1.0,
-0.696_192_8,
-0.525_073_05,
-0.394_917_5,
-0.284_441_38,
-0.184_773_43,
-0.091_050_036,
0.0,
0.079_580_3,
0.160_930_2,
0.246_112_3,
0.337_915_24,
0.440_709_83,
0.562_617,
0.722_956_84,
1.0,
];
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdvancedQuantizationConfig {
pub method: QuantizationMethod,
pub block_size: usize,
pub adaptation_rate: f32,
pub min_scale: f32,
pub max_scale: f32,
pub double_quantization: bool,
}
impl Default for AdvancedQuantizationConfig {
fn default() -> Self {
Self {
method: QuantizationMethod::NF4,
block_size: 64,
adaptation_rate: 0.01,
min_scale: 1e-8,
max_scale: 1e8,
double_quantization: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum QuantizationMethod {
Int4,
NF4,
Int8,
Dynamic,
BlockWise,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizedTensor {
pub data: Vec<f32>,
pub scales: Vec<f32>,
pub zero_points: Vec<f32>,
pub shape: Vec<usize>,
pub method: QuantizationMethod,
pub block_size: usize,
}
impl QuantizedTensor {
pub fn new(
data: Vec<f32>,
scales: Vec<f32>,
zero_points: Vec<f32>,
shape: Vec<usize>,
method: QuantizationMethod,
block_size: usize,
) -> Self {
Self {
data,
scales,
zero_points,
shape,
method,
block_size,
}
}
pub fn memory_usage(&self) -> usize {
self.data.len() * 4 + self.scales.len() * 4 + self.zero_points.len() * 4
}
pub fn compression_ratio(&self) -> f32 {
let original_size = self.shape.iter().product::<usize>() * 4; match self.method {
QuantizationMethod::NF4 | QuantizationMethod::Int4 => 8.0, QuantizationMethod::Int8 => 4.0, _ => {
let compressed_size = self.memory_usage();
if compressed_size > 0 {
original_size as f32 / compressed_size as f32
} else {
1.0
}
},
}
}
}
pub struct QuantizationUtils;
impl QuantizationUtils {
pub fn quantize_nf4(tensor: &Tensor, block_size: usize) -> Result<QuantizedTensor> {
let data = tensor.data()?;
let shape = tensor.shape();
let num_elements = data.len();
let num_blocks = num_elements.div_ceil(block_size);
let mut quantized_data = Vec::new();
let mut scales = Vec::with_capacity(num_blocks);
let mut zero_points = Vec::with_capacity(num_blocks);
for block_idx in 0..num_blocks {
let start = block_idx * block_size;
let end = (start + block_size).min(num_elements);
let block = &data[start..end];
let min_val = block.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_val = block.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let scale = (max_val - min_val) / 15.0; let zero_point = -min_val / scale;
scales.push(scale);
zero_points.push(zero_point);
for &value in block {
let normalized = (value - min_val) / scale;
let quantized = Self::find_closest_nf4(normalized / 15.0);
quantized_data.push(quantized);
}
}
Ok(QuantizedTensor::new(
quantized_data,
scales,
zero_points,
shape,
QuantizationMethod::NF4,
block_size,
))
}
fn find_closest_nf4(value: f32) -> f32 {
let clamped = value.clamp(-1.0, 1.0);
let mut best_val = NF4_VALUES[0];
let mut best_diff = (NF4_VALUES[0] - clamped).abs();
for &nf4_val in NF4_VALUES.iter() {
let diff = (nf4_val - clamped).abs();
if diff < best_diff {
best_diff = diff;
best_val = nf4_val;
}
}
best_val
}
pub fn dequantize_nf4(quantized: &QuantizedTensor) -> Result<Tensor> {
let num_elements: usize = quantized.shape.iter().product();
let mut data = Vec::with_capacity(num_elements);
let block_size = quantized.block_size;
let num_blocks = num_elements.div_ceil(block_size);
let mut data_idx = 0;
for block_idx in 0..num_blocks {
let start = block_idx * block_size;
let end = (start + block_size).min(num_elements);
let block_len = end - start;
let scale = quantized.scales[block_idx];
let zero_point = quantized.zero_points[block_idx];
for _ in 0..block_len {
if data_idx < quantized.data.len() {
let nf4_val = quantized.data[data_idx];
let dequantized = (nf4_val * 15.0 + zero_point) * scale;
data.push(dequantized);
data_idx += 1;
}
}
}
Tensor::new(data)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GradientStatistics {
pub mean: f32,
pub variance: f32,
pub skewness: f32,
pub kurtosis: f32,
pub l2_norm: f32,
}
impl GradientStatistics {
pub fn compute(data: &[f32]) -> Self {
let n = data.len() as f32;
let mean = data.iter().sum::<f32>() / n;
let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / n;
let std_dev = variance.sqrt();
let skewness = if std_dev > 1e-8 {
data.iter().map(|x| ((x - mean) / std_dev).powi(3)).sum::<f32>() / n
} else {
0.0
};
let kurtosis = if std_dev > 1e-8 {
data.iter().map(|x| ((x - mean) / std_dev).powi(4)).sum::<f32>() / n - 3.0
} else {
0.0
};
let l2_norm = data.iter().map(|x| x * x).sum::<f32>().sqrt();
Self {
mean,
variance,
skewness,
kurtosis,
l2_norm,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Adam4bitOptimizerConfig {
pub learning_rate: f32,
pub beta1: f32,
pub beta2: f32,
pub epsilon: f32,
pub weight_decay: f32,
}
impl Default for Adam4bitOptimizerConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
weight_decay: 0.0,
}
}
}
#[derive(Debug)]
pub struct Adam4bit {
config: AdvancedQuantizationConfig,
optimizer_config: Adam4bitOptimizerConfig,
state: OptimizerState,
momentum_quantized: HashMap<String, QuantizedTensor>,
variance_quantized: HashMap<String, QuantizedTensor>,
gradient_stats: HashMap<String, GradientStatistics>,
}
impl Adam4bit {
pub fn new(
learning_rate: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
weight_decay: f32,
) -> Self {
let optimizer_config = Adam4bitOptimizerConfig {
learning_rate,
beta1,
beta2,
epsilon,
weight_decay,
};
Self {
config: AdvancedQuantizationConfig::default(),
optimizer_config,
state: OptimizerState::new(),
momentum_quantized: HashMap::new(),
variance_quantized: HashMap::new(),
gradient_stats: HashMap::new(),
}
}
pub fn with_quantization_config(
optimizer_config: Adam4bitOptimizerConfig,
quantization_config: AdvancedQuantizationConfig,
) -> Self {
Self {
config: quantization_config,
optimizer_config,
state: OptimizerState::new(),
momentum_quantized: HashMap::new(),
variance_quantized: HashMap::new(),
gradient_stats: HashMap::new(),
}
}
pub fn memory_savings(&self) -> f32 {
0.75
}
fn update_gradient_stats(&mut self, param_id: &str, gradient_data: &[f32]) {
let stats = GradientStatistics::compute(gradient_data);
if let Some(existing_stats) = self.gradient_stats.get_mut(param_id) {
let alpha = self.config.adaptation_rate;
existing_stats.mean = (1.0 - alpha) * existing_stats.mean + alpha * stats.mean;
existing_stats.variance =
(1.0 - alpha) * existing_stats.variance + alpha * stats.variance;
} else {
self.gradient_stats.insert(param_id.to_string(), stats);
}
}
}
impl Optimizer for Adam4bit {
fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
match (parameter, grad) {
(Tensor::F32(param), Tensor::F32(grad_arr)) => {
let param_id = format!("{:p}", param.as_ptr());
let size = grad_arr.len();
self.update_gradient_stats(
¶m_id,
&grad_arr.iter().cloned().collect::<Vec<f32>>(),
);
if !self.momentum_quantized.contains_key(¶m_id) {
let zeros = vec![0.0; size];
let zero_tensor = Tensor::new(zeros)?;
let momentum_q =
QuantizationUtils::quantize_nf4(&zero_tensor, self.config.block_size)?;
let variance_q =
QuantizationUtils::quantize_nf4(&zero_tensor, self.config.block_size)?;
self.momentum_quantized.insert(param_id.clone(), momentum_q);
self.variance_quantized.insert(param_id.clone(), variance_q);
}
let momentum_q = self
.momentum_quantized
.get(¶m_id)
.expect("momentum_quantized should exist after insert");
let variance_q = self
.variance_quantized
.get(¶m_id)
.expect("variance_quantized should exist after insert");
let momentum_tensor = QuantizationUtils::dequantize_nf4(momentum_q)?;
let variance_tensor = QuantizationUtils::dequantize_nf4(variance_q)?;
let momentum_data = momentum_tensor.data()?;
let variance_data = variance_tensor.data()?;
let mut new_momentum = Vec::with_capacity(size);
let mut new_variance = Vec::with_capacity(size);
let step = (self.state.step + 1) as f32;
let bias_correction1 = 1.0 - self.optimizer_config.beta1.powf(step);
let bias_correction2 = 1.0 - self.optimizer_config.beta2.powf(step);
for i in 0..size {
let mut g = grad_arr[i];
if self.optimizer_config.weight_decay > 0.0 {
g += self.optimizer_config.weight_decay * param[i];
}
let m = self.optimizer_config.beta1 * momentum_data[i]
+ (1.0 - self.optimizer_config.beta1) * g;
let v = self.optimizer_config.beta2 * variance_data[i]
+ (1.0 - self.optimizer_config.beta2) * g * g;
new_momentum.push(m);
new_variance.push(v);
let m_hat = m / bias_correction1;
let v_hat = v / bias_correction2;
param[i] -= self.optimizer_config.learning_rate * m_hat
/ (v_hat.sqrt() + self.optimizer_config.epsilon);
}
let new_momentum_tensor = Tensor::new(new_momentum)?;
let new_variance_tensor = Tensor::new(new_variance)?;
let momentum_q_new =
QuantizationUtils::quantize_nf4(&new_momentum_tensor, self.config.block_size)?;
let variance_q_new =
QuantizationUtils::quantize_nf4(&new_variance_tensor, self.config.block_size)?;
self.momentum_quantized.insert(param_id.clone(), momentum_q_new);
self.variance_quantized.insert(param_id, variance_q_new);
Ok(())
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor types for Adam4bit",
"Adam4bit::update",
)),
}
}
fn zero_grad(&mut self) {
}
fn step(&mut self) {
self.state.step();
}
fn get_lr(&self) -> f32 {
self.optimizer_config.learning_rate
}
fn set_lr(&mut self, lr: f32) {
self.optimizer_config.learning_rate = lr;
}
}
impl StatefulOptimizer for Adam4bit {
type Config = Adam4bitOptimizerConfig;
type State = OptimizerState;
fn config(&self) -> &Self::Config {
&self.optimizer_config
}
fn state(&self) -> &Self::State {
&self.state
}
fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}
fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
let mut state_dict = HashMap::new();
state_dict.insert(
"learning_rate".to_string(),
Tensor::new(vec![self.optimizer_config.learning_rate])?,
);
state_dict.insert(
"beta1".to_string(),
Tensor::new(vec![self.optimizer_config.beta1])?,
);
state_dict.insert(
"beta2".to_string(),
Tensor::new(vec![self.optimizer_config.beta2])?,
);
state_dict.insert(
"epsilon".to_string(),
Tensor::new(vec![self.optimizer_config.epsilon])?,
);
state_dict.insert(
"weight_decay".to_string(),
Tensor::new(vec![self.optimizer_config.weight_decay])?,
);
state_dict.insert(
"step".to_string(),
Tensor::new(vec![self.state.step as f32])?,
);
for (param_id, momentum_q) in &self.momentum_quantized {
state_dict.insert(
format!("momentum_q_{}", param_id),
Tensor::new(momentum_q.data.clone())?,
);
}
for (param_id, variance_q) in &self.variance_quantized {
state_dict.insert(
format!("variance_q_{}", param_id),
Tensor::new(variance_q.data.clone())?,
);
}
Ok(state_dict)
}
fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
if let Some(lr_tensor) = state.get("learning_rate") {
if let Ok(lr_vec) = lr_tensor.data() {
if !lr_vec.is_empty() {
self.optimizer_config.learning_rate = lr_vec[0];
}
}
}
Ok(())
}
fn memory_usage(&self) -> StateMemoryStats {
let total_memory =
self.momentum_quantized.values().map(|q| q.memory_usage()).sum::<usize>()
+ self.variance_quantized.values().map(|q| q.memory_usage()).sum::<usize>();
StateMemoryStats {
momentum_elements: self.momentum_quantized.values().map(|q| q.data.len()).sum(),
variance_elements: self.variance_quantized.values().map(|q| q.data.len()).sum(),
third_moment_elements: 0,
total_bytes: total_memory,
num_parameters: self.momentum_quantized.len(),
}
}
fn reset_state(&mut self) {
self.state.clear();
self.momentum_quantized.clear();
self.variance_quantized.clear();
self.gradient_stats.clear();
}
fn num_parameters(&self) -> usize {
self.momentum_quantized.values().map(|q| q.data.len()).sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nf4_quantization() {
let data = vec![1.0, -0.5, 0.0, 0.8, -1.2];
let tensor = Tensor::new(data.clone()).expect("Failed to create tensor");
let quantized =
QuantizationUtils::quantize_nf4(&tensor, 64).expect("Operation failed in test");
assert_eq!(quantized.method, QuantizationMethod::NF4);
assert!(quantized.compression_ratio() >= 1.0);
}
#[test]
fn test_gradient_statistics() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let stats = GradientStatistics::compute(&data);
assert!((stats.mean - 3.0).abs() < 1e-6);
assert!(stats.variance > 0.0);
assert!(stats.l2_norm > 0.0);
}
#[test]
fn test_adam4bit_creation() {
let optimizer = Adam4bit::new(0.001, 0.9, 0.999, 1e-8, 0.01);
assert_eq!(optimizer.get_lr(), 0.001);
assert!(optimizer.memory_savings() > 0.5); }
#[test]
fn test_quantized_tensor_memory() {
let quantized = QuantizedTensor::new(
vec![0.0, 1.0, 2.0, 3.0],
vec![1.0],
vec![0.0],
vec![4],
QuantizationMethod::NF4,
64,
);
assert!(quantized.memory_usage() > 0);
assert!(quantized.compression_ratio() >= 1.0);
}
}