Skip to main content

openinfer_simulator/runtime/
model_loader.rs

1//! Lazy `.oinf` model loader.
2//!
3//! The loader memory-maps the model file, validates headers, and loads tensor
4//! payloads only on demand.
5use std::collections::HashMap;
6use std::fs::File;
7use std::path::{Path, PathBuf};
8use std::sync::Arc;
9
10use anyhow::{anyhow, Context, Result};
11use memmap2::Mmap;
12
13use crate::runtime::tensor_store::{MappedSlice, TensorRef, TensorStore};
14use crate::tensor::{
15    BF16, Bitset, DType, F16, F8, I1, I2, I4, T1, T2, U1, U2, U4, Tensor, TensorValue,
16};
17use crate::types::VarInfo;
18
19const MAGIC: &[u8; 5] = b"OINF\0";
20const HEADER_SIZE: usize = 69;
21
22#[allow(dead_code)]
23#[derive(Debug, Clone)]
24struct MetadataInfo {
25    value_type: u32,
26    value_offset: u64,
27    value_nbytes: u64,
28    dims: Vec<u64>,
29}
30
31/// Loads `.oinf` model files and exposes tensors/metadata.
32#[derive(Debug, Clone)]
33pub struct ModelLoader {
34    #[allow(dead_code)]
35    path: PathBuf,
36    sizes: HashMap<String, usize>,
37    vars: HashMap<String, VarInfo>,
38    #[allow(dead_code)]
39    metadata: HashMap<String, MetadataInfo>,
40    mmap: Arc<Mmap>,
41    tensor_store: TensorStore,
42}
43
44impl ModelLoader {
45    /// Open an `.oinf` model file from disk.
46    ///
47    /// # Example
48    /// ```no_run
49    /// # use openinfer::ModelLoader;
50    /// # fn main() -> anyhow::Result<()> {
51    /// let model = ModelLoader::open("model.oinf")?;
52    /// # Ok(()) }
53    /// ```
54    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
55        let path = path.as_ref().to_path_buf();
56        let file = File::open(&path).with_context(|| "open model file")?;
57        let mmap = unsafe { Mmap::map(&file).with_context(|| "mmap model file")? };
58        let data = &mmap[..];
59        if data.len() < HEADER_SIZE {
60            return Err(anyhow!("file too small for OINF header"));
61        }
62
63        let mut cursor = 0usize;
64        let magic = read_bytes(data, &mut cursor, 5)?;
65        if magic != MAGIC {
66            return Err(anyhow!("invalid OINF magic"));
67        }
68        let version = read_u32(data, &mut cursor)?;
69        if version != 1 {
70            return Err(anyhow!("unsupported OINF version {}", version));
71        }
72        let _flags = read_u32(data, &mut cursor)?;
73        let n_sizevars = read_u32(data, &mut cursor)? as usize;
74        let n_metadata = read_u32(data, &mut cursor)? as usize;
75        let n_tensors = read_u32(data, &mut cursor)? as usize;
76        let _reserved = read_u32(data, &mut cursor)?;
77        let offset_sizevars = read_u64(data, &mut cursor)? as usize;
78        let offset_metadata = read_u64(data, &mut cursor)? as usize;
79        let offset_tensors = read_u64(data, &mut cursor)? as usize;
80        let offset_data = read_u64(data, &mut cursor)? as usize;
81        let file_size = read_u64(data, &mut cursor)? as usize;
82
83        if file_size != data.len() {
84            return Err(anyhow!("file size mismatch"));
85        }
86        let offsets = vec![
87            offset_sizevars,
88            offset_metadata,
89            offset_tensors,
90            offset_data,
91            file_size,
92        ];
93        let mut sorted = offsets.clone();
94        sorted.sort_unstable();
95        if offsets != sorted {
96            return Err(anyhow!("OINF offsets are not ascending"));
97        }
98        for off in offsets.iter().take(4) {
99            if *off % 8 != 0 {
100                return Err(anyhow!("OINF section offset not aligned"));
101            }
102            if *off > file_size {
103                return Err(anyhow!("OINF section offset out of bounds"));
104            }
105        }
106
107        let mut sizes = HashMap::new();
108        let mut size_cursor = offset_sizevars;
109        for _ in 0..n_sizevars {
110            let name = read_string(data, &mut size_cursor)?;
111            if sizes.contains_key(&name) {
112                return Err(anyhow!("duplicate sizevar {}", name));
113            }
114            let value = read_u64_at(data, size_cursor)?;
115            size_cursor += 8;
116            sizes.insert(name, value as usize);
117        }
118
119        let mut metadata = HashMap::new();
120        let mut meta_cursor = offset_metadata;
121        for _ in 0..n_metadata {
122            let key = read_string(data, &mut meta_cursor)?;
123            if metadata.contains_key(&key) {
124                return Err(anyhow!("duplicate metadata key {}", key));
125            }
126            let value_type = read_u32_at(data, meta_cursor)?;
127            let flags = read_u32_at(data, meta_cursor + 4)?;
128            let value_nbytes = read_u64_at(data, meta_cursor + 8)?;
129            let value_offset = read_u64_at(data, meta_cursor + 16)?;
130            meta_cursor += 24;
131            if flags != 0 {
132                return Err(anyhow!("metadata flags must be 0"));
133            }
134            if value_offset % 8 != 0 {
135                return Err(anyhow!("metadata value offset not aligned"));
136            }
137            let value_end = value_offset
138                .checked_add(value_nbytes)
139                .ok_or_else(|| anyhow!("metadata value offset overflow"))?;
140            if value_end as usize > file_size {
141                return Err(anyhow!("metadata value out of bounds"));
142            }
143
144            let mut dims = Vec::new();
145            if value_type == ValueType::NDARRAY {
146                let mut cursor = value_offset as usize;
147                let element_type = read_u32(data, &mut cursor)?;
148                let ndim = read_u32(data, &mut cursor)? as usize;
149                if !ValueType::is_scalar(element_type) {
150                    return Err(anyhow!("metadata ndarray has invalid element type"));
151                }
152                for _ in 0..ndim {
153                    dims.push(read_u64(data, &mut cursor)?);
154                }
155            }
156
157            metadata.insert(
158                key,
159                MetadataInfo {
160                    value_type,
161                    value_offset,
162                    value_nbytes,
163                    dims,
164                },
165            );
166        }
167
168        let mut vars = HashMap::new();
169        let mut tensor_cursor = offset_tensors;
170        for _ in 0..n_tensors {
171            let name = read_string(data, &mut tensor_cursor)?;
172            if vars.contains_key(&name) {
173                return Err(anyhow!("duplicate tensor name {}", name));
174            }
175            let dtype_raw = read_u32(data, &mut tensor_cursor)?;
176            let ndim = read_u32(data, &mut tensor_cursor)? as usize;
177            let flags = read_u32(data, &mut tensor_cursor)?;
178            let mut dims = Vec::new();
179            for _ in 0..ndim {
180                dims.push(read_u64(data, &mut tensor_cursor)?);
181            }
182            let data_nbytes = read_u64(data, &mut tensor_cursor)? as usize;
183            let data_offset = read_u64(data, &mut tensor_cursor)? as usize;
184
185            let dtype = ValueType::to_dtype(dtype_raw)?;
186            let has_data = (flags & 1) != 0;
187            if has_data {
188                if data_offset % 8 != 0 {
189                    return Err(anyhow!("tensor data offset not aligned"));
190                }
191                if data_offset < offset_data {
192                    return Err(anyhow!("tensor data offset precedes data section"));
193                }
194                if data_offset + data_nbytes > file_size {
195                    return Err(anyhow!("tensor data out of bounds"));
196                }
197            } else if data_offset != 0 || data_nbytes != 0 {
198                return Err(anyhow!("tensor without data must have zero offset/size"));
199            }
200
201            let dims_str = dims.iter().map(|d| d.to_string()).collect();
202            let value_range = if has_data {
203                Some((data_offset, data_offset + data_nbytes))
204            } else {
205                None
206            };
207            vars.insert(
208                name.clone(),
209                VarInfo {
210                    name,
211                    dtype,
212                    dims: dims_str,
213                    value_range,
214                    has_data,
215                },
216            );
217        }
218
219        let mmap = Arc::new(mmap);
220        let tensor_store = build_tensor_store(&sizes, &vars, mmap.clone())?;
221
222        Ok(Self {
223            path,
224            sizes,
225            vars,
226            metadata,
227            mmap,
228            tensor_store,
229        })
230    }
231
232    /// Lookup a size variable by name.
233    ///
234    /// # Example
235    /// ```no_run
236    /// # use openinfer::ModelLoader;
237    /// # fn main() -> anyhow::Result<()> {
238    /// let model = ModelLoader::open("model.oinf")?;
239    /// let b = model.size_of("B")?;
240    /// # Ok(()) }
241    /// ```
242    pub fn size_of(&self, name: &str) -> Result<usize> {
243        self.sizes
244            .get(name)
245            .copied()
246            .ok_or_else(|| anyhow!("unknown size: {}", name))
247    }
248
249    /// Resolve a product of dimension strings into a length.
250    pub fn resolve_len(&self, dims: &[String]) -> Result<usize> {
251        let mut total = 1usize;
252        for dim in dims {
253            total = total.saturating_mul(self.resolve_dim_value(dim)?);
254        }
255        Ok(total)
256    }
257
258    /// Resolve dimension strings into a concrete shape.
259    ///
260    /// # Example
261    /// ```no_run
262    /// # use openinfer::ModelLoader;
263    /// # fn main() -> anyhow::Result<()> {
264    /// let model = ModelLoader::open("model.oinf")?;
265    /// let shape = model.resolve_shape(&["B".into(), "D".into()])?;
266    /// # Ok(()) }
267    /// ```
268    pub fn resolve_shape(&self, dims: &[String]) -> Result<Vec<usize>> {
269        let mut shape = Vec::with_capacity(dims.len());
270        for dim in dims {
271            shape.push(self.resolve_dim_value(dim)?);
272        }
273        Ok(shape)
274    }
275
276    /// Resolve a single dimension expression (literal, sizevar, or product).
277    pub fn resolve_dim_value(&self, dim: &str) -> Result<usize> {
278        if let Ok(val) = dim.parse::<usize>() {
279            return Ok(val);
280        }
281        let trimmed = dim.trim();
282        if let Some((left, right)) = trimmed.split_once('*') {
283            let left = left.trim();
284            let right = right.trim();
285            let left_val = match left.parse::<usize>() {
286                Ok(value) => value,
287                Err(_) => self.size_of(left)?,
288            };
289            let right_val = match right.parse::<usize>() {
290                Ok(value) => value,
291                Err(_) => self.size_of(right)?,
292            };
293            return Ok(left_val.saturating_mul(right_val));
294        }
295        self.size_of(trimmed)
296    }
297
298    /// Fetch variable metadata by name.
299    pub fn var_info(&self, name: &str) -> Option<&VarInfo> {
300        self.vars.get(name)
301    }
302
303    /// Access the underlying tensor store.
304    pub fn tensor_store(&self) -> &TensorStore {
305        &self.tensor_store
306    }
307
308    /// Load a tensor payload by name from the mapped file.
309    ///
310    /// # Example
311    /// ```no_run
312    /// # use openinfer::ModelLoader;
313    /// # fn main() -> anyhow::Result<()> {
314    /// let model = ModelLoader::open("model.oinf")?;
315    /// let tensor = model.load_tensor("w1")?;
316    /// # Ok(()) }
317    /// ```
318    pub fn load_tensor(&self, name: &str) -> Result<TensorValue> {
319        let info = self
320            .vars
321            .get(name)
322            .ok_or_else(|| anyhow!("unknown variable: {}", name))?;
323        if !info.has_data {
324            return Err(anyhow!("no data found for {}", name));
325        }
326        let range = info
327            .value_range
328            .ok_or_else(|| anyhow!("missing data range for {}", name))?;
329        let data = &self.mmap[range.0..range.1];
330        tensor_value_from_bytes(info, data)
331    }
332
333    /// Load a metadata tensor by name, if present.
334    pub fn load_metadata_tensor(&self, name: &str) -> Result<Option<TensorValue>> {
335        let info = match self.metadata.get(name) {
336            Some(info) => info,
337            None => return Ok(None),
338        };
339        let data = &self.mmap[..];
340        let start = info.value_offset as usize;
341        let end = start + info.value_nbytes as usize;
342        if end > data.len() {
343            return Err(anyhow!("metadata value out of bounds for {}", name));
344        }
345
346        if info.value_type == ValueType::STRING {
347            return Err(anyhow!("metadata {} is a string, not a tensor", name));
348        }
349
350        if info.value_type == ValueType::BITSET {
351            if info.value_nbytes < 8 {
352                return Err(anyhow!("bitset metadata too small for {}", name));
353            }
354            let bits = read_u32_at(data, start)? as usize;
355            let packed_len = read_u32_at(data, start + 4)? as usize;
356            if start + 8 + packed_len > end {
357                return Err(anyhow!("bitset metadata payload out of bounds for {}", name));
358            }
359            let packed = &data[start + 8..start + 8 + packed_len];
360            let first = packed.first().copied().unwrap_or(0);
361            if bits > 8 {
362                return Err(anyhow!("bitset metadata too large for {}", name));
363            }
364            return Ok(Some(TensorValue::from(Bitset { bits: first })));
365        }
366
367        if info.value_type == ValueType::NDARRAY {
368            let mut cursor = start;
369            let element_type = read_u32(data, &mut cursor)?;
370            let ndim = read_u32(data, &mut cursor)? as usize;
371            let mut dims = Vec::with_capacity(ndim);
372            for _ in 0..ndim {
373                dims.push(read_u64(data, &mut cursor)?);
374            }
375            let dtype = ValueType::to_dtype(element_type)?;
376            let var_info = VarInfo {
377                name: name.to_string(),
378                dtype,
379                dims: dims.iter().map(|d| d.to_string()).collect(),
380                value_range: None,
381                has_data: true,
382            };
383            let payload = &data[cursor..end];
384            return tensor_value_from_bytes(&var_info, payload).map(Some);
385        }
386
387        let dtype = ValueType::to_dtype(info.value_type)?;
388        let var_info = VarInfo {
389            name: name.to_string(),
390            dtype,
391            dims: Vec::new(),
392            value_range: None,
393            has_data: true,
394        };
395        let payload = &data[start..end];
396        tensor_value_from_bytes(&var_info, payload).map(Some)
397    }
398
399    /// True if a named metadata entry is a string.
400    pub fn has_metadata_string(&self, name: &str) -> bool {
401        self.metadata
402            .get(name)
403            .map(|info| info.value_type == ValueType::STRING)
404            .unwrap_or(false)
405    }
406
407    /// Load a metadata string by name, if present.
408    pub fn load_metadata_string(&self, name: &str) -> Result<Option<String>> {
409        let info = match self.metadata.get(name) {
410            Some(info) => info,
411            None => return Ok(None),
412        };
413        if info.value_type != ValueType::STRING {
414            return Ok(None);
415        }
416        let data = &self.mmap[..];
417        let start = info.value_offset as usize;
418        let end = start + info.value_nbytes as usize;
419        if end > data.len() {
420            return Err(anyhow!("metadata value out of bounds for {}", name));
421        }
422        if info.value_nbytes < 4 {
423            return Err(anyhow!("metadata string too small for {}", name));
424        }
425
426        let len = read_u32_at(data, start)? as usize;
427        let payload_end = start + 4 + len;
428        if payload_end > end {
429            return Err(anyhow!("metadata string payload out of bounds for {}", name));
430        }
431        let raw = &data[start + 4..payload_end];
432        let text = std::str::from_utf8(raw).context("invalid UTF-8 string")?;
433        let padded = align_up(4 + len, 8);
434        if start + padded > end {
435            return Err(anyhow!("metadata string padding out of bounds for {}", name));
436        }
437        Ok(Some(text.to_string()))
438    }
439}
440
441fn build_tensor_store(
442    sizes: &HashMap<String, usize>,
443    vars: &HashMap<String, VarInfo>,
444    mmap: Arc<Mmap>,
445) -> Result<TensorStore> {
446    let mut tensors = HashMap::new();
447    for (name, info) in vars {
448        let shape = resolve_shape(sizes, &info.dims)?;
449        let data = info.value_range.map(|(start, end)| {
450            MappedSlice::new(mmap.clone(), start..end)
451        });
452        tensors.insert(
453            name.clone(),
454            TensorRef {
455                name: name.clone(),
456                dtype: info.dtype,
457                dims: info.dims.clone(),
458                shape,
459                data,
460            },
461        );
462    }
463    Ok(TensorStore::new(tensors))
464}
465
466fn resolve_shape(sizes: &HashMap<String, usize>, dims: &[String]) -> Result<Vec<usize>> {
467    let mut shape = Vec::with_capacity(dims.len());
468    for dim in dims {
469        shape.push(resolve_dim_value(sizes, dim)?);
470    }
471    Ok(shape)
472}
473
474fn resolve_dim_value(sizes: &HashMap<String, usize>, dim: &str) -> Result<usize> {
475    if let Ok(val) = dim.parse::<usize>() {
476        return Ok(val);
477    }
478    let trimmed = dim.trim();
479    if let Some((left, right)) = trimmed.split_once('*') {
480        let left = left.trim();
481        let right = right.trim();
482        let left_val = match left.parse::<usize>() {
483            Ok(value) => value,
484            Err(_) => sizes
485                .get(left)
486                .copied()
487                .ok_or_else(|| anyhow!("unknown size: {}", left))?,
488        };
489        let right_val = match right.parse::<usize>() {
490            Ok(value) => value,
491            Err(_) => sizes
492                .get(right)
493                .copied()
494                .ok_or_else(|| anyhow!("unknown size: {}", right))?,
495        };
496        return Ok(left_val.saturating_mul(right_val));
497    }
498    sizes
499        .get(trimmed)
500        .copied()
501        .ok_or_else(|| anyhow!("unknown size: {}", trimmed))
502}
503
504fn read_bytes<'a>(data: &'a [u8], cursor: &mut usize, len: usize) -> Result<&'a [u8]> {
505    if *cursor + len > data.len() {
506        return Err(anyhow!("unexpected EOF"));
507    }
508    let out = &data[*cursor..*cursor + len];
509    *cursor += len;
510    Ok(out)
511}
512
513fn read_u32(data: &[u8], cursor: &mut usize) -> Result<u32> {
514    let bytes = read_bytes(data, cursor, 4)?;
515    Ok(u32::from_le_bytes(bytes.try_into().unwrap()))
516}
517
518fn read_u64(data: &[u8], cursor: &mut usize) -> Result<u64> {
519    let bytes = read_bytes(data, cursor, 8)?;
520    Ok(u64::from_le_bytes(bytes.try_into().unwrap()))
521}
522
523fn read_u32_at(data: &[u8], offset: usize) -> Result<u32> {
524    if offset + 4 > data.len() {
525        return Err(anyhow!("unexpected EOF"));
526    }
527    Ok(u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()))
528}
529
530fn read_u64_at(data: &[u8], offset: usize) -> Result<u64> {
531    if offset + 8 > data.len() {
532        return Err(anyhow!("unexpected EOF"));
533    }
534    Ok(u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()))
535}
536
537fn read_string(data: &[u8], cursor: &mut usize) -> Result<String> {
538    let len = read_u32(data, cursor)? as usize;
539    let bytes = read_bytes(data, cursor, len)?;
540    let s = std::str::from_utf8(bytes).context("invalid UTF-8 string")?;
541    let padded = align_up(4 + len, 8);
542    let consumed = 4 + len;
543    if padded > consumed {
544        let skip = padded - consumed;
545        if *cursor + skip > data.len() {
546            return Err(anyhow!("unexpected EOF"));
547        }
548        *cursor += skip;
549    }
550    Ok(s.to_string())
551}
552
553fn align_up(value: usize, alignment: usize) -> usize {
554    (value + alignment - 1) / alignment * alignment
555}
556
557fn tensor_value_from_bytes(info: &VarInfo, bytes: &[u8]) -> Result<TensorValue> {
558    match info.dtype {
559        DType::I8 => tensor_from_bytes::<i8>(info, bytes).map(TensorValue::I8),
560        DType::I16 => tensor_from_bytes::<i16>(info, bytes).map(TensorValue::I16),
561        DType::I32 => tensor_from_bytes::<i32>(info, bytes).map(TensorValue::I32),
562        DType::I64 => tensor_from_bytes::<i64>(info, bytes).map(TensorValue::I64),
563        DType::U8 => tensor_from_bytes::<u8>(info, bytes).map(TensorValue::U8),
564        DType::U16 => tensor_from_bytes::<u16>(info, bytes).map(TensorValue::U16),
565        DType::U32 => tensor_from_bytes::<u32>(info, bytes).map(TensorValue::U32),
566        DType::U64 => tensor_from_bytes::<u64>(info, bytes).map(TensorValue::U64),
567        DType::F16 => tensor_from_bits::<u16, F16>(info, bytes, |bits| F16 { bits }).map(TensorValue::F16),
568        DType::BF16 => tensor_from_bits::<u16, BF16>(info, bytes, |bits| BF16 { bits }).map(TensorValue::BF16),
569        DType::F8 => tensor_from_bits::<u8, F8>(info, bytes, |bits| F8 { bits }).map(TensorValue::F8),
570        DType::F32 => tensor_from_bytes::<f32>(info, bytes).map(TensorValue::F32),
571        DType::F64 => tensor_from_bytes::<f64>(info, bytes).map(TensorValue::F64),
572        DType::Bool => tensor_from_bytes::<bool>(info, bytes).map(TensorValue::Bool),
573        DType::Bitset => tensor_from_bits::<u8, Bitset>(info, bytes, |bits| Bitset { bits }).map(TensorValue::Bitset),
574        DType::I4 => tensor_from_bits::<u8, I4>(info, bytes, |bits| I4 { bits }).map(TensorValue::I4),
575        DType::I2 => tensor_from_bits::<u8, I2>(info, bytes, |bits| I2 { bits }).map(TensorValue::I2),
576        DType::I1 => tensor_from_bits::<u8, I1>(info, bytes, |bits| I1 { bits }).map(TensorValue::I1),
577        DType::U4 => tensor_from_bits::<u8, U4>(info, bytes, |bits| U4 { bits }).map(TensorValue::U4),
578        DType::U2 => tensor_from_bits::<u8, U2>(info, bytes, |bits| U2 { bits }).map(TensorValue::U2),
579        DType::U1 => tensor_from_bits::<u8, U1>(info, bytes, |bits| U1 { bits }).map(TensorValue::U1),
580        DType::T2 => tensor_from_bits::<u8, T2>(info, bytes, |bits| T2 { bits }).map(TensorValue::T2),
581        DType::T1 => tensor_from_bits::<u8, T1>(info, bytes, |bits| T1 { bits }).map(TensorValue::T1),
582    }
583}
584
585fn tensor_from_bytes<T: Copy>(info: &VarInfo, bytes: &[u8]) -> Result<Tensor<T>> {
586    let shape = info
587        .dims
588        .iter()
589        .map(|dim| dim.parse::<usize>())
590        .collect::<std::result::Result<Vec<_>, _>>()
591        .map_err(|_| anyhow!("invalid tensor dims for {}", info.name))?;
592    let len = shape.iter().product::<usize>();
593    let expected = len * std::mem::size_of::<T>();
594    if bytes.len() != expected {
595        return Err(anyhow!(
596            "tensor {} byte length mismatch: expected {}, got {}",
597            info.name,
598            expected,
599            bytes.len()
600        ));
601    }
602    let mut out = Vec::with_capacity(len);
603    let mut cursor = 0usize;
604    while cursor < bytes.len() {
605        let end = cursor + std::mem::size_of::<T>();
606        let value = read_t::<T>(&bytes[cursor..end])?;
607        out.push(value);
608        cursor = end;
609    }
610    Tensor::from_vec_with_opts(
611        out,
612        crate::tensor::TensorOptions {
613            shape: Some(shape),
614            ..crate::tensor::TensorOptions::default()
615        },
616    )
617}
618
619fn tensor_from_bits<B: Copy, T>(
620    info: &VarInfo,
621    bytes: &[u8],
622    map: fn(B) -> T,
623) -> Result<Tensor<T>> {
624    let shape = info
625        .dims
626        .iter()
627        .map(|dim| dim.parse::<usize>())
628        .collect::<std::result::Result<Vec<_>, _>>()
629        .map_err(|_| anyhow!("invalid tensor dims for {}", info.name))?;
630    let len = shape.iter().product::<usize>();
631    if bytes.is_empty() && len == 0 {
632        return Tensor::from_vec_with_opts(
633            Vec::new(),
634            crate::tensor::TensorOptions {
635                shape: Some(shape),
636                ..crate::tensor::TensorOptions::default()
637            },
638        );
639    }
640    let mut out = Vec::with_capacity(bytes.len());
641    let mut cursor = 0usize;
642    while cursor < bytes.len() {
643        let end = cursor + std::mem::size_of::<B>();
644        let value = read_t::<B>(&bytes[cursor..end])?;
645        out.push(map(value));
646        cursor = end;
647    }
648    Tensor::from_vec_with_opts(
649        out,
650        crate::tensor::TensorOptions {
651            shape: Some(shape),
652            allow_len_mismatch: true,
653            ..crate::tensor::TensorOptions::default()
654        },
655    )
656}
657
658fn read_t<T: Copy>(bytes: &[u8]) -> Result<T> {
659    let mut value = std::mem::MaybeUninit::<T>::uninit();
660    let len = std::mem::size_of::<T>();
661    if bytes.len() != len {
662        return Err(anyhow!("invalid byte length"));
663    }
664    unsafe {
665        std::ptr::copy_nonoverlapping(bytes.as_ptr(), value.as_mut_ptr() as *mut u8, len);
666        Ok(value.assume_init())
667    }
668}
669
670struct ValueType;
671
672impl ValueType {
673    const I8: u32 = 1;
674    const I16: u32 = 2;
675    const I32: u32 = 3;
676    const I64: u32 = 4;
677    const U8: u32 = 5;
678    const U16: u32 = 6;
679    const U32: u32 = 7;
680    const U64: u32 = 8;
681    const F16: u32 = 9;
682    const F32: u32 = 10;
683    const F64: u32 = 11;
684    const BOOL: u32 = 12;
685    const BITSET: u32 = 13;
686    #[allow(dead_code)]
687    const STRING: u32 = 14;
688    const NDARRAY: u32 = 15;
689    const BF16: u32 = 16;
690    const F8: u32 = 17;
691    const I4: u32 = 18;
692    const I2: u32 = 19;
693    const I1: u32 = 20;
694    const U4: u32 = 21;
695    const U2: u32 = 22;
696    const U1: u32 = 23;
697    const T2: u32 = 24;
698    const T1: u32 = 25;
699
700    fn is_scalar(value_type: u32) -> bool {
701        value_type >= Self::I8 && value_type <= Self::T1
702    }
703
704    fn to_dtype(value_type: u32) -> Result<DType> {
705        Ok(match value_type {
706            Self::I8 => DType::I8,
707            Self::I16 => DType::I16,
708            Self::I32 => DType::I32,
709            Self::I64 => DType::I64,
710            Self::U8 => DType::U8,
711            Self::U16 => DType::U16,
712            Self::U32 => DType::U32,
713            Self::U64 => DType::U64,
714            Self::F16 => DType::F16,
715            Self::F32 => DType::F32,
716            Self::F64 => DType::F64,
717            Self::BOOL => DType::Bool,
718            Self::BITSET => DType::Bitset,
719            Self::BF16 => DType::BF16,
720            Self::F8 => DType::F8,
721            Self::I4 => DType::I4,
722            Self::I2 => DType::I2,
723            Self::I1 => DType::I1,
724            Self::U4 => DType::U4,
725            Self::U2 => DType::U2,
726            Self::U1 => DType::U1,
727            Self::T2 => DType::T2,
728            Self::T1 => DType::T1,
729            _ => return Err(anyhow!("unknown tensor dtype {}", value_type)),
730        })
731    }
732}