Skip to main content

llama_rs/tensor/
dtype.rs

1//! Tensor data types
2
3use crate::gguf::GgmlType;
4
5/// Data type for tensor elements
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum DType {
8    // Floating point types
9    F32,
10    F16,
11    BF16,
12    F64,
13    // Integer types
14    I8,
15    I16,
16    I32,
17    I64,
18    U8,
19    // Legacy quantized types (block size 32)
20    Q4_0,
21    Q4_1,
22    Q5_0,
23    Q5_1,
24    Q8_0,
25    Q8_1,
26    // K-quant types (block size 256)
27    Q2K,
28    Q3K,
29    Q4K,
30    Q5K,
31    Q6K,
32    Q8K,
33    // IQ (importance-weighted) quant types
34    IQ2XXS,
35    IQ2XS,
36    IQ2S,
37    IQ3XXS,
38    IQ3S,
39    IQ4XS,
40    IQ4NL,
41    IQ1S,
42    IQ1M,
43    // Ternary quant types
44    TQ1_0,
45    TQ2_0,
46}
47
48impl DType {
49    /// Block size for this type (number of elements per block)
50    pub const fn block_size(&self) -> usize {
51        match self {
52            // Non-quantized types have block size 1
53            Self::F32 | Self::F16 | Self::BF16 | Self::F64 => 1,
54            Self::I8 | Self::I16 | Self::I32 | Self::I64 | Self::U8 => 1,
55            // Legacy quants: 32 elements per block
56            Self::Q4_0 | Self::Q4_1 | Self::Q5_0 | Self::Q5_1 | Self::Q8_0 | Self::Q8_1 => 32,
57            // K-quants: 256 elements per block
58            Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => 256,
59            // IQ types: 256 elements per block (except IQ4NL which uses 32)
60            Self::IQ2XXS
61            | Self::IQ2XS
62            | Self::IQ2S
63            | Self::IQ3XXS
64            | Self::IQ3S
65            | Self::IQ4XS
66            | Self::IQ1S
67            | Self::IQ1M => 256,
68            Self::IQ4NL => 32,
69            // Ternary quants: 256 elements per block
70            Self::TQ1_0 | Self::TQ2_0 => 256,
71        }
72    }
73
74    /// Bytes per block for this type
75    pub const fn block_bytes(&self) -> usize {
76        match self {
77            Self::F32 => 4,
78            Self::F16 | Self::BF16 => 2,
79            Self::F64 => 8,
80            Self::I8 | Self::U8 => 1,
81            Self::I16 => 2,
82            Self::I32 => 4,
83            Self::I64 => 8,
84            Self::Q4_0 => 18,
85            Self::Q4_1 => 20,
86            Self::Q5_0 => 22,
87            Self::Q5_1 => 24,
88            Self::Q8_0 => 34,
89            Self::Q8_1 => 36,
90            Self::Q2K => 84,
91            Self::Q3K => 110,
92            Self::Q4K => 144,
93            Self::Q5K => 176,
94            Self::Q6K => 210,
95            Self::Q8K => 292,
96            Self::IQ2XXS => 66,
97            Self::IQ2XS => 74,
98            Self::IQ2S => 82,
99            Self::IQ3XXS => 98,
100            Self::IQ3S => 110,
101            Self::IQ4XS => 136,
102            Self::IQ4NL => 18,
103            Self::IQ1S => 50,
104            Self::IQ1M => 56,
105            Self::TQ1_0 => 54,
106            Self::TQ2_0 => 66,
107        }
108    }
109
110    /// Returns true if this is a quantized type
111    pub const fn is_quantized(&self) -> bool {
112        !matches!(
113            self,
114            Self::F32
115                | Self::F16
116                | Self::BF16
117                | Self::F64
118                | Self::I8
119                | Self::I16
120                | Self::I32
121                | Self::I64
122                | Self::U8
123        )
124    }
125
126    /// Returns a human-readable name for the dtype
127    pub const fn name(&self) -> &'static str {
128        match self {
129            Self::F32 => "F32",
130            Self::F16 => "F16",
131            Self::BF16 => "BF16",
132            Self::F64 => "F64",
133            Self::I8 => "I8",
134            Self::I16 => "I16",
135            Self::I32 => "I32",
136            Self::I64 => "I64",
137            Self::U8 => "U8",
138            Self::Q4_0 => "Q4_0",
139            Self::Q4_1 => "Q4_1",
140            Self::Q5_0 => "Q5_0",
141            Self::Q5_1 => "Q5_1",
142            Self::Q8_0 => "Q8_0",
143            Self::Q8_1 => "Q8_1",
144            Self::Q2K => "Q2_K",
145            Self::Q3K => "Q3_K",
146            Self::Q4K => "Q4_K",
147            Self::Q5K => "Q5_K",
148            Self::Q6K => "Q6_K",
149            Self::Q8K => "Q8_K",
150            Self::IQ2XXS => "IQ2_XXS",
151            Self::IQ2XS => "IQ2_XS",
152            Self::IQ2S => "IQ2_S",
153            Self::IQ3XXS => "IQ3_XXS",
154            Self::IQ3S => "IQ3_S",
155            Self::IQ4XS => "IQ4_XS",
156            Self::IQ4NL => "IQ4_NL",
157            Self::IQ1S => "IQ1_S",
158            Self::IQ1M => "IQ1_M",
159            Self::TQ1_0 => "TQ1_0",
160            Self::TQ2_0 => "TQ2_0",
161        }
162    }
163
164    /// Calculate the byte size needed for a given number of elements
165    pub const fn size_for_elements(&self, n_elements: usize) -> usize {
166        let block_size = self.block_size();
167        let block_bytes = self.block_bytes();
168        // For quantized types, elements must be a multiple of block_size
169        // We round up to handle partial blocks
170        let n_blocks = n_elements.div_ceil(block_size);
171        n_blocks * block_bytes
172    }
173}
174
175impl From<GgmlType> for DType {
176    fn from(ggml_type: GgmlType) -> Self {
177        match ggml_type {
178            GgmlType::F32 => DType::F32,
179            GgmlType::F16 => DType::F16,
180            GgmlType::BF16 => DType::BF16,
181            GgmlType::F64 => DType::F64,
182            GgmlType::I8 => DType::I8,
183            GgmlType::I16 => DType::I16,
184            GgmlType::I32 => DType::I32,
185            GgmlType::I64 => DType::I64,
186            GgmlType::Q4_0 => DType::Q4_0,
187            GgmlType::Q4_1 => DType::Q4_1,
188            GgmlType::Q5_0 => DType::Q5_0,
189            GgmlType::Q5_1 => DType::Q5_1,
190            GgmlType::Q8_0 => DType::Q8_0,
191            GgmlType::Q8_1 => DType::Q8_1,
192            GgmlType::Q2K => DType::Q2K,
193            GgmlType::Q3K => DType::Q3K,
194            GgmlType::Q4K => DType::Q4K,
195            GgmlType::Q5K => DType::Q5K,
196            GgmlType::Q6K => DType::Q6K,
197            GgmlType::Q8K => DType::Q8K,
198            GgmlType::IQ2XXS => DType::IQ2XXS,
199            GgmlType::IQ2XS => DType::IQ2XS,
200            GgmlType::IQ2S => DType::IQ2S,
201            GgmlType::IQ3XXS => DType::IQ3XXS,
202            GgmlType::IQ3S => DType::IQ3S,
203            GgmlType::IQ4XS => DType::IQ4XS,
204            GgmlType::IQ4NL => DType::IQ4NL,
205            GgmlType::IQ1S => DType::IQ1S,
206            GgmlType::IQ1M => DType::IQ1M,
207            GgmlType::TQ1_0 => DType::TQ1_0,
208            GgmlType::TQ2_0 => DType::TQ2_0,
209        }
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn test_dtype_size_for_elements() {
219        // F32: 4 bytes per element
220        assert_eq!(DType::F32.size_for_elements(10), 40);
221
222        // Q4_0: 18 bytes per 32 elements
223        assert_eq!(DType::Q4_0.size_for_elements(32), 18);
224        assert_eq!(DType::Q4_0.size_for_elements(64), 36);
225
226        // Q4K: 144 bytes per 256 elements
227        assert_eq!(DType::Q4K.size_for_elements(256), 144);
228        assert_eq!(DType::Q4K.size_for_elements(512), 288);
229    }
230
231    #[test]
232    fn test_is_quantized() {
233        assert!(!DType::F32.is_quantized());
234        assert!(!DType::F16.is_quantized());
235        assert!(!DType::I32.is_quantized());
236        assert!(DType::Q4_0.is_quantized());
237        assert!(DType::Q4K.is_quantized());
238        assert!(DType::IQ2XXS.is_quantized());
239    }
240
241    #[test]
242    fn test_from_ggml_type() {
243        assert_eq!(DType::from(GgmlType::F32), DType::F32);
244        assert_eq!(DType::from(GgmlType::Q4_0), DType::Q4_0);
245        assert_eq!(DType::from(GgmlType::Q4K), DType::Q4K);
246    }
247}