1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
//! Matrix quantization types and implementations
//!
//! This module contains the QuantizedMatrix struct and QuantizedData2D enum
//! along with their implementations for handling quantized matrix data.
use half::{bf16, f16};
use scirs2_core::ndarray::{Array1, Array2};
use super::types::QuantizedDataType;
/// A matrix with quantized values
#[derive(Debug, Clone)]
pub struct QuantizedMatrix {
/// The quantized values can be stored in different formats
pub data: QuantizedData2D,
/// The original shape of the matrix
pub shape: (usize, usize),
/// The data type used for quantization
pub data_type: QuantizedDataType,
}
/// Storage for quantized 2D data (matrices) in different formats
#[derive(Debug, Clone)]
pub enum QuantizedData2D {
/// 8-bit integer storage
Int8(Array2<i8>),
/// 16-bit float storage (IEEE 754 half-precision)
Float16(Array2<f16>),
/// 16-bit brain float storage
BFloat16(Array2<bf16>),
}
impl QuantizedData2D {
/// Get the number of elements in the storage
pub fn len(&self) -> usize {
match self {
QuantizedData2D::Int8(arr) => arr.len(),
QuantizedData2D::Float16(arr) => arr.len(),
QuantizedData2D::BFloat16(arr) => arr.len(),
}
}
/// Check if the storage is empty
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
/// Helper function to get the i8 data from a QuantizedMatrix if available
///
/// Returns None if the matrix does not use Int8 storage
#[allow(dead_code)]
pub fn get_quantizedmatrix_2d_i8(matrix: &QuantizedMatrix) -> Option<&Array2<i8>> {
match &matrix.data {
QuantizedData2D::Int8(data) => Some(data),
_ => None,
}
}
impl QuantizedMatrix {
/// Creates a new quantized matrix with int8 storage
pub fn new_i8(data: Array2<i8>, shape: (usize, usize), data_type: QuantizedDataType) -> Self {
Self {
data: QuantizedData2D::Int8(data),
shape,
data_type,
}
}
/// Creates a new f16 quantized matrix
pub fn new_f16(data: Array2<f16>, shape: (usize, usize)) -> Self {
Self {
data: QuantizedData2D::Float16(data),
shape,
data_type: QuantizedDataType::Float16,
}
}
/// Creates a new bf16 quantized matrix
pub fn new_bf16(data: Array2<bf16>, shape: (usize, usize)) -> Self {
Self {
data: QuantizedData2D::BFloat16(data),
shape,
data_type: QuantizedDataType::BFloat16,
}
}
/// Creates a standard Int8 quantized matrix (for backward compatibility)
pub fn from_i8(data: Array2<i8>, shape: (usize, usize)) -> Self {
Self {
data: QuantizedData2D::Int8(data),
shape,
data_type: QuantizedDataType::Int8,
}
}
// This method stays for backward compatibility but will be deprecated in the future
// Use get_i8 or get_f32 instead
#[deprecated(since = "0.1.0", note = "Use get_i8 or get_f32 instead")]
pub fn get(&self, row: usize, col: usize) -> i8 {
self.get_i8(row, col)
}
/// Returns the shape of the matrix
pub fn shape(&self) -> (usize, usize) {
self.shape
}
/// Returns the number of rows in the matrix
pub fn nrows(&self) -> usize {
self.shape.0
}
/// Returns the number of columns in the matrix
pub fn ncols(&self) -> usize {
self.shape.1
}
/// Get value at specified position as i8 (for int quantization)
pub fn get_i8(&self, row: usize, col: usize) -> i8 {
match &self.data {
QuantizedData2D::Int8(arr) => {
match self.data_type {
QuantizedDataType::Int8 => arr[[row, col]],
QuantizedDataType::Int4 => {
let idx = row * self.shape.1 + col;
let byte_idx = idx / 2;
let nibble_idx = idx % 2;
let byte = arr.as_slice().expect("Operation failed")[byte_idx];
if nibble_idx == 0 {
// Upper 4 bits
byte >> 4
} else {
// Lower 4 bits
byte & 0x0F
}
}
QuantizedDataType::UInt4 => {
let idx = row * self.shape.1 + col;
let byte_idx = idx / 2;
let nibble_idx = idx % 2;
let byte = arr.as_slice().expect("Operation failed")[byte_idx];
if nibble_idx == 0 {
// Upper 4 bits
(byte >> 4) & 0x0F
} else {
// Lower 4 bits
byte & 0x0F
}
}
_ => unreachable!(
"Invalid quantization type for Int8 storage: expected Int8, Int4, or UInt4"
),
}
}
_ => unreachable!("Cannot get i8 value from floating-point quantized matrix"),
}
}
/// Get value at specified position as f32 (for all quantization types)
pub fn get_f32(&self, row: usize, col: usize) -> f32 {
match &self.data {
QuantizedData2D::Int8(arr) => match self.data_type {
QuantizedDataType::Int8 => arr[[row, col]] as f32,
QuantizedDataType::Int4 => self.get_i8(row, col) as f32,
QuantizedDataType::UInt4 => self.get_i8(row, col) as f32,
_ => unreachable!(
"Invalid data type for Int8 storage: expected Int8, Int4, or UInt4"
),
},
QuantizedData2D::Float16(arr) => arr[[row, col]].to_f32(),
QuantizedData2D::BFloat16(arr) => arr[[row, col]].to_f32(),
}
}
}