Skip to main content

cjc_runtime/
tensor_dtype.rs

1//! Typed tensor infrastructure: DType enum and TypedStorage.
2//!
3//! This module provides the byte-first storage layer for multi-typed tensors.
4//! The core idea: raw bytes are the primary representation, typed views are
5//! computed on demand. This enables zero-copy serialization (snap), SIMD-friendly
6//! aligned buffers, and memory-efficient storage for non-f64 types.
7//!
8//! ## Byte-First Philosophy
9//!
10//! - `TypedStorage` stores raw `Vec<u8>` + a `DType` tag
11//! - Typed access via `as_f64()`, `as_i64()`, etc. reinterprets bytes in-place
12//! - Serialization = memcpy the byte buffer (no conversion)
13//! - COW semantics via `Rc<RefCell<Vec<u8>>>` (same pattern as Buffer<T>)
14
15use std::cell::{Ref, RefCell};
16use std::rc::Rc;
17
18use crate::accumulator::binned_sum_f64;
19use crate::complex::ComplexF64;
20use crate::error::RuntimeError;
21use crate::value::Bf16;
22
23// ---------------------------------------------------------------------------
24// DType — element type tag for typed tensors
25// ---------------------------------------------------------------------------
26
27/// Element type for typed tensor storage.
28///
29/// Each variant determines how the raw byte buffer is interpreted.
30/// Byte widths are fixed and platform-independent (little-endian canonical).
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
32pub enum DType {
33    /// 64-bit IEEE 754 float (8 bytes per element)
34    F64,
35    /// 32-bit IEEE 754 float (4 bytes per element)
36    F32,
37    /// 64-bit signed integer (8 bytes per element)
38    I64,
39    /// 32-bit signed integer (4 bytes per element)
40    I32,
41    /// 8-bit unsigned integer (1 byte per element)
42    U8,
43    /// Boolean (1 byte per element; 0x00 = false, 0x01 = true)
44    /// Note: NOT packed bits — 1 byte per bool for simplicity and alignment.
45    /// Packed-bit BoolTensor can be a future optimization.
46    Bool,
47    /// Brain float 16-bit (2 bytes per element)
48    Bf16,
49    /// IEEE 754 half-precision float (2 bytes per element)
50    F16,
51    /// Complex f64 pair (16 bytes per element: 8 re + 8 im)
52    Complex,
53}
54
55impl DType {
56    /// Bytes per element for this dtype.
57    pub fn byte_width(&self) -> usize {
58        match self {
59            DType::F64 | DType::I64 => 8,
60            DType::F32 | DType::I32 => 4,
61            DType::Bf16 | DType::F16 => 2,
62            DType::U8 | DType::Bool => 1,
63            DType::Complex => 16,
64        }
65    }
66
67    /// Human-readable name for display and error messages.
68    pub fn name(&self) -> &'static str {
69        match self {
70            DType::F64 => "f64",
71            DType::F32 => "f32",
72            DType::I64 => "i64",
73            DType::I32 => "i32",
74            DType::U8 => "u8",
75            DType::Bool => "bool",
76            DType::Bf16 => "bf16",
77            DType::F16 => "f16",
78            DType::Complex => "complex",
79        }
80    }
81
82    /// Whether this dtype represents a floating-point type.
83    pub fn is_float(&self) -> bool {
84        matches!(self, DType::F64 | DType::F32 | DType::Bf16 | DType::F16)
85    }
86
87    /// Whether this dtype represents an integer type.
88    pub fn is_int(&self) -> bool {
89        matches!(self, DType::I64 | DType::I32 | DType::U8)
90    }
91
92    /// Whether this dtype supports arithmetic operations.
93    pub fn is_numeric(&self) -> bool {
94        !matches!(self, DType::Bool)
95    }
96
97    /// Tag byte used in snap serialization.
98    pub fn snap_tag(&self) -> u8 {
99        match self {
100            DType::F64 => 0,
101            DType::F32 => 1,
102            DType::I64 => 2,
103            DType::I32 => 3,
104            DType::U8 => 4,
105            DType::Bool => 5,
106            DType::Bf16 => 6,
107            DType::F16 => 7,
108            DType::Complex => 8,
109        }
110    }
111
112    /// Reconstruct DType from a snap tag byte.
113    pub fn from_snap_tag(tag: u8) -> Result<Self, String> {
114        match tag {
115            0 => Ok(DType::F64),
116            1 => Ok(DType::F32),
117            2 => Ok(DType::I64),
118            3 => Ok(DType::I32),
119            4 => Ok(DType::U8),
120            5 => Ok(DType::Bool),
121            6 => Ok(DType::Bf16),
122            7 => Ok(DType::F16),
123            8 => Ok(DType::Complex),
124            _ => Err(format!("unknown dtype snap tag: {tag}")),
125        }
126    }
127}
128
129impl std::fmt::Display for DType {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        write!(f, "{}", self.name())
132    }
133}
134
135// ---------------------------------------------------------------------------
136// TypedStorage — byte-first tensor backing store
137// ---------------------------------------------------------------------------
138
139/// Byte-first tensor storage with COW (copy-on-write) semantics.
140///
141/// The raw byte buffer is the canonical representation. Typed views are
142/// created on demand via `as_f64()`, `as_i64()`, etc. This enables:
143///
144/// - Zero-copy snap serialization (bytes ARE the encoded form)
145/// - Memory-mapped I/O (load bytes, interpret in-place)
146/// - SIMD-friendly aligned byte buffers
147/// - Memory-efficient storage (f32 = 50% of f64, bool = 12.5%)
148///
149/// COW: Cloning a `TypedStorage` increments the Rc refcount (zero copy).
150/// Mutation triggers a deep copy if shared.
151#[derive(Debug)]
152pub struct TypedStorage {
153    /// Raw byte buffer. Alignment: elements are naturally aligned within
154    /// the Vec<u8> because Vec guarantees pointer alignment ≥ 8 bytes
155    /// (on 64-bit platforms). For SIMD (16-byte alignment), use
156    /// AlignedByteSlice for hot paths.
157    bytes: Rc<RefCell<Vec<u8>>>,
158    /// Element type determines byte interpretation.
159    dtype: DType,
160    /// Number of logical elements (NOT bytes).
161    len: usize,
162}
163
164impl TypedStorage {
165    // -- Construction -------------------------------------------------------
166
167    /// Create storage filled with zeros.
168    pub fn zeros(dtype: DType, len: usize) -> Self {
169        let nbytes = len * dtype.byte_width();
170        TypedStorage {
171            bytes: Rc::new(RefCell::new(vec![0u8; nbytes])),
172            dtype,
173            len,
174        }
175    }
176
177    /// Create storage from an existing byte buffer.
178    /// Returns error if byte length doesn't match dtype × element count.
179    pub fn from_bytes(bytes: Vec<u8>, dtype: DType, len: usize) -> Result<Self, String> {
180        let expected = len * dtype.byte_width();
181        if bytes.len() != expected {
182            return Err(format!(
183                "TypedStorage::from_bytes: expected {} bytes ({} × {} elements), got {}",
184                expected,
185                dtype.byte_width(),
186                len,
187                bytes.len()
188            ));
189        }
190        Ok(TypedStorage {
191            bytes: Rc::new(RefCell::new(bytes)),
192            dtype,
193            len,
194        })
195    }
196
197    /// Create f64 storage from a Vec<f64>.
198    pub fn from_f64_vec(data: Vec<f64>) -> Self {
199        let len = data.len();
200        let bytes = f64_vec_to_bytes(data);
201        TypedStorage {
202            bytes: Rc::new(RefCell::new(bytes)),
203            dtype: DType::F64,
204            len,
205        }
206    }
207
208    /// Create i64 storage from a Vec<i64>.
209    pub fn from_i64_vec(data: Vec<i64>) -> Self {
210        let len = data.len();
211        let bytes = i64_vec_to_bytes(data);
212        TypedStorage {
213            bytes: Rc::new(RefCell::new(bytes)),
214            dtype: DType::I64,
215            len,
216        }
217    }
218
219    /// Create f32 storage from a Vec<f32>.
220    pub fn from_f32_vec(data: Vec<f32>) -> Self {
221        let len = data.len();
222        let bytes = f32_vec_to_bytes(data);
223        TypedStorage {
224            bytes: Rc::new(RefCell::new(bytes)),
225            dtype: DType::F32,
226            len,
227        }
228    }
229
230    /// Create i32 storage from a Vec<i32>.
231    pub fn from_i32_vec(data: Vec<i32>) -> Self {
232        let len = data.len();
233        let bytes = i32_vec_to_bytes(data);
234        TypedStorage {
235            bytes: Rc::new(RefCell::new(bytes)),
236            dtype: DType::I32,
237            len,
238        }
239    }
240
241    /// Create u8 storage from a Vec<u8>.
242    pub fn from_u8_vec(data: Vec<u8>) -> Self {
243        let len = data.len();
244        TypedStorage {
245            bytes: Rc::new(RefCell::new(data)),
246            dtype: DType::U8,
247            len,
248        }
249    }
250
251    /// Create bool storage from a Vec<bool>.
252    pub fn from_bool_vec(data: Vec<bool>) -> Self {
253        let len = data.len();
254        let bytes: Vec<u8> = data.iter().map(|&b| if b { 1u8 } else { 0u8 }).collect();
255        TypedStorage {
256            bytes: Rc::new(RefCell::new(bytes)),
257            dtype: DType::Bool,
258            len,
259        }
260    }
261
262    /// Create complex storage from a Vec<ComplexF64>.
263    pub fn from_complex_vec(data: Vec<ComplexF64>) -> Self {
264        let len = data.len();
265        let mut bytes = Vec::with_capacity(len * 16);
266        for c in &data {
267            bytes.extend_from_slice(&c.re.to_le_bytes());
268            bytes.extend_from_slice(&c.im.to_le_bytes());
269        }
270        TypedStorage {
271            bytes: Rc::new(RefCell::new(bytes)),
272            dtype: DType::Complex,
273            len,
274        }
275    }
276
277    /// Create bf16 storage from a Vec<Bf16>.
278    pub fn from_bf16_vec(data: Vec<Bf16>) -> Self {
279        let len = data.len();
280        let mut bytes = Vec::with_capacity(len * 2);
281        for v in &data {
282            bytes.extend_from_slice(&v.0.to_le_bytes());
283        }
284        TypedStorage {
285            bytes: Rc::new(RefCell::new(bytes)),
286            dtype: DType::Bf16,
287            len,
288        }
289    }
290
291    // -- Accessors ----------------------------------------------------------
292
293    /// Element type.
294    pub fn dtype(&self) -> DType {
295        self.dtype
296    }
297
298    /// Number of logical elements.
299    pub fn len(&self) -> usize {
300        self.len
301    }
302
303    /// Whether storage is empty.
304    pub fn is_empty(&self) -> bool {
305        self.len == 0
306    }
307
308    /// Total byte count of the raw buffer.
309    pub fn byte_len(&self) -> usize {
310        self.len * self.dtype.byte_width()
311    }
312
313    /// Number of live references to the underlying byte buffer.
314    pub fn refcount(&self) -> usize {
315        Rc::strong_count(&self.bytes)
316    }
317
318    /// Borrow the raw byte buffer.
319    pub fn borrow_bytes(&self) -> Ref<Vec<u8>> {
320        self.bytes.borrow()
321    }
322
323    /// Clone the raw bytes out (for serialization, etc.).
324    pub fn to_bytes(&self) -> Vec<u8> {
325        self.bytes.borrow().clone()
326    }
327
328    // -- Typed views (read-only) -------------------------------------------
329
330    /// Interpret bytes as f64 slice. Panics if dtype != F64.
331    pub fn as_f64_vec(&self) -> Vec<f64> {
332        assert_eq!(self.dtype, DType::F64, "as_f64_vec: dtype is {}", self.dtype);
333        bytes_to_f64_vec(&self.bytes.borrow())
334    }
335
336    /// Interpret bytes as i64 slice. Panics if dtype != I64.
337    pub fn as_i64_vec(&self) -> Vec<i64> {
338        assert_eq!(self.dtype, DType::I64, "as_i64_vec: dtype is {}", self.dtype);
339        bytes_to_i64_vec(&self.bytes.borrow())
340    }
341
342    /// Interpret bytes as f32 slice. Panics if dtype != F32.
343    pub fn as_f32_vec(&self) -> Vec<f32> {
344        assert_eq!(self.dtype, DType::F32, "as_f32_vec: dtype is {}", self.dtype);
345        bytes_to_f32_vec(&self.bytes.borrow())
346    }
347
348    /// Interpret bytes as i32 slice. Panics if dtype != I32.
349    pub fn as_i32_vec(&self) -> Vec<i32> {
350        assert_eq!(self.dtype, DType::I32, "as_i32_vec: dtype is {}", self.dtype);
351        bytes_to_i32_vec(&self.bytes.borrow())
352    }
353
354    /// Interpret bytes as bool slice. Panics if dtype != Bool.
355    pub fn as_bool_vec(&self) -> Vec<bool> {
356        assert_eq!(self.dtype, DType::Bool, "as_bool_vec: dtype is {}", self.dtype);
357        self.bytes.borrow().iter().map(|&b| b != 0).collect()
358    }
359
360    /// Interpret bytes as u8 slice (trivial — bytes ARE u8). Panics if dtype != U8.
361    pub fn as_u8_vec(&self) -> Vec<u8> {
362        assert_eq!(self.dtype, DType::U8, "as_u8_vec: dtype is {}", self.dtype);
363        self.bytes.borrow().clone()
364    }
365
366    /// Interpret bytes as ComplexF64 slice. Panics if dtype != Complex.
367    pub fn as_complex_vec(&self) -> Vec<ComplexF64> {
368        assert_eq!(self.dtype, DType::Complex, "as_complex_vec: dtype is {}", self.dtype);
369        let raw = self.bytes.borrow();
370        let mut result = Vec::with_capacity(self.len);
371        for i in 0..self.len {
372            let off = i * 16;
373            let re = f64::from_le_bytes(raw[off..off + 8].try_into().unwrap());
374            let im = f64::from_le_bytes(raw[off + 8..off + 16].try_into().unwrap());
375            result.push(ComplexF64 { re, im });
376        }
377        result
378    }
379
380    /// Interpret bytes as Bf16 slice. Panics if dtype != Bf16.
381    pub fn as_bf16_vec(&self) -> Vec<Bf16> {
382        assert_eq!(self.dtype, DType::Bf16, "as_bf16_vec: dtype is {}", self.dtype);
383        let raw = self.bytes.borrow();
384        let mut result = Vec::with_capacity(self.len);
385        for i in 0..self.len {
386            let off = i * 2;
387            let bits = u16::from_le_bytes(raw[off..off + 2].try_into().unwrap());
388            result.push(Bf16(bits));
389        }
390        result
391    }
392
393    /// Convert any numeric dtype to f64 vec (for operations that need f64).
394    /// Bool: false→0.0, true→1.0.
395    pub fn to_f64_vec(&self) -> Vec<f64> {
396        match self.dtype {
397            DType::F64 => self.as_f64_vec(),
398            DType::F32 => self.as_f32_vec().into_iter().map(|v| v as f64).collect(),
399            DType::I64 => self.as_i64_vec().into_iter().map(|v| v as f64).collect(),
400            DType::I32 => self.as_i32_vec().into_iter().map(|v| v as f64).collect(),
401            DType::U8 => self.as_u8_vec().into_iter().map(|v| v as f64).collect(),
402            DType::Bool => self.as_bool_vec().into_iter().map(|v| if v { 1.0 } else { 0.0 }).collect(),
403            DType::Bf16 => self.as_bf16_vec().into_iter().map(|v| v.to_f32() as f64).collect(),
404            DType::F16 => {
405                let raw = self.bytes.borrow();
406                let mut result = Vec::with_capacity(self.len);
407                for i in 0..self.len {
408                    let off = i * 2;
409                    let bits = u16::from_le_bytes(raw[off..off + 2].try_into().unwrap());
410                    result.push(crate::f16::F16(bits).to_f64());
411                }
412                result
413            }
414            DType::Complex => {
415                // Return real parts only for scalar operations
416                self.as_complex_vec().into_iter().map(|c| c.re).collect()
417            }
418        }
419    }
420
421    // -- Element access -----------------------------------------------------
422
423    /// Get a single f64 value at index. Works for any numeric dtype (converts).
424    pub fn get_as_f64(&self, idx: usize) -> Result<f64, RuntimeError> {
425        if idx >= self.len {
426            return Err(RuntimeError::IndexOutOfBounds { index: idx, length: self.len });
427        }
428        let raw = self.bytes.borrow();
429        let bw = self.dtype.byte_width();
430        let off = idx * bw;
431        Ok(match self.dtype {
432            DType::F64 => f64::from_le_bytes(raw[off..off + 8].try_into().unwrap()),
433            DType::F32 => f32::from_le_bytes(raw[off..off + 4].try_into().unwrap()) as f64,
434            DType::I64 => i64::from_le_bytes(raw[off..off + 8].try_into().unwrap()) as f64,
435            DType::I32 => i32::from_le_bytes(raw[off..off + 4].try_into().unwrap()) as f64,
436            DType::U8 => raw[off] as f64,
437            DType::Bool => if raw[off] != 0 { 1.0 } else { 0.0 },
438            DType::Bf16 => {
439                let bits = u16::from_le_bytes(raw[off..off + 2].try_into().unwrap());
440                Bf16(bits).to_f32() as f64
441            }
442            DType::F16 => {
443                let bits = u16::from_le_bytes(raw[off..off + 2].try_into().unwrap());
444                crate::f16::F16(bits).to_f64()
445            }
446            DType::Complex => {
447                f64::from_le_bytes(raw[off..off + 8].try_into().unwrap()) // real part
448            }
449        })
450    }
451
452    /// Set a single f64 value at index. Converts to storage dtype.
453    /// Triggers COW if shared.
454    pub fn set_from_f64(&mut self, idx: usize, val: f64) -> Result<(), RuntimeError> {
455        if idx >= self.len {
456            return Err(RuntimeError::IndexOutOfBounds { index: idx, length: self.len });
457        }
458        self.make_unique();
459        let bw = self.dtype.byte_width();
460        let off = idx * bw;
461        let mut raw = self.bytes.borrow_mut();
462        match self.dtype {
463            DType::F64 => raw[off..off + 8].copy_from_slice(&val.to_le_bytes()),
464            DType::F32 => raw[off..off + 4].copy_from_slice(&(val as f32).to_le_bytes()),
465            DType::I64 => raw[off..off + 8].copy_from_slice(&(val as i64).to_le_bytes()),
466            DType::I32 => raw[off..off + 4].copy_from_slice(&(val as i32).to_le_bytes()),
467            DType::U8 => raw[off] = val as u8,
468            DType::Bool => raw[off] = if val != 0.0 { 1 } else { 0 },
469            DType::Bf16 => {
470                let bits = Bf16::from_f32(val as f32).0;
471                raw[off..off + 2].copy_from_slice(&bits.to_le_bytes());
472            }
473            DType::F16 => {
474                let bits = crate::f16::F16::from_f64(val).0;
475                raw[off..off + 2].copy_from_slice(&bits.to_le_bytes());
476            }
477            DType::Complex => {
478                raw[off..off + 8].copy_from_slice(&val.to_le_bytes());
479                raw[off + 8..off + 16].copy_from_slice(&0.0f64.to_le_bytes());
480            }
481        }
482        Ok(())
483    }
484
485    // -- COW ----------------------------------------------------------------
486
487    /// Ensure exclusive ownership. If shared, deep-copy the byte buffer.
488    pub fn make_unique(&mut self) {
489        if Rc::strong_count(&self.bytes) > 1 {
490            let data = self.bytes.borrow().clone();
491            self.bytes = Rc::new(RefCell::new(data));
492        }
493    }
494
495    /// Force a deep copy, returning a new TypedStorage that does not share.
496    pub fn deep_clone(&self) -> TypedStorage {
497        TypedStorage {
498            bytes: Rc::new(RefCell::new(self.bytes.borrow().clone())),
499            dtype: self.dtype,
500            len: self.len,
501        }
502    }
503
504    // -- Reductions ---------------------------------------------------------
505
506    /// Sum all elements as f64. Uses BinnedAccumulator for float types.
507    pub fn sum_f64(&self) -> f64 {
508        let data = self.to_f64_vec();
509        if self.dtype.is_float() || self.dtype == DType::Complex {
510            binned_sum_f64(&data)
511        } else {
512            // Integer types: exact sum (no accumulator needed)
513            data.iter().sum()
514        }
515    }
516
517    /// Mean of all elements as f64.
518    pub fn mean_f64(&self) -> f64 {
519        if self.len == 0 {
520            return f64::NAN;
521        }
522        self.sum_f64() / self.len as f64
523    }
524
525    // -- Type casting -------------------------------------------------------
526
527    /// Cast to a different dtype. Returns a new TypedStorage.
528    pub fn cast(&self, target: DType) -> TypedStorage {
529        if self.dtype == target {
530            return self.deep_clone();
531        }
532        let f64_data = self.to_f64_vec();
533        match target {
534            DType::F64 => TypedStorage::from_f64_vec(f64_data),
535            DType::F32 => TypedStorage::from_f32_vec(f64_data.into_iter().map(|v| v as f32).collect()),
536            DType::I64 => TypedStorage::from_i64_vec(f64_data.into_iter().map(|v| v as i64).collect()),
537            DType::I32 => TypedStorage::from_i32_vec(f64_data.into_iter().map(|v| v as i32).collect()),
538            DType::U8 => TypedStorage::from_u8_vec(f64_data.into_iter().map(|v| v as u8).collect()),
539            DType::Bool => TypedStorage::from_bool_vec(f64_data.into_iter().map(|v| v != 0.0).collect()),
540            DType::Bf16 => TypedStorage::from_bf16_vec(f64_data.into_iter().map(|v| Bf16::from_f32(v as f32)).collect()),
541            DType::F16 => {
542                let mut bytes = Vec::with_capacity(f64_data.len() * 2);
543                for v in &f64_data {
544                    let bits = crate::f16::F16::from_f64(*v).0;
545                    bytes.extend_from_slice(&bits.to_le_bytes());
546                }
547                TypedStorage {
548                    bytes: Rc::new(RefCell::new(bytes)),
549                    dtype: DType::F16,
550                    len: f64_data.len(),
551                }
552            }
553            DType::Complex => TypedStorage::from_complex_vec(
554                f64_data.into_iter().map(|v| ComplexF64 { re: v, im: 0.0 }).collect()
555            ),
556        }
557    }
558}
559
560impl Clone for TypedStorage {
561    /// Cloning increments refcount — zero copy (COW).
562    fn clone(&self) -> Self {
563        TypedStorage {
564            bytes: Rc::clone(&self.bytes),
565            dtype: self.dtype,
566            len: self.len,
567        }
568    }
569}
570
571// ---------------------------------------------------------------------------
572// Byte conversion helpers (little-endian, deterministic)
573// ---------------------------------------------------------------------------
574
575fn f64_vec_to_bytes(data: Vec<f64>) -> Vec<u8> {
576    let mut bytes = Vec::with_capacity(data.len() * 8);
577    for v in &data {
578        bytes.extend_from_slice(&v.to_le_bytes());
579    }
580    bytes
581}
582
583fn bytes_to_f64_vec(bytes: &[u8]) -> Vec<f64> {
584    let n = bytes.len() / 8;
585    let mut result = Vec::with_capacity(n);
586    for i in 0..n {
587        let off = i * 8;
588        result.push(f64::from_le_bytes(bytes[off..off + 8].try_into().unwrap()));
589    }
590    result
591}
592
593fn i64_vec_to_bytes(data: Vec<i64>) -> Vec<u8> {
594    let mut bytes = Vec::with_capacity(data.len() * 8);
595    for v in &data {
596        bytes.extend_from_slice(&v.to_le_bytes());
597    }
598    bytes
599}
600
601fn bytes_to_i64_vec(bytes: &[u8]) -> Vec<i64> {
602    let n = bytes.len() / 8;
603    let mut result = Vec::with_capacity(n);
604    for i in 0..n {
605        let off = i * 8;
606        result.push(i64::from_le_bytes(bytes[off..off + 8].try_into().unwrap()));
607    }
608    result
609}
610
611fn f32_vec_to_bytes(data: Vec<f32>) -> Vec<u8> {
612    let mut bytes = Vec::with_capacity(data.len() * 4);
613    for v in &data {
614        bytes.extend_from_slice(&v.to_le_bytes());
615    }
616    bytes
617}
618
619fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> {
620    let n = bytes.len() / 4;
621    let mut result = Vec::with_capacity(n);
622    for i in 0..n {
623        let off = i * 4;
624        result.push(f32::from_le_bytes(bytes[off..off + 4].try_into().unwrap()));
625    }
626    result
627}
628
629fn i32_vec_to_bytes(data: Vec<i32>) -> Vec<u8> {
630    let mut bytes = Vec::with_capacity(data.len() * 4);
631    for v in &data {
632        bytes.extend_from_slice(&v.to_le_bytes());
633    }
634    bytes
635}
636
637fn bytes_to_i32_vec(bytes: &[u8]) -> Vec<i32> {
638    let n = bytes.len() / 4;
639    let mut result = Vec::with_capacity(n);
640    for i in 0..n {
641        let off = i * 4;
642        result.push(i32::from_le_bytes(bytes[off..off + 4].try_into().unwrap()));
643    }
644    result
645}
646
647// ---------------------------------------------------------------------------
648// Tests
649// ---------------------------------------------------------------------------
650
651#[cfg(test)]
652mod tests {
653    use super::*;
654
655    #[test]
656    fn test_dtype_byte_width() {
657        assert_eq!(DType::F64.byte_width(), 8);
658        assert_eq!(DType::F32.byte_width(), 4);
659        assert_eq!(DType::I64.byte_width(), 8);
660        assert_eq!(DType::I32.byte_width(), 4);
661        assert_eq!(DType::U8.byte_width(), 1);
662        assert_eq!(DType::Bool.byte_width(), 1);
663        assert_eq!(DType::Bf16.byte_width(), 2);
664        assert_eq!(DType::F16.byte_width(), 2);
665        assert_eq!(DType::Complex.byte_width(), 16);
666    }
667
668    #[test]
669    fn test_dtype_snap_roundtrip() {
670        for dt in &[DType::F64, DType::F32, DType::I64, DType::I32,
671                    DType::U8, DType::Bool, DType::Bf16, DType::F16, DType::Complex] {
672            assert_eq!(DType::from_snap_tag(dt.snap_tag()).unwrap(), *dt);
673        }
674    }
675
676    #[test]
677    fn test_f64_storage_roundtrip() {
678        let data = vec![1.5, -2.3, 0.0, f64::INFINITY, f64::NEG_INFINITY];
679        let storage = TypedStorage::from_f64_vec(data.clone());
680        assert_eq!(storage.dtype(), DType::F64);
681        assert_eq!(storage.len(), 5);
682        assert_eq!(storage.as_f64_vec(), data);
683    }
684
685    #[test]
686    fn test_i64_storage_roundtrip() {
687        let data = vec![1i64, -2, 0, i64::MAX, i64::MIN];
688        let storage = TypedStorage::from_i64_vec(data.clone());
689        assert_eq!(storage.dtype(), DType::I64);
690        assert_eq!(storage.as_i64_vec(), data);
691    }
692
693    #[test]
694    fn test_f32_storage_roundtrip() {
695        let data = vec![1.0f32, -2.5, 0.0, 3.14];
696        let storage = TypedStorage::from_f32_vec(data.clone());
697        assert_eq!(storage.dtype(), DType::F32);
698        assert_eq!(storage.as_f32_vec(), data);
699    }
700
701    #[test]
702    fn test_i32_storage_roundtrip() {
703        let data = vec![42i32, -1, 0, i32::MAX];
704        let storage = TypedStorage::from_i32_vec(data.clone());
705        assert_eq!(storage.as_i32_vec(), data);
706    }
707
708    #[test]
709    fn test_u8_storage_roundtrip() {
710        let data = vec![0u8, 127, 255];
711        let storage = TypedStorage::from_u8_vec(data.clone());
712        assert_eq!(storage.as_u8_vec(), data);
713    }
714
715    #[test]
716    fn test_bool_storage_roundtrip() {
717        let data = vec![true, false, true, true, false];
718        let storage = TypedStorage::from_bool_vec(data.clone());
719        assert_eq!(storage.as_bool_vec(), data);
720    }
721
722    #[test]
723    fn test_complex_storage_roundtrip() {
724        let data = vec![
725            ComplexF64 { re: 1.0, im: 2.0 },
726            ComplexF64 { re: -3.0, im: 0.5 },
727        ];
728        let storage = TypedStorage::from_complex_vec(data.clone());
729        let back = storage.as_complex_vec();
730        assert_eq!(back.len(), 2);
731        assert_eq!(back[0].re, 1.0);
732        assert_eq!(back[0].im, 2.0);
733        assert_eq!(back[1].re, -3.0);
734        assert_eq!(back[1].im, 0.5);
735    }
736
737    #[test]
738    fn test_bf16_storage_roundtrip() {
739        let data = vec![Bf16::from_f32(1.0), Bf16::from_f32(-0.5)];
740        let storage = TypedStorage::from_bf16_vec(data.clone());
741        let back = storage.as_bf16_vec();
742        assert_eq!(back[0].to_f32(), 1.0);
743        assert_eq!(back[1].to_f32(), -0.5);
744    }
745
746    #[test]
747    fn test_cow_semantics() {
748        let s1 = TypedStorage::from_f64_vec(vec![1.0, 2.0, 3.0]);
749        let s2 = s1.clone();
750        assert_eq!(s1.refcount(), 2);
751        assert_eq!(s2.refcount(), 2);
752
753        let s3 = s1.deep_clone();
754        assert_eq!(s3.refcount(), 1);
755        assert_eq!(s1.refcount(), 2); // s1 and s2 still share
756    }
757
758    #[test]
759    fn test_cow_mutation() {
760        let s1 = TypedStorage::from_f64_vec(vec![1.0, 2.0, 3.0]);
761        let mut s2 = s1.clone();
762        assert_eq!(s1.refcount(), 2);
763
764        s2.set_from_f64(0, 99.0).unwrap();
765        assert_eq!(s1.refcount(), 1); // s1 no longer shared
766        assert_eq!(s2.refcount(), 1);
767        assert_eq!(s1.as_f64_vec()[0], 1.0); // unchanged
768        assert_eq!(s2.as_f64_vec()[0], 99.0); // mutated copy
769    }
770
771    #[test]
772    fn test_get_set_f64() {
773        let mut storage = TypedStorage::from_f64_vec(vec![10.0, 20.0, 30.0]);
774        assert_eq!(storage.get_as_f64(0).unwrap(), 10.0);
775        assert_eq!(storage.get_as_f64(2).unwrap(), 30.0);
776        assert!(storage.get_as_f64(3).is_err());
777
778        storage.set_from_f64(1, 99.0).unwrap();
779        assert_eq!(storage.get_as_f64(1).unwrap(), 99.0);
780    }
781
782    #[test]
783    fn test_get_set_i64() {
784        let mut storage = TypedStorage::from_i64_vec(vec![10, 20, 30]);
785        assert_eq!(storage.get_as_f64(0).unwrap(), 10.0);
786        storage.set_from_f64(1, 42.0).unwrap();
787        assert_eq!(storage.as_i64_vec()[1], 42);
788    }
789
790    #[test]
791    fn test_to_f64_vec_conversion() {
792        let storage = TypedStorage::from_i32_vec(vec![1, 2, 3]);
793        assert_eq!(storage.to_f64_vec(), vec![1.0, 2.0, 3.0]);
794
795        let storage = TypedStorage::from_bool_vec(vec![true, false, true]);
796        assert_eq!(storage.to_f64_vec(), vec![1.0, 0.0, 1.0]);
797    }
798
799    #[test]
800    fn test_sum_f64() {
801        let storage = TypedStorage::from_f64_vec(vec![1.0, 2.0, 3.0, 4.0]);
802        assert!((storage.sum_f64() - 10.0).abs() < 1e-12);
803
804        let storage = TypedStorage::from_i64_vec(vec![1, 2, 3, 4]);
805        assert!((storage.sum_f64() - 10.0).abs() < 1e-12);
806    }
807
808    #[test]
809    fn test_cast_f64_to_i64() {
810        let s = TypedStorage::from_f64_vec(vec![1.5, -2.7, 3.0]);
811        let c = s.cast(DType::I64);
812        assert_eq!(c.dtype(), DType::I64);
813        assert_eq!(c.as_i64_vec(), vec![1, -2, 3]);
814    }
815
816    #[test]
817    fn test_cast_i64_to_f32() {
818        let s = TypedStorage::from_i64_vec(vec![1, 2, 3]);
819        let c = s.cast(DType::F32);
820        assert_eq!(c.dtype(), DType::F32);
821        assert_eq!(c.as_f32_vec(), vec![1.0f32, 2.0, 3.0]);
822    }
823
824    #[test]
825    fn test_zeros_all_dtypes() {
826        for dt in &[DType::F64, DType::F32, DType::I64, DType::I32,
827                    DType::U8, DType::Bool, DType::Bf16, DType::F16, DType::Complex] {
828            let s = TypedStorage::zeros(*dt, 10);
829            assert_eq!(s.len(), 10);
830            assert_eq!(s.byte_len(), 10 * dt.byte_width());
831            // All zero bytes → all zero values
832            assert!((s.get_as_f64(0).unwrap()).abs() < 1e-15 || s.get_as_f64(0).unwrap() == 0.0);
833        }
834    }
835
836    #[test]
837    fn test_byte_determinism() {
838        // Same data → identical bytes (deterministic encoding)
839        let s1 = TypedStorage::from_f64_vec(vec![1.0, 2.0, 3.0]);
840        let s2 = TypedStorage::from_f64_vec(vec![1.0, 2.0, 3.0]);
841        assert_eq!(s1.to_bytes(), s2.to_bytes());
842
843        let s3 = TypedStorage::from_i64_vec(vec![42, -1, 0]);
844        let s4 = TypedStorage::from_i64_vec(vec![42, -1, 0]);
845        assert_eq!(s3.to_bytes(), s4.to_bytes());
846    }
847
848    #[test]
849    fn test_from_bytes_roundtrip() {
850        let original = TypedStorage::from_f64_vec(vec![1.5, -2.3, 0.0]);
851        let bytes = original.to_bytes();
852        let restored = TypedStorage::from_bytes(bytes, DType::F64, 3).unwrap();
853        assert_eq!(original.as_f64_vec(), restored.as_f64_vec());
854    }
855
856    #[test]
857    fn test_from_bytes_size_mismatch() {
858        assert!(TypedStorage::from_bytes(vec![0u8; 10], DType::F64, 2).is_err());
859    }
860}