1#[allow(unused_imports)]
6use super::functions::*;
7
8#[cfg(test)]
9mod tests {
10 use super::*;
11 use crate::numpy::NpyArray;
12 use crate::numpy::NpyDtype;
13 use crate::numpy::NpyField;
14 use crate::numpy::NpyMaskedArray;
15 use crate::numpy::NpyRecordArray;
16 use crate::numpy::NpySlice;
17 use crate::numpy::NpzArchive;
18 use crate::numpy::NpzWriter;
19 #[test]
21 fn test_roundtrip_f64() {
22 let shape = vec![3usize, 2];
23 let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
24 let bytes = write_npy_f64(&shape, &data);
25 let (got_shape, got_data) = read_npy_f64(&bytes).expect("read_npy_f64 failed");
26 assert_eq!(got_shape, shape);
27 assert_eq!(got_data.len(), data.len());
28 for (a, b) in data.iter().zip(got_data.iter()) {
29 assert_eq!(a.to_bits(), b.to_bits());
30 }
31 }
32 #[test]
34 fn test_roundtrip_i32() {
35 let shape = vec![2usize, 3];
36 let data: Vec<i32> = vec![-1, 0, 1, i32::MAX, i32::MIN, 42];
37 let bytes = write_npy_i32(&shape, &data);
38 let (got_shape, got_data) = read_npy_i32(&bytes).expect("read_npy_i32 failed");
39 assert_eq!(got_shape, shape);
40 assert_eq!(got_data, data);
41 }
42 #[test]
44 fn test_magic_bytes() {
45 let shape = vec![2usize];
46 let data = vec![1.0f64, 2.0f64];
47 let bytes = write_npy_f64(&shape, &data);
48 assert_eq!(&bytes[0..6], b"\x93NUMPY");
49 }
50 #[test]
52 fn test_npz_roundtrip() {
53 let mut writer = NpzWriter::new();
54 let shape_f = vec![3usize, 2];
55 let data_f: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
56 let shape_i = vec![4usize];
57 let data_i: Vec<i32> = vec![10, 20, 30, 40];
58 writer.add_array_f64("matrix", &shape_f, &data_f);
59 writer.add_array_i32("counts", &shape_i, &data_i);
60 let bytes = writer.to_bytes();
61 let recovered = NpzWriter::from_bytes(&bytes).expect("from_bytes failed");
62 assert_eq!(recovered.files.len(), 2);
63 let (s1, d1) = recovered
64 .get_f64("matrix")
65 .expect("matrix not found")
66 .expect("read_npy_f64 failed");
67 assert_eq!(s1, shape_f);
68 assert_eq!(d1, data_f);
69 let (s2, d2) = recovered
70 .get_i32("counts")
71 .expect("counts not found")
72 .expect("read_npy_i32 failed");
73 assert_eq!(s2, shape_i);
74 assert_eq!(d2, data_i);
75 }
76 #[test]
78 fn test_shape_encoding() {
79 let shape = vec![5usize, 4, 3];
80 let data = vec![0.0f64; 60];
81 let bytes = write_npy_f64(&shape, &data);
82 let (got_shape, _) = read_npy_f64(&bytes).expect("read_npy_f64 failed");
83 assert_eq!(got_shape, shape);
84 }
85 #[test]
87 fn test_numpy_str() {
88 assert_eq!(NpyDtype::Float64.numpy_str(), "<f8");
89 assert_eq!(NpyDtype::Float32.numpy_str(), "<f4");
90 assert_eq!(NpyDtype::Int32.numpy_str(), "<i4");
91 assert_eq!(NpyDtype::Int64.numpy_str(), "<i8");
92 assert_eq!(NpyDtype::Bool.numpy_str(), "?");
93 assert_eq!(NpyDtype::Uint8.numpy_str(), "|u1");
94 }
95 #[test]
97 fn test_1d_roundtrip() {
98 let shape = vec![5usize];
99 let data: Vec<f64> = vec![10.0, 20.0, 30.0, 40.0, 50.0];
100 let bytes = write_npy_f64(&shape, &data);
101 let (got_shape, got_data) = read_npy_f64(&bytes).expect("1d read failed");
102 assert_eq!(got_shape, shape);
103 assert_eq!(got_data, data);
104 }
105 #[test]
106 fn test_roundtrip_f32() {
107 let shape = vec![2usize, 3];
108 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
109 let bytes = write_npy_f32(&shape, &data);
110 let (got_shape, got_data) = read_npy_f32(&bytes).expect("read_npy_f32 failed");
111 assert_eq!(got_shape, shape);
112 assert_eq!(got_data, data);
113 }
114 #[test]
115 fn test_roundtrip_i64() {
116 let shape = vec![3usize];
117 let data: Vec<i64> = vec![i64::MIN, 0, i64::MAX];
118 let bytes = write_npy_i64(&shape, &data);
119 let (got_shape, got_data) = read_npy_i64(&bytes).expect("read_npy_i64 failed");
120 assert_eq!(got_shape, shape);
121 assert_eq!(got_data, data);
122 }
123 #[test]
124 fn test_element_size() {
125 assert_eq!(NpyDtype::Float64.element_size(), 8);
126 assert_eq!(NpyDtype::Float32.element_size(), 4);
127 assert_eq!(NpyDtype::Int32.element_size(), 4);
128 assert_eq!(NpyDtype::Int64.element_size(), 8);
129 assert_eq!(NpyDtype::Bool.element_size(), 1);
130 assert_eq!(NpyDtype::Uint8.element_size(), 1);
131 }
132 #[test]
133 fn test_dtype_from_str() {
134 assert_eq!(NpyDtype::from_numpy_str("<f8"), Ok(NpyDtype::Float64));
135 assert_eq!(NpyDtype::from_numpy_str("<f4"), Ok(NpyDtype::Float32));
136 assert_eq!(NpyDtype::from_numpy_str("<i4"), Ok(NpyDtype::Int32));
137 assert!(NpyDtype::from_numpy_str("bad").is_err());
138 }
139 #[test]
140 fn test_npy_array_validate() {
141 let arr = NpyArray::from_f64(vec![2, 3], vec![1.0; 6]);
142 assert!(arr.validate().is_ok());
143 let bad = NpyArray::from_f64(vec![2, 3], vec![1.0; 5]);
144 assert!(bad.validate().is_err());
145 }
146 #[test]
147 fn test_npy_array_reshape() {
148 let mut arr = NpyArray::from_f64(vec![2, 3], vec![1.0; 6]);
149 assert!(arr.reshape(vec![3, 2]).is_ok());
150 assert_eq!(arr.shape, vec![3, 2]);
151 assert!(arr.reshape(vec![4, 2]).is_err());
152 }
153 #[test]
154 fn test_npy_array_from_f32() {
155 let arr = NpyArray::from_f32(vec![4], vec![1.0f32; 4]);
156 assert_eq!(arr.dtype, NpyDtype::Float32);
157 assert_eq!(arr.ndim(), 1);
158 assert_eq!(arr.numel(), 4);
159 }
160 #[test]
161 fn test_npy_array_from_i32() {
162 let arr = NpyArray::from_i32(vec![2, 2], vec![1, 2, 3, 4]);
163 assert_eq!(arr.dtype, NpyDtype::Int32);
164 assert_eq!(arr.ndim(), 2);
165 }
166 #[test]
167 fn test_validate_shape_ok() {
168 assert!(validate_shape(&[2, 3], 6).is_ok());
169 assert!(validate_shape(&[5], 5).is_ok());
170 }
171 #[test]
172 fn test_validate_shape_err() {
173 assert!(validate_shape(&[2, 3], 7).is_err());
174 }
175 #[test]
176 fn test_flat_index() {
177 assert_eq!(flat_index(&[1, 2], &[3, 4]).unwrap(), 6);
178 assert_eq!(flat_index(&[0, 0], &[3, 4]).unwrap(), 0);
179 }
180 #[test]
181 fn test_flat_index_out_of_range() {
182 assert!(flat_index(&[3, 0], &[3, 4]).is_err());
183 }
184 #[test]
185 fn test_unravel_index() {
186 let indices = unravel_index(6, &[3, 4]).unwrap();
187 assert_eq!(indices, vec![1, 2]);
188 }
189 #[test]
190 fn test_unravel_flat_roundtrip() {
191 let shape = vec![3, 4, 5];
192 for flat in 0..60 {
193 let indices = unravel_index(flat, &shape).unwrap();
194 let recovered = flat_index(&indices, &shape).unwrap();
195 assert_eq!(flat, recovered, "round-trip failed for flat={flat}");
196 }
197 }
198 #[test]
199 fn test_detect_dtype() {
200 let bytes = write_npy_f64(&[2], &[1.0, 2.0]);
201 assert_eq!(detect_npy_dtype(&bytes).unwrap(), NpyDtype::Float64);
202 let bytes = write_npy_i32(&[3], &[1, 2, 3]);
203 assert_eq!(detect_npy_dtype(&bytes).unwrap(), NpyDtype::Int32);
204 }
205 #[test]
206 fn test_read_npy_shape() {
207 let bytes = write_npy_f64(&[3, 4, 5], &vec![0.0; 60]);
208 let shape = read_npy_shape(&bytes).unwrap();
209 assert_eq!(shape, vec![3, 4, 5]);
210 }
211 #[test]
212 fn test_npz_names_contains_remove() {
213 let mut w = NpzWriter::new();
214 w.add_array_f64("a", &[2], &[1.0, 2.0]);
215 w.add_array_i32("b", &[3], &[1, 2, 3]);
216 assert_eq!(w.len(), 2);
217 assert!(w.contains("a"));
218 assert!(!w.contains("c"));
219 let names = w.names();
220 assert!(names.contains(&"a"));
221 assert!(names.contains(&"b"));
222 assert!(w.remove("a"));
223 assert_eq!(w.len(), 1);
224 assert!(!w.contains("a"));
225 assert!(!w.remove("a"));
226 }
227 #[test]
228 fn test_npz_f32_i64() {
229 let mut w = NpzWriter::new();
230 w.add_array_f32("floats", &[3], &[1.0f32, 2.0, 3.0]);
231 w.add_array_i64("longs", &[2], &[100i64, 200]);
232 let bytes = w.to_bytes();
233 let recovered = NpzWriter::from_bytes(&bytes).unwrap();
234 let (shape, data) = recovered.get_f32("floats").unwrap().unwrap();
235 assert_eq!(shape, vec![3]);
236 assert_eq!(data, vec![1.0f32, 2.0, 3.0]);
237 let (shape, data) = recovered.get_i64("longs").unwrap().unwrap();
238 assert_eq!(shape, vec![2]);
239 assert_eq!(data, vec![100i64, 200]);
240 }
241 #[test]
242 fn test_scalar_array() {
243 let shape: Vec<usize> = vec![];
244 let data = vec![42.0_f64];
245 let bytes = write_npy_f64(&shape, &data);
246 let (got_shape, got_data) = read_npy_f64(&bytes).unwrap();
247 assert!(got_shape.is_empty());
248 assert_eq!(got_data.len(), 1);
249 assert_eq!(got_data[0], 42.0);
250 }
251 #[test]
252 fn test_wrong_dtype_error() {
253 let bytes = write_npy_f64(&[2], &[1.0, 2.0]);
254 assert!(read_npy_i32(&bytes).is_err());
255 }
256 #[test]
257 fn test_truncated_data_error() {
258 let bytes = write_npy_f64(&[10], &[0.0; 10]);
259 let truncated = &bytes[..bytes.len() - 10];
260 assert!(read_npy_f64(truncated).is_err());
261 }
262 #[test]
263 fn test_bad_magic_error() {
264 let mut bytes = write_npy_f64(&[2], &[1.0, 2.0]);
265 bytes[0] = 0;
266 assert!(read_npy_f64(&bytes).is_err());
267 }
268 #[test]
269 fn test_npy_slice_row() {
270 let data: Vec<f64> = (0..12).map(|i| i as f64).collect();
271 let slice = NpySlice::new(&data, vec![3, 4]).unwrap();
272 let row1 = slice.row(1).unwrap();
273 assert_eq!(row1, &[4.0, 5.0, 6.0, 7.0]);
274 }
275 #[test]
276 fn test_npy_slice_get() {
277 let data: Vec<f64> = (0..6).map(|i| i as f64).collect();
278 let slice = NpySlice::new(&data, vec![2, 3]).unwrap();
279 assert_eq!(slice.get(&[1, 2]).unwrap(), 5.0);
280 }
281 #[test]
282 fn test_npy_slice_shape_mismatch() {
283 let data = vec![1.0; 6];
284 assert!(NpySlice::new(&data, vec![2, 4]).is_err());
285 }
286 #[test]
287 fn test_masked_array_mean_valid() {
288 let data = vec![1.0, 2.0, 3.0, 4.0];
289 let mask = vec![false, true, false, false];
290 let ma = NpyMaskedArray::new(data, mask, vec![4], 1e20).unwrap();
291 let mean = ma.mean_valid().unwrap();
292 assert!((mean - 8.0 / 3.0).abs() < 1e-10, "mean={mean}");
293 }
294 #[test]
295 fn test_masked_array_filled() {
296 let data = vec![1.0, 2.0];
297 let mask = vec![false, true];
298 let ma = NpyMaskedArray::new(data, mask, vec![2], 999.0).unwrap();
299 let filled = ma.filled();
300 assert_eq!(filled[0], 1.0);
301 assert_eq!(filled[1], 999.0);
302 }
303 #[test]
304 fn test_masked_array_count_valid() {
305 let ma = NpyMaskedArray::from_data(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
306 assert_eq!(ma.count_valid(), 3);
307 }
308 #[test]
309 fn test_masked_array_mask_greater_than() {
310 let mut ma = NpyMaskedArray::from_data(vec![1.0, 5.0, 2.0, 10.0], vec![4]).unwrap();
311 ma.mask_greater_than(4.0);
312 assert!(!ma.mask[0]);
313 assert!(ma.mask[1]);
314 assert!(!ma.mask[2]);
315 assert!(ma.mask[3]);
316 assert_eq!(ma.count_valid(), 2);
317 }
318 #[test]
319 fn test_slice_mean() {
320 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
321 assert_eq!(slice_mean(&data).unwrap(), 3.0);
322 assert!(slice_mean(&[]).is_none());
323 }
324 #[test]
325 fn test_slice_var() {
326 let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
327 let v = slice_var(&data).unwrap();
328 assert!((v - 4.0).abs() < 1e-10, "var={v}");
329 }
330 #[test]
331 fn test_slice_std() {
332 let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
333 let s = slice_std(&data).unwrap();
334 assert!((s - 2.0).abs() < 1e-10, "std={s}");
335 }
336 #[test]
337 fn test_slice_min_max() {
338 let data = vec![3.0, 1.0, 4.0, 1.5, 9.0, 2.6];
339 let (min_v, min_i, max_v, max_i) = slice_min_max(&data).unwrap();
340 assert_eq!(min_v, 1.0);
341 assert_eq!(min_i, 1);
342 assert_eq!(max_v, 9.0);
343 assert_eq!(max_i, 4);
344 }
345 #[test]
346 fn test_slice_percentile_median() {
347 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
348 let median = slice_percentile(&data, 50.0).unwrap();
349 assert!((median - 3.0).abs() < 1e-10, "median={median}");
350 }
351 #[test]
352 fn test_slice_clip() {
353 let data = vec![-1.0, 0.5, 2.0, 3.5];
354 let clipped = slice_clip(&data, 0.0, 2.0);
355 assert_eq!(clipped, vec![0.0, 0.5, 2.0, 2.0]);
356 }
357 #[test]
358 fn test_slice_dot() {
359 let a = vec![1.0, 2.0, 3.0];
360 let b = vec![4.0, 5.0, 6.0];
361 let dot = slice_dot(&a, &b).unwrap();
362 assert!((dot - 32.0).abs() < 1e-10);
363 }
364 #[test]
365 fn test_npz_archive_roundtrip() {
366 let mut archive = NpzArchive::new();
367 archive.insert("pos", NpyArray::from_f64(vec![3], vec![1.0, 2.0, 3.0]));
368 archive.insert("idx", NpyArray::from_i32(vec![2], vec![10, 20]));
369 let bytes = archive.to_bytes().unwrap();
370 let recovered = NpzArchive::from_bytes(&bytes).unwrap();
371 assert_eq!(recovered.len(), 2);
372 let pos = recovered.get("pos").unwrap();
373 assert_eq!(pos.dtype, NpyDtype::Float64);
374 assert_eq!(pos.data_f64, vec![1.0, 2.0, 3.0]);
375 let idx = recovered.get("idx").unwrap();
376 assert_eq!(idx.data_i32, vec![10, 20]);
377 }
378 #[test]
379 fn test_npz_archive_names_remove() {
380 let mut archive = NpzArchive::new();
381 archive.insert("a", NpyArray::from_f64(vec![1], vec![1.0]));
382 archive.insert("b", NpyArray::from_f64(vec![1], vec![2.0]));
383 assert!(archive.names().contains(&"a"));
384 assert!(archive.remove("a"));
385 assert!(!archive.names().contains(&"a"));
386 assert_eq!(archive.len(), 1);
387 }
388 #[test]
389 fn test_record_array_push_and_get() {
390 let fields = vec![
391 NpyField::scalar("x", NpyDtype::Float64),
392 NpyField::scalar("y", NpyDtype::Float64),
393 NpyField::scalar("mass", NpyDtype::Float64),
394 ];
395 let mut ra = NpyRecordArray::new(fields);
396 ra.push_record(&[1.0, 2.0, 12.0]).unwrap();
397 ra.push_record(&[3.0, 4.0, 16.0]).unwrap();
398 assert_eq!(ra.n_records, 2);
399 assert!((ra.get_scalar(1, "mass").unwrap() - 16.0).abs() < 1e-10);
400 }
401 #[test]
402 fn test_record_array_column() {
403 let fields = vec![
404 NpyField::scalar("vx", NpyDtype::Float64),
405 NpyField::scalar("vy", NpyDtype::Float64),
406 ];
407 let mut ra = NpyRecordArray::new(fields);
408 ra.push_record(&[0.1, 0.2]).unwrap();
409 ra.push_record(&[0.3, 0.4]).unwrap();
410 let col = ra.column("vx").unwrap();
411 assert_eq!(col, &[0.1, 0.3]);
412 }
413 #[test]
414 fn test_linspace() {
415 let v = linspace(0.0, 1.0, 5);
416 assert_eq!(v.len(), 5);
417 assert!((v[0] - 0.0).abs() < 1e-10);
418 assert!((v[2] - 0.5).abs() < 1e-10);
419 assert!((v[4] - 1.0).abs() < 1e-10);
420 }
421 #[test]
422 fn test_linspace_single() {
423 let v = linspace(3.0, 3.0, 1);
424 assert_eq!(v, vec![3.0]);
425 }
426 #[test]
427 fn test_arange() {
428 let v = arange(0.0, 1.0, 0.25).unwrap();
429 assert_eq!(v.len(), 4);
430 assert!((v[0] - 0.0).abs() < 1e-10);
431 assert!((v[3] - 0.75).abs() < 1e-10);
432 }
433 #[test]
434 fn test_arange_zero_step_error() {
435 assert!(arange(0.0, 1.0, 0.0).is_err());
436 }
437 #[test]
438 fn test_logspace() {
439 let v = logspace(0.0, 2.0, 3);
440 assert_eq!(v.len(), 3);
441 assert!((v[0] - 1.0).abs() < 1e-10);
442 assert!((v[1] - 10.0).abs() < 1e-10);
443 assert!((v[2] - 100.0).abs() < 1e-10);
444 }
445 #[test]
446 fn test_transpose_2d() {
447 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
448 let (t, shape) = transpose_2d(&data, &[2, 3]).unwrap();
449 assert_eq!(shape, vec![3, 2]);
450 assert_eq!(t[0], 1.0);
451 assert_eq!(t[1], 4.0);
452 }
453 #[test]
454 fn test_transpose_2d_square() {
455 let data = vec![1.0, 2.0, 3.0, 4.0];
456 let (t, shape) = transpose_2d(&data, &[2, 2]).unwrap();
457 assert_eq!(shape, vec![2, 2]);
458 assert_eq!(t, vec![1.0, 3.0, 2.0, 4.0]);
459 }
460 #[test]
461 fn test_matmul_2x3_3x2() {
462 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
463 let b = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
464 let (c, shape) = matmul(&a, &[2, 3], &b, &[3, 2]).unwrap();
465 assert_eq!(shape, vec![2, 2]);
466 assert!((c[0] - 58.0).abs() < 1e-10);
467 assert!((c[1] - 64.0).abs() < 1e-10);
468 assert!((c[2] - 139.0).abs() < 1e-10);
469 assert!((c[3] - 154.0).abs() < 1e-10);
470 }
471 #[test]
472 fn test_matmul_identity() {
473 let id = vec![1.0, 0.0, 0.0, 1.0];
474 let a = vec![3.0, 4.0, 5.0, 6.0];
475 let (c, _) = matmul(&id, &[2, 2], &a, &[2, 2]).unwrap();
476 assert_eq!(c, a);
477 }
478 #[test]
479 fn test_save_structured_magic() {
480 let mut data_bytes = Vec::new();
481 for v in &[1.0f64, 2.0f64, 3.0f64, 4.0f64] {
482 data_bytes.extend_from_slice(&v.to_le_bytes());
483 }
484 let bytes =
485 NpyArray::save_structured(&[("a", "<f8"), ("b", "<f8")], 2, &data_bytes).unwrap();
486 assert_eq!(&bytes[0..6], NPY_MAGIC.as_ref());
487 }
488 #[test]
489 fn test_save_structured_header_contains_field_names() {
490 let data_bytes = vec![0u8; 8];
491 let bytes = NpyArray::save_structured(&[("pressure", "<f8")], 1, &data_bytes).unwrap();
492 let header_len = u16::from_le_bytes([bytes[8], bytes[9]]) as usize;
493 let header = std::str::from_utf8(&bytes[10..10 + header_len]).unwrap();
494 assert!(
495 header.contains("pressure"),
496 "header should contain field name 'pressure'"
497 );
498 }
499 #[test]
500 fn test_save_structured_empty_fields_error() {
501 let result = NpyArray::save_structured(&[], 0, &[]);
502 assert!(result.is_err());
503 }
504 #[test]
505 fn test_save_structured_header_len_multiple_64() {
506 let data_bytes = vec![0u8; 16];
507 let bytes = NpyArray::save_structured(&[("x", "<f8")], 2, &data_bytes).unwrap();
508 let header_len = u16::from_le_bytes([bytes[8], bytes[9]]) as usize;
509 assert_eq!(header_len % 64, 0, "header_len should be multiple of 64");
510 }
511 #[test]
512 fn test_npz_archive_add_array_replaces() {
513 let mut archive = NpzArchive::new();
514 archive.add_array("v", NpyArray::from_f64(vec![2], vec![1.0, 2.0]));
515 archive.add_array("v", NpyArray::from_f64(vec![3], vec![9.0, 8.0, 7.0]));
516 assert_eq!(archive.len(), 1, "add_array should replace existing entry");
517 assert_eq!(archive.get("v").unwrap().shape, vec![3]);
518 }
519 #[test]
520 fn test_npz_archive_load_all_roundtrip() {
521 let mut archive = NpzArchive::new();
522 archive.add_array(
523 "coords",
524 NpyArray::from_f64(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]),
525 );
526 archive.add_array("ids", NpyArray::from_i32(vec![2], vec![0, 1]));
527 let bytes = archive.to_bytes().unwrap();
528 let loaded = NpzArchive::load_all(&bytes).unwrap();
529 assert_eq!(loaded.len(), 2);
530 let coords = loaded.get("coords").unwrap();
531 assert_eq!(coords.shape, vec![2, 3]);
532 assert!((coords.data_f64[5] - 6.0).abs() < 1e-12);
533 }
534 #[test]
535 fn test_npz_archive_iter() {
536 let mut archive = NpzArchive::new();
537 archive.add_array("a", NpyArray::from_f64(vec![1], vec![1.0]));
538 archive.add_array("b", NpyArray::from_f64(vec![1], vec![2.0]));
539 let names: Vec<&str> = archive.iter().map(|(n, _)| n).collect();
540 assert!(names.contains(&"a"));
541 assert!(names.contains(&"b"));
542 }
543 #[test]
544 fn test_npz_archive_merge() {
545 let mut a = NpzArchive::new();
546 a.add_array("x", NpyArray::from_f64(vec![1], vec![1.0]));
547 let mut b = NpzArchive::new();
548 b.add_array("y", NpyArray::from_f64(vec![1], vec![2.0]));
549 b.add_array("x", NpyArray::from_f64(vec![1], vec![99.0]));
550 a.merge(b);
551 assert_eq!(a.len(), 2);
552 assert!((a.get("x").unwrap().data_f64[0] - 99.0).abs() < 1e-12);
553 }
554 #[test]
555 fn test_npz_archive_total_elements() {
556 let mut archive = NpzArchive::new();
557 archive.add_array("a", NpyArray::from_f64(vec![3], vec![1.0, 2.0, 3.0]));
558 archive.add_array("b", NpyArray::from_i32(vec![2, 2], vec![1, 2, 3, 4]));
559 assert_eq!(archive.total_elements(), 7);
560 }
561}