1use std::cell::{Ref, RefCell};
16use std::rc::Rc;
17
18use crate::accumulator::binned_sum_f64;
19use crate::complex::ComplexF64;
20use crate::error::RuntimeError;
21use crate::value::Bf16;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
32pub enum DType {
33 F64,
35 F32,
37 I64,
39 I32,
41 U8,
43 Bool,
47 Bf16,
49 F16,
51 Complex,
53}
54
55impl DType {
56 pub fn byte_width(&self) -> usize {
58 match self {
59 DType::F64 | DType::I64 => 8,
60 DType::F32 | DType::I32 => 4,
61 DType::Bf16 | DType::F16 => 2,
62 DType::U8 | DType::Bool => 1,
63 DType::Complex => 16,
64 }
65 }
66
67 pub fn name(&self) -> &'static str {
69 match self {
70 DType::F64 => "f64",
71 DType::F32 => "f32",
72 DType::I64 => "i64",
73 DType::I32 => "i32",
74 DType::U8 => "u8",
75 DType::Bool => "bool",
76 DType::Bf16 => "bf16",
77 DType::F16 => "f16",
78 DType::Complex => "complex",
79 }
80 }
81
82 pub fn is_float(&self) -> bool {
84 matches!(self, DType::F64 | DType::F32 | DType::Bf16 | DType::F16)
85 }
86
87 pub fn is_int(&self) -> bool {
89 matches!(self, DType::I64 | DType::I32 | DType::U8)
90 }
91
92 pub fn is_numeric(&self) -> bool {
94 !matches!(self, DType::Bool)
95 }
96
97 pub fn snap_tag(&self) -> u8 {
99 match self {
100 DType::F64 => 0,
101 DType::F32 => 1,
102 DType::I64 => 2,
103 DType::I32 => 3,
104 DType::U8 => 4,
105 DType::Bool => 5,
106 DType::Bf16 => 6,
107 DType::F16 => 7,
108 DType::Complex => 8,
109 }
110 }
111
112 pub fn from_snap_tag(tag: u8) -> Result<Self, String> {
114 match tag {
115 0 => Ok(DType::F64),
116 1 => Ok(DType::F32),
117 2 => Ok(DType::I64),
118 3 => Ok(DType::I32),
119 4 => Ok(DType::U8),
120 5 => Ok(DType::Bool),
121 6 => Ok(DType::Bf16),
122 7 => Ok(DType::F16),
123 8 => Ok(DType::Complex),
124 _ => Err(format!("unknown dtype snap tag: {tag}")),
125 }
126 }
127}
128
129impl std::fmt::Display for DType {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 write!(f, "{}", self.name())
132 }
133}
134
135#[derive(Debug)]
152pub struct TypedStorage {
153 bytes: Rc<RefCell<Vec<u8>>>,
158 dtype: DType,
160 len: usize,
162}
163
164impl TypedStorage {
165 pub fn zeros(dtype: DType, len: usize) -> Self {
169 let nbytes = len * dtype.byte_width();
170 TypedStorage {
171 bytes: Rc::new(RefCell::new(vec![0u8; nbytes])),
172 dtype,
173 len,
174 }
175 }
176
177 pub fn from_bytes(bytes: Vec<u8>, dtype: DType, len: usize) -> Result<Self, String> {
180 let expected = len * dtype.byte_width();
181 if bytes.len() != expected {
182 return Err(format!(
183 "TypedStorage::from_bytes: expected {} bytes ({} × {} elements), got {}",
184 expected,
185 dtype.byte_width(),
186 len,
187 bytes.len()
188 ));
189 }
190 Ok(TypedStorage {
191 bytes: Rc::new(RefCell::new(bytes)),
192 dtype,
193 len,
194 })
195 }
196
197 pub fn from_f64_vec(data: Vec<f64>) -> Self {
199 let len = data.len();
200 let bytes = f64_vec_to_bytes(data);
201 TypedStorage {
202 bytes: Rc::new(RefCell::new(bytes)),
203 dtype: DType::F64,
204 len,
205 }
206 }
207
208 pub fn from_i64_vec(data: Vec<i64>) -> Self {
210 let len = data.len();
211 let bytes = i64_vec_to_bytes(data);
212 TypedStorage {
213 bytes: Rc::new(RefCell::new(bytes)),
214 dtype: DType::I64,
215 len,
216 }
217 }
218
219 pub fn from_f32_vec(data: Vec<f32>) -> Self {
221 let len = data.len();
222 let bytes = f32_vec_to_bytes(data);
223 TypedStorage {
224 bytes: Rc::new(RefCell::new(bytes)),
225 dtype: DType::F32,
226 len,
227 }
228 }
229
230 pub fn from_i32_vec(data: Vec<i32>) -> Self {
232 let len = data.len();
233 let bytes = i32_vec_to_bytes(data);
234 TypedStorage {
235 bytes: Rc::new(RefCell::new(bytes)),
236 dtype: DType::I32,
237 len,
238 }
239 }
240
241 pub fn from_u8_vec(data: Vec<u8>) -> Self {
243 let len = data.len();
244 TypedStorage {
245 bytes: Rc::new(RefCell::new(data)),
246 dtype: DType::U8,
247 len,
248 }
249 }
250
251 pub fn from_bool_vec(data: Vec<bool>) -> Self {
253 let len = data.len();
254 let bytes: Vec<u8> = data.iter().map(|&b| if b { 1u8 } else { 0u8 }).collect();
255 TypedStorage {
256 bytes: Rc::new(RefCell::new(bytes)),
257 dtype: DType::Bool,
258 len,
259 }
260 }
261
262 pub fn from_complex_vec(data: Vec<ComplexF64>) -> Self {
264 let len = data.len();
265 let mut bytes = Vec::with_capacity(len * 16);
266 for c in &data {
267 bytes.extend_from_slice(&c.re.to_le_bytes());
268 bytes.extend_from_slice(&c.im.to_le_bytes());
269 }
270 TypedStorage {
271 bytes: Rc::new(RefCell::new(bytes)),
272 dtype: DType::Complex,
273 len,
274 }
275 }
276
277 pub fn from_bf16_vec(data: Vec<Bf16>) -> Self {
279 let len = data.len();
280 let mut bytes = Vec::with_capacity(len * 2);
281 for v in &data {
282 bytes.extend_from_slice(&v.0.to_le_bytes());
283 }
284 TypedStorage {
285 bytes: Rc::new(RefCell::new(bytes)),
286 dtype: DType::Bf16,
287 len,
288 }
289 }
290
291 pub fn dtype(&self) -> DType {
295 self.dtype
296 }
297
298 pub fn len(&self) -> usize {
300 self.len
301 }
302
303 pub fn is_empty(&self) -> bool {
305 self.len == 0
306 }
307
308 pub fn byte_len(&self) -> usize {
310 self.len * self.dtype.byte_width()
311 }
312
313 pub fn refcount(&self) -> usize {
315 Rc::strong_count(&self.bytes)
316 }
317
318 pub fn borrow_bytes(&self) -> Ref<Vec<u8>> {
320 self.bytes.borrow()
321 }
322
323 pub fn to_bytes(&self) -> Vec<u8> {
325 self.bytes.borrow().clone()
326 }
327
328 pub fn as_f64_vec(&self) -> Vec<f64> {
332 assert_eq!(self.dtype, DType::F64, "as_f64_vec: dtype is {}", self.dtype);
333 bytes_to_f64_vec(&self.bytes.borrow())
334 }
335
336 pub fn as_i64_vec(&self) -> Vec<i64> {
338 assert_eq!(self.dtype, DType::I64, "as_i64_vec: dtype is {}", self.dtype);
339 bytes_to_i64_vec(&self.bytes.borrow())
340 }
341
342 pub fn as_f32_vec(&self) -> Vec<f32> {
344 assert_eq!(self.dtype, DType::F32, "as_f32_vec: dtype is {}", self.dtype);
345 bytes_to_f32_vec(&self.bytes.borrow())
346 }
347
348 pub fn as_i32_vec(&self) -> Vec<i32> {
350 assert_eq!(self.dtype, DType::I32, "as_i32_vec: dtype is {}", self.dtype);
351 bytes_to_i32_vec(&self.bytes.borrow())
352 }
353
354 pub fn as_bool_vec(&self) -> Vec<bool> {
356 assert_eq!(self.dtype, DType::Bool, "as_bool_vec: dtype is {}", self.dtype);
357 self.bytes.borrow().iter().map(|&b| b != 0).collect()
358 }
359
360 pub fn as_u8_vec(&self) -> Vec<u8> {
362 assert_eq!(self.dtype, DType::U8, "as_u8_vec: dtype is {}", self.dtype);
363 self.bytes.borrow().clone()
364 }
365
366 pub fn as_complex_vec(&self) -> Vec<ComplexF64> {
368 assert_eq!(self.dtype, DType::Complex, "as_complex_vec: dtype is {}", self.dtype);
369 let raw = self.bytes.borrow();
370 let mut result = Vec::with_capacity(self.len);
371 for i in 0..self.len {
372 let off = i * 16;
373 let re = f64::from_le_bytes(raw[off..off + 8].try_into().unwrap());
374 let im = f64::from_le_bytes(raw[off + 8..off + 16].try_into().unwrap());
375 result.push(ComplexF64 { re, im });
376 }
377 result
378 }
379
380 pub fn as_bf16_vec(&self) -> Vec<Bf16> {
382 assert_eq!(self.dtype, DType::Bf16, "as_bf16_vec: dtype is {}", self.dtype);
383 let raw = self.bytes.borrow();
384 let mut result = Vec::with_capacity(self.len);
385 for i in 0..self.len {
386 let off = i * 2;
387 let bits = u16::from_le_bytes(raw[off..off + 2].try_into().unwrap());
388 result.push(Bf16(bits));
389 }
390 result
391 }
392
393 pub fn to_f64_vec(&self) -> Vec<f64> {
396 match self.dtype {
397 DType::F64 => self.as_f64_vec(),
398 DType::F32 => self.as_f32_vec().into_iter().map(|v| v as f64).collect(),
399 DType::I64 => self.as_i64_vec().into_iter().map(|v| v as f64).collect(),
400 DType::I32 => self.as_i32_vec().into_iter().map(|v| v as f64).collect(),
401 DType::U8 => self.as_u8_vec().into_iter().map(|v| v as f64).collect(),
402 DType::Bool => self.as_bool_vec().into_iter().map(|v| if v { 1.0 } else { 0.0 }).collect(),
403 DType::Bf16 => self.as_bf16_vec().into_iter().map(|v| v.to_f32() as f64).collect(),
404 DType::F16 => {
405 let raw = self.bytes.borrow();
406 let mut result = Vec::with_capacity(self.len);
407 for i in 0..self.len {
408 let off = i * 2;
409 let bits = u16::from_le_bytes(raw[off..off + 2].try_into().unwrap());
410 result.push(crate::f16::F16(bits).to_f64());
411 }
412 result
413 }
414 DType::Complex => {
415 self.as_complex_vec().into_iter().map(|c| c.re).collect()
417 }
418 }
419 }
420
421 pub fn get_as_f64(&self, idx: usize) -> Result<f64, RuntimeError> {
425 if idx >= self.len {
426 return Err(RuntimeError::IndexOutOfBounds { index: idx, length: self.len });
427 }
428 let raw = self.bytes.borrow();
429 let bw = self.dtype.byte_width();
430 let off = idx * bw;
431 Ok(match self.dtype {
432 DType::F64 => f64::from_le_bytes(raw[off..off + 8].try_into().unwrap()),
433 DType::F32 => f32::from_le_bytes(raw[off..off + 4].try_into().unwrap()) as f64,
434 DType::I64 => i64::from_le_bytes(raw[off..off + 8].try_into().unwrap()) as f64,
435 DType::I32 => i32::from_le_bytes(raw[off..off + 4].try_into().unwrap()) as f64,
436 DType::U8 => raw[off] as f64,
437 DType::Bool => if raw[off] != 0 { 1.0 } else { 0.0 },
438 DType::Bf16 => {
439 let bits = u16::from_le_bytes(raw[off..off + 2].try_into().unwrap());
440 Bf16(bits).to_f32() as f64
441 }
442 DType::F16 => {
443 let bits = u16::from_le_bytes(raw[off..off + 2].try_into().unwrap());
444 crate::f16::F16(bits).to_f64()
445 }
446 DType::Complex => {
447 f64::from_le_bytes(raw[off..off + 8].try_into().unwrap()) }
449 })
450 }
451
452 pub fn set_from_f64(&mut self, idx: usize, val: f64) -> Result<(), RuntimeError> {
455 if idx >= self.len {
456 return Err(RuntimeError::IndexOutOfBounds { index: idx, length: self.len });
457 }
458 self.make_unique();
459 let bw = self.dtype.byte_width();
460 let off = idx * bw;
461 let mut raw = self.bytes.borrow_mut();
462 match self.dtype {
463 DType::F64 => raw[off..off + 8].copy_from_slice(&val.to_le_bytes()),
464 DType::F32 => raw[off..off + 4].copy_from_slice(&(val as f32).to_le_bytes()),
465 DType::I64 => raw[off..off + 8].copy_from_slice(&(val as i64).to_le_bytes()),
466 DType::I32 => raw[off..off + 4].copy_from_slice(&(val as i32).to_le_bytes()),
467 DType::U8 => raw[off] = val as u8,
468 DType::Bool => raw[off] = if val != 0.0 { 1 } else { 0 },
469 DType::Bf16 => {
470 let bits = Bf16::from_f32(val as f32).0;
471 raw[off..off + 2].copy_from_slice(&bits.to_le_bytes());
472 }
473 DType::F16 => {
474 let bits = crate::f16::F16::from_f64(val).0;
475 raw[off..off + 2].copy_from_slice(&bits.to_le_bytes());
476 }
477 DType::Complex => {
478 raw[off..off + 8].copy_from_slice(&val.to_le_bytes());
479 raw[off + 8..off + 16].copy_from_slice(&0.0f64.to_le_bytes());
480 }
481 }
482 Ok(())
483 }
484
485 pub fn make_unique(&mut self) {
489 if Rc::strong_count(&self.bytes) > 1 {
490 let data = self.bytes.borrow().clone();
491 self.bytes = Rc::new(RefCell::new(data));
492 }
493 }
494
495 pub fn deep_clone(&self) -> TypedStorage {
497 TypedStorage {
498 bytes: Rc::new(RefCell::new(self.bytes.borrow().clone())),
499 dtype: self.dtype,
500 len: self.len,
501 }
502 }
503
504 pub fn sum_f64(&self) -> f64 {
508 let data = self.to_f64_vec();
509 if self.dtype.is_float() || self.dtype == DType::Complex {
510 binned_sum_f64(&data)
511 } else {
512 data.iter().sum()
514 }
515 }
516
517 pub fn mean_f64(&self) -> f64 {
519 if self.len == 0 {
520 return f64::NAN;
521 }
522 self.sum_f64() / self.len as f64
523 }
524
525 pub fn cast(&self, target: DType) -> TypedStorage {
529 if self.dtype == target {
530 return self.deep_clone();
531 }
532 let f64_data = self.to_f64_vec();
533 match target {
534 DType::F64 => TypedStorage::from_f64_vec(f64_data),
535 DType::F32 => TypedStorage::from_f32_vec(f64_data.into_iter().map(|v| v as f32).collect()),
536 DType::I64 => TypedStorage::from_i64_vec(f64_data.into_iter().map(|v| v as i64).collect()),
537 DType::I32 => TypedStorage::from_i32_vec(f64_data.into_iter().map(|v| v as i32).collect()),
538 DType::U8 => TypedStorage::from_u8_vec(f64_data.into_iter().map(|v| v as u8).collect()),
539 DType::Bool => TypedStorage::from_bool_vec(f64_data.into_iter().map(|v| v != 0.0).collect()),
540 DType::Bf16 => TypedStorage::from_bf16_vec(f64_data.into_iter().map(|v| Bf16::from_f32(v as f32)).collect()),
541 DType::F16 => {
542 let mut bytes = Vec::with_capacity(f64_data.len() * 2);
543 for v in &f64_data {
544 let bits = crate::f16::F16::from_f64(*v).0;
545 bytes.extend_from_slice(&bits.to_le_bytes());
546 }
547 TypedStorage {
548 bytes: Rc::new(RefCell::new(bytes)),
549 dtype: DType::F16,
550 len: f64_data.len(),
551 }
552 }
553 DType::Complex => TypedStorage::from_complex_vec(
554 f64_data.into_iter().map(|v| ComplexF64 { re: v, im: 0.0 }).collect()
555 ),
556 }
557 }
558}
559
560impl Clone for TypedStorage {
561 fn clone(&self) -> Self {
563 TypedStorage {
564 bytes: Rc::clone(&self.bytes),
565 dtype: self.dtype,
566 len: self.len,
567 }
568 }
569}
570
571fn f64_vec_to_bytes(data: Vec<f64>) -> Vec<u8> {
576 let mut bytes = Vec::with_capacity(data.len() * 8);
577 for v in &data {
578 bytes.extend_from_slice(&v.to_le_bytes());
579 }
580 bytes
581}
582
583fn bytes_to_f64_vec(bytes: &[u8]) -> Vec<f64> {
584 let n = bytes.len() / 8;
585 let mut result = Vec::with_capacity(n);
586 for i in 0..n {
587 let off = i * 8;
588 result.push(f64::from_le_bytes(bytes[off..off + 8].try_into().unwrap()));
589 }
590 result
591}
592
593fn i64_vec_to_bytes(data: Vec<i64>) -> Vec<u8> {
594 let mut bytes = Vec::with_capacity(data.len() * 8);
595 for v in &data {
596 bytes.extend_from_slice(&v.to_le_bytes());
597 }
598 bytes
599}
600
601fn bytes_to_i64_vec(bytes: &[u8]) -> Vec<i64> {
602 let n = bytes.len() / 8;
603 let mut result = Vec::with_capacity(n);
604 for i in 0..n {
605 let off = i * 8;
606 result.push(i64::from_le_bytes(bytes[off..off + 8].try_into().unwrap()));
607 }
608 result
609}
610
611fn f32_vec_to_bytes(data: Vec<f32>) -> Vec<u8> {
612 let mut bytes = Vec::with_capacity(data.len() * 4);
613 for v in &data {
614 bytes.extend_from_slice(&v.to_le_bytes());
615 }
616 bytes
617}
618
619fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> {
620 let n = bytes.len() / 4;
621 let mut result = Vec::with_capacity(n);
622 for i in 0..n {
623 let off = i * 4;
624 result.push(f32::from_le_bytes(bytes[off..off + 4].try_into().unwrap()));
625 }
626 result
627}
628
629fn i32_vec_to_bytes(data: Vec<i32>) -> Vec<u8> {
630 let mut bytes = Vec::with_capacity(data.len() * 4);
631 for v in &data {
632 bytes.extend_from_slice(&v.to_le_bytes());
633 }
634 bytes
635}
636
637fn bytes_to_i32_vec(bytes: &[u8]) -> Vec<i32> {
638 let n = bytes.len() / 4;
639 let mut result = Vec::with_capacity(n);
640 for i in 0..n {
641 let off = i * 4;
642 result.push(i32::from_le_bytes(bytes[off..off + 4].try_into().unwrap()));
643 }
644 result
645}
646
647#[cfg(test)]
652mod tests {
653 use super::*;
654
655 #[test]
656 fn test_dtype_byte_width() {
657 assert_eq!(DType::F64.byte_width(), 8);
658 assert_eq!(DType::F32.byte_width(), 4);
659 assert_eq!(DType::I64.byte_width(), 8);
660 assert_eq!(DType::I32.byte_width(), 4);
661 assert_eq!(DType::U8.byte_width(), 1);
662 assert_eq!(DType::Bool.byte_width(), 1);
663 assert_eq!(DType::Bf16.byte_width(), 2);
664 assert_eq!(DType::F16.byte_width(), 2);
665 assert_eq!(DType::Complex.byte_width(), 16);
666 }
667
668 #[test]
669 fn test_dtype_snap_roundtrip() {
670 for dt in &[DType::F64, DType::F32, DType::I64, DType::I32,
671 DType::U8, DType::Bool, DType::Bf16, DType::F16, DType::Complex] {
672 assert_eq!(DType::from_snap_tag(dt.snap_tag()).unwrap(), *dt);
673 }
674 }
675
676 #[test]
677 fn test_f64_storage_roundtrip() {
678 let data = vec![1.5, -2.3, 0.0, f64::INFINITY, f64::NEG_INFINITY];
679 let storage = TypedStorage::from_f64_vec(data.clone());
680 assert_eq!(storage.dtype(), DType::F64);
681 assert_eq!(storage.len(), 5);
682 assert_eq!(storage.as_f64_vec(), data);
683 }
684
685 #[test]
686 fn test_i64_storage_roundtrip() {
687 let data = vec![1i64, -2, 0, i64::MAX, i64::MIN];
688 let storage = TypedStorage::from_i64_vec(data.clone());
689 assert_eq!(storage.dtype(), DType::I64);
690 assert_eq!(storage.as_i64_vec(), data);
691 }
692
693 #[test]
694 fn test_f32_storage_roundtrip() {
695 let data = vec![1.0f32, -2.5, 0.0, 3.14];
696 let storage = TypedStorage::from_f32_vec(data.clone());
697 assert_eq!(storage.dtype(), DType::F32);
698 assert_eq!(storage.as_f32_vec(), data);
699 }
700
701 #[test]
702 fn test_i32_storage_roundtrip() {
703 let data = vec![42i32, -1, 0, i32::MAX];
704 let storage = TypedStorage::from_i32_vec(data.clone());
705 assert_eq!(storage.as_i32_vec(), data);
706 }
707
708 #[test]
709 fn test_u8_storage_roundtrip() {
710 let data = vec![0u8, 127, 255];
711 let storage = TypedStorage::from_u8_vec(data.clone());
712 assert_eq!(storage.as_u8_vec(), data);
713 }
714
715 #[test]
716 fn test_bool_storage_roundtrip() {
717 let data = vec![true, false, true, true, false];
718 let storage = TypedStorage::from_bool_vec(data.clone());
719 assert_eq!(storage.as_bool_vec(), data);
720 }
721
722 #[test]
723 fn test_complex_storage_roundtrip() {
724 let data = vec![
725 ComplexF64 { re: 1.0, im: 2.0 },
726 ComplexF64 { re: -3.0, im: 0.5 },
727 ];
728 let storage = TypedStorage::from_complex_vec(data.clone());
729 let back = storage.as_complex_vec();
730 assert_eq!(back.len(), 2);
731 assert_eq!(back[0].re, 1.0);
732 assert_eq!(back[0].im, 2.0);
733 assert_eq!(back[1].re, -3.0);
734 assert_eq!(back[1].im, 0.5);
735 }
736
737 #[test]
738 fn test_bf16_storage_roundtrip() {
739 let data = vec![Bf16::from_f32(1.0), Bf16::from_f32(-0.5)];
740 let storage = TypedStorage::from_bf16_vec(data.clone());
741 let back = storage.as_bf16_vec();
742 assert_eq!(back[0].to_f32(), 1.0);
743 assert_eq!(back[1].to_f32(), -0.5);
744 }
745
746 #[test]
747 fn test_cow_semantics() {
748 let s1 = TypedStorage::from_f64_vec(vec![1.0, 2.0, 3.0]);
749 let s2 = s1.clone();
750 assert_eq!(s1.refcount(), 2);
751 assert_eq!(s2.refcount(), 2);
752
753 let s3 = s1.deep_clone();
754 assert_eq!(s3.refcount(), 1);
755 assert_eq!(s1.refcount(), 2); }
757
758 #[test]
759 fn test_cow_mutation() {
760 let s1 = TypedStorage::from_f64_vec(vec![1.0, 2.0, 3.0]);
761 let mut s2 = s1.clone();
762 assert_eq!(s1.refcount(), 2);
763
764 s2.set_from_f64(0, 99.0).unwrap();
765 assert_eq!(s1.refcount(), 1); assert_eq!(s2.refcount(), 1);
767 assert_eq!(s1.as_f64_vec()[0], 1.0); assert_eq!(s2.as_f64_vec()[0], 99.0); }
770
771 #[test]
772 fn test_get_set_f64() {
773 let mut storage = TypedStorage::from_f64_vec(vec![10.0, 20.0, 30.0]);
774 assert_eq!(storage.get_as_f64(0).unwrap(), 10.0);
775 assert_eq!(storage.get_as_f64(2).unwrap(), 30.0);
776 assert!(storage.get_as_f64(3).is_err());
777
778 storage.set_from_f64(1, 99.0).unwrap();
779 assert_eq!(storage.get_as_f64(1).unwrap(), 99.0);
780 }
781
782 #[test]
783 fn test_get_set_i64() {
784 let mut storage = TypedStorage::from_i64_vec(vec![10, 20, 30]);
785 assert_eq!(storage.get_as_f64(0).unwrap(), 10.0);
786 storage.set_from_f64(1, 42.0).unwrap();
787 assert_eq!(storage.as_i64_vec()[1], 42);
788 }
789
790 #[test]
791 fn test_to_f64_vec_conversion() {
792 let storage = TypedStorage::from_i32_vec(vec![1, 2, 3]);
793 assert_eq!(storage.to_f64_vec(), vec![1.0, 2.0, 3.0]);
794
795 let storage = TypedStorage::from_bool_vec(vec![true, false, true]);
796 assert_eq!(storage.to_f64_vec(), vec![1.0, 0.0, 1.0]);
797 }
798
799 #[test]
800 fn test_sum_f64() {
801 let storage = TypedStorage::from_f64_vec(vec![1.0, 2.0, 3.0, 4.0]);
802 assert!((storage.sum_f64() - 10.0).abs() < 1e-12);
803
804 let storage = TypedStorage::from_i64_vec(vec![1, 2, 3, 4]);
805 assert!((storage.sum_f64() - 10.0).abs() < 1e-12);
806 }
807
808 #[test]
809 fn test_cast_f64_to_i64() {
810 let s = TypedStorage::from_f64_vec(vec![1.5, -2.7, 3.0]);
811 let c = s.cast(DType::I64);
812 assert_eq!(c.dtype(), DType::I64);
813 assert_eq!(c.as_i64_vec(), vec![1, -2, 3]);
814 }
815
816 #[test]
817 fn test_cast_i64_to_f32() {
818 let s = TypedStorage::from_i64_vec(vec![1, 2, 3]);
819 let c = s.cast(DType::F32);
820 assert_eq!(c.dtype(), DType::F32);
821 assert_eq!(c.as_f32_vec(), vec![1.0f32, 2.0, 3.0]);
822 }
823
824 #[test]
825 fn test_zeros_all_dtypes() {
826 for dt in &[DType::F64, DType::F32, DType::I64, DType::I32,
827 DType::U8, DType::Bool, DType::Bf16, DType::F16, DType::Complex] {
828 let s = TypedStorage::zeros(*dt, 10);
829 assert_eq!(s.len(), 10);
830 assert_eq!(s.byte_len(), 10 * dt.byte_width());
831 assert!((s.get_as_f64(0).unwrap()).abs() < 1e-15 || s.get_as_f64(0).unwrap() == 0.0);
833 }
834 }
835
836 #[test]
837 fn test_byte_determinism() {
838 let s1 = TypedStorage::from_f64_vec(vec![1.0, 2.0, 3.0]);
840 let s2 = TypedStorage::from_f64_vec(vec![1.0, 2.0, 3.0]);
841 assert_eq!(s1.to_bytes(), s2.to_bytes());
842
843 let s3 = TypedStorage::from_i64_vec(vec![42, -1, 0]);
844 let s4 = TypedStorage::from_i64_vec(vec![42, -1, 0]);
845 assert_eq!(s3.to_bytes(), s4.to_bytes());
846 }
847
848 #[test]
849 fn test_from_bytes_roundtrip() {
850 let original = TypedStorage::from_f64_vec(vec![1.5, -2.3, 0.0]);
851 let bytes = original.to_bytes();
852 let restored = TypedStorage::from_bytes(bytes, DType::F64, 3).unwrap();
853 assert_eq!(original.as_f64_vec(), restored.as_f64_vec());
854 }
855
856 #[test]
857 fn test_from_bytes_size_mismatch() {
858 assert!(TypedStorage::from_bytes(vec![0u8; 10], DType::F64, 2).is_err());
859 }
860}