1use std::collections::HashMap;
6use std::fs::File;
7use std::path::{Path, PathBuf};
8use std::sync::Arc;
9
10use anyhow::{anyhow, Context, Result};
11use memmap2::Mmap;
12
13use crate::runtime::tensor_store::{MappedSlice, TensorRef, TensorStore};
14use crate::tensor::{
15 BF16, Bitset, DType, F16, F8, I1, I2, I4, T1, T2, U1, U2, U4, Tensor, TensorValue,
16};
17use crate::types::VarInfo;
18
19const MAGIC: &[u8; 5] = b"OINF\0";
20const HEADER_SIZE: usize = 69;
21
22#[allow(dead_code)]
23#[derive(Debug, Clone)]
24struct MetadataInfo {
25 value_type: u32,
26 value_offset: u64,
27 value_nbytes: u64,
28 dims: Vec<u64>,
29}
30
31#[derive(Debug, Clone)]
33pub struct ModelLoader {
34 #[allow(dead_code)]
35 path: PathBuf,
36 sizes: HashMap<String, usize>,
37 vars: HashMap<String, VarInfo>,
38 #[allow(dead_code)]
39 metadata: HashMap<String, MetadataInfo>,
40 mmap: Arc<Mmap>,
41 tensor_store: TensorStore,
42}
43
44impl ModelLoader {
45 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
55 let path = path.as_ref().to_path_buf();
56 let file = File::open(&path).with_context(|| "open model file")?;
57 let mmap = unsafe { Mmap::map(&file).with_context(|| "mmap model file")? };
58 let data = &mmap[..];
59 if data.len() < HEADER_SIZE {
60 return Err(anyhow!("file too small for OINF header"));
61 }
62
63 let mut cursor = 0usize;
64 let magic = read_bytes(data, &mut cursor, 5)?;
65 if magic != MAGIC {
66 return Err(anyhow!("invalid OINF magic"));
67 }
68 let version = read_u32(data, &mut cursor)?;
69 if version != 1 {
70 return Err(anyhow!("unsupported OINF version {}", version));
71 }
72 let _flags = read_u32(data, &mut cursor)?;
73 let n_sizevars = read_u32(data, &mut cursor)? as usize;
74 let n_metadata = read_u32(data, &mut cursor)? as usize;
75 let n_tensors = read_u32(data, &mut cursor)? as usize;
76 let _reserved = read_u32(data, &mut cursor)?;
77 let offset_sizevars = read_u64(data, &mut cursor)? as usize;
78 let offset_metadata = read_u64(data, &mut cursor)? as usize;
79 let offset_tensors = read_u64(data, &mut cursor)? as usize;
80 let offset_data = read_u64(data, &mut cursor)? as usize;
81 let file_size = read_u64(data, &mut cursor)? as usize;
82
83 if file_size != data.len() {
84 return Err(anyhow!("file size mismatch"));
85 }
86 let offsets = vec![
87 offset_sizevars,
88 offset_metadata,
89 offset_tensors,
90 offset_data,
91 file_size,
92 ];
93 let mut sorted = offsets.clone();
94 sorted.sort_unstable();
95 if offsets != sorted {
96 return Err(anyhow!("OINF offsets are not ascending"));
97 }
98 for off in offsets.iter().take(4) {
99 if *off % 8 != 0 {
100 return Err(anyhow!("OINF section offset not aligned"));
101 }
102 if *off > file_size {
103 return Err(anyhow!("OINF section offset out of bounds"));
104 }
105 }
106
107 let mut sizes = HashMap::new();
108 let mut size_cursor = offset_sizevars;
109 for _ in 0..n_sizevars {
110 let name = read_string(data, &mut size_cursor)?;
111 if sizes.contains_key(&name) {
112 return Err(anyhow!("duplicate sizevar {}", name));
113 }
114 let value = read_u64_at(data, size_cursor)?;
115 size_cursor += 8;
116 sizes.insert(name, value as usize);
117 }
118
119 let mut metadata = HashMap::new();
120 let mut meta_cursor = offset_metadata;
121 for _ in 0..n_metadata {
122 let key = read_string(data, &mut meta_cursor)?;
123 if metadata.contains_key(&key) {
124 return Err(anyhow!("duplicate metadata key {}", key));
125 }
126 let value_type = read_u32_at(data, meta_cursor)?;
127 let flags = read_u32_at(data, meta_cursor + 4)?;
128 let value_nbytes = read_u64_at(data, meta_cursor + 8)?;
129 let value_offset = read_u64_at(data, meta_cursor + 16)?;
130 meta_cursor += 24;
131 if flags != 0 {
132 return Err(anyhow!("metadata flags must be 0"));
133 }
134 if value_offset % 8 != 0 {
135 return Err(anyhow!("metadata value offset not aligned"));
136 }
137 let value_end = value_offset
138 .checked_add(value_nbytes)
139 .ok_or_else(|| anyhow!("metadata value offset overflow"))?;
140 if value_end as usize > file_size {
141 return Err(anyhow!("metadata value out of bounds"));
142 }
143
144 let mut dims = Vec::new();
145 if value_type == ValueType::NDARRAY {
146 let mut cursor = value_offset as usize;
147 let element_type = read_u32(data, &mut cursor)?;
148 let ndim = read_u32(data, &mut cursor)? as usize;
149 if !ValueType::is_scalar(element_type) {
150 return Err(anyhow!("metadata ndarray has invalid element type"));
151 }
152 for _ in 0..ndim {
153 dims.push(read_u64(data, &mut cursor)?);
154 }
155 }
156
157 metadata.insert(
158 key,
159 MetadataInfo {
160 value_type,
161 value_offset,
162 value_nbytes,
163 dims,
164 },
165 );
166 }
167
168 let mut vars = HashMap::new();
169 let mut tensor_cursor = offset_tensors;
170 for _ in 0..n_tensors {
171 let name = read_string(data, &mut tensor_cursor)?;
172 if vars.contains_key(&name) {
173 return Err(anyhow!("duplicate tensor name {}", name));
174 }
175 let dtype_raw = read_u32(data, &mut tensor_cursor)?;
176 let ndim = read_u32(data, &mut tensor_cursor)? as usize;
177 let flags = read_u32(data, &mut tensor_cursor)?;
178 let mut dims = Vec::new();
179 for _ in 0..ndim {
180 dims.push(read_u64(data, &mut tensor_cursor)?);
181 }
182 let data_nbytes = read_u64(data, &mut tensor_cursor)? as usize;
183 let data_offset = read_u64(data, &mut tensor_cursor)? as usize;
184
185 let dtype = ValueType::to_dtype(dtype_raw)?;
186 let has_data = (flags & 1) != 0;
187 if has_data {
188 if data_offset % 8 != 0 {
189 return Err(anyhow!("tensor data offset not aligned"));
190 }
191 if data_offset < offset_data {
192 return Err(anyhow!("tensor data offset precedes data section"));
193 }
194 if data_offset + data_nbytes > file_size {
195 return Err(anyhow!("tensor data out of bounds"));
196 }
197 } else if data_offset != 0 || data_nbytes != 0 {
198 return Err(anyhow!("tensor without data must have zero offset/size"));
199 }
200
201 let dims_str = dims.iter().map(|d| d.to_string()).collect();
202 let value_range = if has_data {
203 Some((data_offset, data_offset + data_nbytes))
204 } else {
205 None
206 };
207 vars.insert(
208 name.clone(),
209 VarInfo {
210 name,
211 dtype,
212 dims: dims_str,
213 value_range,
214 has_data,
215 },
216 );
217 }
218
219 let mmap = Arc::new(mmap);
220 let tensor_store = build_tensor_store(&sizes, &vars, mmap.clone())?;
221
222 Ok(Self {
223 path,
224 sizes,
225 vars,
226 metadata,
227 mmap,
228 tensor_store,
229 })
230 }
231
232 pub fn size_of(&self, name: &str) -> Result<usize> {
243 self.sizes
244 .get(name)
245 .copied()
246 .ok_or_else(|| anyhow!("unknown size: {}", name))
247 }
248
249 pub fn resolve_len(&self, dims: &[String]) -> Result<usize> {
251 let mut total = 1usize;
252 for dim in dims {
253 total = total.saturating_mul(self.resolve_dim_value(dim)?);
254 }
255 Ok(total)
256 }
257
258 pub fn resolve_shape(&self, dims: &[String]) -> Result<Vec<usize>> {
269 let mut shape = Vec::with_capacity(dims.len());
270 for dim in dims {
271 shape.push(self.resolve_dim_value(dim)?);
272 }
273 Ok(shape)
274 }
275
276 pub fn resolve_dim_value(&self, dim: &str) -> Result<usize> {
278 if let Ok(val) = dim.parse::<usize>() {
279 return Ok(val);
280 }
281 let trimmed = dim.trim();
282 if let Some((left, right)) = trimmed.split_once('*') {
283 let left = left.trim();
284 let right = right.trim();
285 let left_val = match left.parse::<usize>() {
286 Ok(value) => value,
287 Err(_) => self.size_of(left)?,
288 };
289 let right_val = match right.parse::<usize>() {
290 Ok(value) => value,
291 Err(_) => self.size_of(right)?,
292 };
293 return Ok(left_val.saturating_mul(right_val));
294 }
295 self.size_of(trimmed)
296 }
297
298 pub fn var_info(&self, name: &str) -> Option<&VarInfo> {
300 self.vars.get(name)
301 }
302
303 pub fn tensor_store(&self) -> &TensorStore {
305 &self.tensor_store
306 }
307
308 pub fn load_tensor(&self, name: &str) -> Result<TensorValue> {
319 let info = self
320 .vars
321 .get(name)
322 .ok_or_else(|| anyhow!("unknown variable: {}", name))?;
323 if !info.has_data {
324 return Err(anyhow!("no data found for {}", name));
325 }
326 let range = info
327 .value_range
328 .ok_or_else(|| anyhow!("missing data range for {}", name))?;
329 let data = &self.mmap[range.0..range.1];
330 tensor_value_from_bytes(info, data)
331 }
332
333 pub fn load_metadata_tensor(&self, name: &str) -> Result<Option<TensorValue>> {
335 let info = match self.metadata.get(name) {
336 Some(info) => info,
337 None => return Ok(None),
338 };
339 let data = &self.mmap[..];
340 let start = info.value_offset as usize;
341 let end = start + info.value_nbytes as usize;
342 if end > data.len() {
343 return Err(anyhow!("metadata value out of bounds for {}", name));
344 }
345
346 if info.value_type == ValueType::STRING {
347 return Err(anyhow!("metadata {} is a string, not a tensor", name));
348 }
349
350 if info.value_type == ValueType::BITSET {
351 if info.value_nbytes < 8 {
352 return Err(anyhow!("bitset metadata too small for {}", name));
353 }
354 let bits = read_u32_at(data, start)? as usize;
355 let packed_len = read_u32_at(data, start + 4)? as usize;
356 if start + 8 + packed_len > end {
357 return Err(anyhow!("bitset metadata payload out of bounds for {}", name));
358 }
359 let packed = &data[start + 8..start + 8 + packed_len];
360 let first = packed.first().copied().unwrap_or(0);
361 if bits > 8 {
362 return Err(anyhow!("bitset metadata too large for {}", name));
363 }
364 return Ok(Some(TensorValue::from(Bitset { bits: first })));
365 }
366
367 if info.value_type == ValueType::NDARRAY {
368 let mut cursor = start;
369 let element_type = read_u32(data, &mut cursor)?;
370 let ndim = read_u32(data, &mut cursor)? as usize;
371 let mut dims = Vec::with_capacity(ndim);
372 for _ in 0..ndim {
373 dims.push(read_u64(data, &mut cursor)?);
374 }
375 let dtype = ValueType::to_dtype(element_type)?;
376 let var_info = VarInfo {
377 name: name.to_string(),
378 dtype,
379 dims: dims.iter().map(|d| d.to_string()).collect(),
380 value_range: None,
381 has_data: true,
382 };
383 let payload = &data[cursor..end];
384 return tensor_value_from_bytes(&var_info, payload).map(Some);
385 }
386
387 let dtype = ValueType::to_dtype(info.value_type)?;
388 let var_info = VarInfo {
389 name: name.to_string(),
390 dtype,
391 dims: Vec::new(),
392 value_range: None,
393 has_data: true,
394 };
395 let payload = &data[start..end];
396 tensor_value_from_bytes(&var_info, payload).map(Some)
397 }
398
399 pub fn has_metadata_string(&self, name: &str) -> bool {
401 self.metadata
402 .get(name)
403 .map(|info| info.value_type == ValueType::STRING)
404 .unwrap_or(false)
405 }
406
407 pub fn load_metadata_string(&self, name: &str) -> Result<Option<String>> {
409 let info = match self.metadata.get(name) {
410 Some(info) => info,
411 None => return Ok(None),
412 };
413 if info.value_type != ValueType::STRING {
414 return Ok(None);
415 }
416 let data = &self.mmap[..];
417 let start = info.value_offset as usize;
418 let end = start + info.value_nbytes as usize;
419 if end > data.len() {
420 return Err(anyhow!("metadata value out of bounds for {}", name));
421 }
422 if info.value_nbytes < 4 {
423 return Err(anyhow!("metadata string too small for {}", name));
424 }
425
426 let len = read_u32_at(data, start)? as usize;
427 let payload_end = start + 4 + len;
428 if payload_end > end {
429 return Err(anyhow!("metadata string payload out of bounds for {}", name));
430 }
431 let raw = &data[start + 4..payload_end];
432 let text = std::str::from_utf8(raw).context("invalid UTF-8 string")?;
433 let padded = align_up(4 + len, 8);
434 if start + padded > end {
435 return Err(anyhow!("metadata string padding out of bounds for {}", name));
436 }
437 Ok(Some(text.to_string()))
438 }
439}
440
441fn build_tensor_store(
442 sizes: &HashMap<String, usize>,
443 vars: &HashMap<String, VarInfo>,
444 mmap: Arc<Mmap>,
445) -> Result<TensorStore> {
446 let mut tensors = HashMap::new();
447 for (name, info) in vars {
448 let shape = resolve_shape(sizes, &info.dims)?;
449 let data = info.value_range.map(|(start, end)| {
450 MappedSlice::new(mmap.clone(), start..end)
451 });
452 tensors.insert(
453 name.clone(),
454 TensorRef {
455 name: name.clone(),
456 dtype: info.dtype,
457 dims: info.dims.clone(),
458 shape,
459 data,
460 },
461 );
462 }
463 Ok(TensorStore::new(tensors))
464}
465
466fn resolve_shape(sizes: &HashMap<String, usize>, dims: &[String]) -> Result<Vec<usize>> {
467 let mut shape = Vec::with_capacity(dims.len());
468 for dim in dims {
469 shape.push(resolve_dim_value(sizes, dim)?);
470 }
471 Ok(shape)
472}
473
474fn resolve_dim_value(sizes: &HashMap<String, usize>, dim: &str) -> Result<usize> {
475 if let Ok(val) = dim.parse::<usize>() {
476 return Ok(val);
477 }
478 let trimmed = dim.trim();
479 if let Some((left, right)) = trimmed.split_once('*') {
480 let left = left.trim();
481 let right = right.trim();
482 let left_val = match left.parse::<usize>() {
483 Ok(value) => value,
484 Err(_) => sizes
485 .get(left)
486 .copied()
487 .ok_or_else(|| anyhow!("unknown size: {}", left))?,
488 };
489 let right_val = match right.parse::<usize>() {
490 Ok(value) => value,
491 Err(_) => sizes
492 .get(right)
493 .copied()
494 .ok_or_else(|| anyhow!("unknown size: {}", right))?,
495 };
496 return Ok(left_val.saturating_mul(right_val));
497 }
498 sizes
499 .get(trimmed)
500 .copied()
501 .ok_or_else(|| anyhow!("unknown size: {}", trimmed))
502}
503
504fn read_bytes<'a>(data: &'a [u8], cursor: &mut usize, len: usize) -> Result<&'a [u8]> {
505 if *cursor + len > data.len() {
506 return Err(anyhow!("unexpected EOF"));
507 }
508 let out = &data[*cursor..*cursor + len];
509 *cursor += len;
510 Ok(out)
511}
512
513fn read_u32(data: &[u8], cursor: &mut usize) -> Result<u32> {
514 let bytes = read_bytes(data, cursor, 4)?;
515 Ok(u32::from_le_bytes(bytes.try_into().unwrap()))
516}
517
518fn read_u64(data: &[u8], cursor: &mut usize) -> Result<u64> {
519 let bytes = read_bytes(data, cursor, 8)?;
520 Ok(u64::from_le_bytes(bytes.try_into().unwrap()))
521}
522
523fn read_u32_at(data: &[u8], offset: usize) -> Result<u32> {
524 if offset + 4 > data.len() {
525 return Err(anyhow!("unexpected EOF"));
526 }
527 Ok(u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()))
528}
529
530fn read_u64_at(data: &[u8], offset: usize) -> Result<u64> {
531 if offset + 8 > data.len() {
532 return Err(anyhow!("unexpected EOF"));
533 }
534 Ok(u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()))
535}
536
537fn read_string(data: &[u8], cursor: &mut usize) -> Result<String> {
538 let len = read_u32(data, cursor)? as usize;
539 let bytes = read_bytes(data, cursor, len)?;
540 let s = std::str::from_utf8(bytes).context("invalid UTF-8 string")?;
541 let padded = align_up(4 + len, 8);
542 let consumed = 4 + len;
543 if padded > consumed {
544 let skip = padded - consumed;
545 if *cursor + skip > data.len() {
546 return Err(anyhow!("unexpected EOF"));
547 }
548 *cursor += skip;
549 }
550 Ok(s.to_string())
551}
552
553fn align_up(value: usize, alignment: usize) -> usize {
554 (value + alignment - 1) / alignment * alignment
555}
556
557fn tensor_value_from_bytes(info: &VarInfo, bytes: &[u8]) -> Result<TensorValue> {
558 match info.dtype {
559 DType::I8 => tensor_from_bytes::<i8>(info, bytes).map(TensorValue::I8),
560 DType::I16 => tensor_from_bytes::<i16>(info, bytes).map(TensorValue::I16),
561 DType::I32 => tensor_from_bytes::<i32>(info, bytes).map(TensorValue::I32),
562 DType::I64 => tensor_from_bytes::<i64>(info, bytes).map(TensorValue::I64),
563 DType::U8 => tensor_from_bytes::<u8>(info, bytes).map(TensorValue::U8),
564 DType::U16 => tensor_from_bytes::<u16>(info, bytes).map(TensorValue::U16),
565 DType::U32 => tensor_from_bytes::<u32>(info, bytes).map(TensorValue::U32),
566 DType::U64 => tensor_from_bytes::<u64>(info, bytes).map(TensorValue::U64),
567 DType::F16 => tensor_from_bits::<u16, F16>(info, bytes, |bits| F16 { bits }).map(TensorValue::F16),
568 DType::BF16 => tensor_from_bits::<u16, BF16>(info, bytes, |bits| BF16 { bits }).map(TensorValue::BF16),
569 DType::F8 => tensor_from_bits::<u8, F8>(info, bytes, |bits| F8 { bits }).map(TensorValue::F8),
570 DType::F32 => tensor_from_bytes::<f32>(info, bytes).map(TensorValue::F32),
571 DType::F64 => tensor_from_bytes::<f64>(info, bytes).map(TensorValue::F64),
572 DType::Bool => tensor_from_bytes::<bool>(info, bytes).map(TensorValue::Bool),
573 DType::Bitset => tensor_from_bits::<u8, Bitset>(info, bytes, |bits| Bitset { bits }).map(TensorValue::Bitset),
574 DType::I4 => tensor_from_bits::<u8, I4>(info, bytes, |bits| I4 { bits }).map(TensorValue::I4),
575 DType::I2 => tensor_from_bits::<u8, I2>(info, bytes, |bits| I2 { bits }).map(TensorValue::I2),
576 DType::I1 => tensor_from_bits::<u8, I1>(info, bytes, |bits| I1 { bits }).map(TensorValue::I1),
577 DType::U4 => tensor_from_bits::<u8, U4>(info, bytes, |bits| U4 { bits }).map(TensorValue::U4),
578 DType::U2 => tensor_from_bits::<u8, U2>(info, bytes, |bits| U2 { bits }).map(TensorValue::U2),
579 DType::U1 => tensor_from_bits::<u8, U1>(info, bytes, |bits| U1 { bits }).map(TensorValue::U1),
580 DType::T2 => tensor_from_bits::<u8, T2>(info, bytes, |bits| T2 { bits }).map(TensorValue::T2),
581 DType::T1 => tensor_from_bits::<u8, T1>(info, bytes, |bits| T1 { bits }).map(TensorValue::T1),
582 }
583}
584
585fn tensor_from_bytes<T: Copy>(info: &VarInfo, bytes: &[u8]) -> Result<Tensor<T>> {
586 let shape = info
587 .dims
588 .iter()
589 .map(|dim| dim.parse::<usize>())
590 .collect::<std::result::Result<Vec<_>, _>>()
591 .map_err(|_| anyhow!("invalid tensor dims for {}", info.name))?;
592 let len = shape.iter().product::<usize>();
593 let expected = len * std::mem::size_of::<T>();
594 if bytes.len() != expected {
595 return Err(anyhow!(
596 "tensor {} byte length mismatch: expected {}, got {}",
597 info.name,
598 expected,
599 bytes.len()
600 ));
601 }
602 let mut out = Vec::with_capacity(len);
603 let mut cursor = 0usize;
604 while cursor < bytes.len() {
605 let end = cursor + std::mem::size_of::<T>();
606 let value = read_t::<T>(&bytes[cursor..end])?;
607 out.push(value);
608 cursor = end;
609 }
610 Tensor::from_vec_with_opts(
611 out,
612 crate::tensor::TensorOptions {
613 shape: Some(shape),
614 ..crate::tensor::TensorOptions::default()
615 },
616 )
617}
618
619fn tensor_from_bits<B: Copy, T>(
620 info: &VarInfo,
621 bytes: &[u8],
622 map: fn(B) -> T,
623) -> Result<Tensor<T>> {
624 let shape = info
625 .dims
626 .iter()
627 .map(|dim| dim.parse::<usize>())
628 .collect::<std::result::Result<Vec<_>, _>>()
629 .map_err(|_| anyhow!("invalid tensor dims for {}", info.name))?;
630 let len = shape.iter().product::<usize>();
631 if bytes.is_empty() && len == 0 {
632 return Tensor::from_vec_with_opts(
633 Vec::new(),
634 crate::tensor::TensorOptions {
635 shape: Some(shape),
636 ..crate::tensor::TensorOptions::default()
637 },
638 );
639 }
640 let mut out = Vec::with_capacity(bytes.len());
641 let mut cursor = 0usize;
642 while cursor < bytes.len() {
643 let end = cursor + std::mem::size_of::<B>();
644 let value = read_t::<B>(&bytes[cursor..end])?;
645 out.push(map(value));
646 cursor = end;
647 }
648 Tensor::from_vec_with_opts(
649 out,
650 crate::tensor::TensorOptions {
651 shape: Some(shape),
652 allow_len_mismatch: true,
653 ..crate::tensor::TensorOptions::default()
654 },
655 )
656}
657
658fn read_t<T: Copy>(bytes: &[u8]) -> Result<T> {
659 let mut value = std::mem::MaybeUninit::<T>::uninit();
660 let len = std::mem::size_of::<T>();
661 if bytes.len() != len {
662 return Err(anyhow!("invalid byte length"));
663 }
664 unsafe {
665 std::ptr::copy_nonoverlapping(bytes.as_ptr(), value.as_mut_ptr() as *mut u8, len);
666 Ok(value.assume_init())
667 }
668}
669
670struct ValueType;
671
672impl ValueType {
673 const I8: u32 = 1;
674 const I16: u32 = 2;
675 const I32: u32 = 3;
676 const I64: u32 = 4;
677 const U8: u32 = 5;
678 const U16: u32 = 6;
679 const U32: u32 = 7;
680 const U64: u32 = 8;
681 const F16: u32 = 9;
682 const F32: u32 = 10;
683 const F64: u32 = 11;
684 const BOOL: u32 = 12;
685 const BITSET: u32 = 13;
686 #[allow(dead_code)]
687 const STRING: u32 = 14;
688 const NDARRAY: u32 = 15;
689 const BF16: u32 = 16;
690 const F8: u32 = 17;
691 const I4: u32 = 18;
692 const I2: u32 = 19;
693 const I1: u32 = 20;
694 const U4: u32 = 21;
695 const U2: u32 = 22;
696 const U1: u32 = 23;
697 const T2: u32 = 24;
698 const T1: u32 = 25;
699
700 fn is_scalar(value_type: u32) -> bool {
701 value_type >= Self::I8 && value_type <= Self::T1
702 }
703
704 fn to_dtype(value_type: u32) -> Result<DType> {
705 Ok(match value_type {
706 Self::I8 => DType::I8,
707 Self::I16 => DType::I16,
708 Self::I32 => DType::I32,
709 Self::I64 => DType::I64,
710 Self::U8 => DType::U8,
711 Self::U16 => DType::U16,
712 Self::U32 => DType::U32,
713 Self::U64 => DType::U64,
714 Self::F16 => DType::F16,
715 Self::F32 => DType::F32,
716 Self::F64 => DType::F64,
717 Self::BOOL => DType::Bool,
718 Self::BITSET => DType::Bitset,
719 Self::BF16 => DType::BF16,
720 Self::F8 => DType::F8,
721 Self::I4 => DType::I4,
722 Self::I2 => DType::I2,
723 Self::I1 => DType::I1,
724 Self::U4 => DType::U4,
725 Self::U2 => DType::U2,
726 Self::U1 => DType::U1,
727 Self::T2 => DType::T2,
728 Self::T1 => DType::T1,
729 _ => return Err(anyhow!("unknown tensor dtype {}", value_type)),
730 })
731 }
732}