1use std::collections::HashMap;
11use std::io::{Read as IoRead, Write as IoWrite};
12
13use crate::{Error, Result};
14
15#[derive(Debug, Clone)]
19pub struct DicomTag {
20 pub group: u16,
22 pub element: u16,
24 pub value_representation: String,
26 pub data: Vec<u8>,
28}
29
30impl DicomTag {
31 pub fn new(group: u16, element: u16, vr: impl Into<String>, data: Vec<u8>) -> Self {
33 Self {
34 group,
35 element,
36 value_representation: vr.into(),
37 data,
38 }
39 }
40
41 pub fn from_str(group: u16, element: u16, vr: impl Into<String>, value: &str) -> Self {
43 Self::new(group, element, vr, value.as_bytes().to_vec())
44 }
45
46 pub fn from_f64(group: u16, element: u16, vr: impl Into<String>, value: f64) -> Self {
48 Self::new(group, element, vr, value.to_le_bytes().to_vec())
49 }
50}
51
52#[derive(Debug, Clone, Default)]
54pub struct DicomDataset {
55 pub tags: HashMap<(u16, u16), DicomTag>,
57}
58
59impl DicomDataset {
60 pub fn new() -> Self {
62 Self::default()
63 }
64
65 pub fn insert(&mut self, tag: DicomTag) {
67 self.tags.insert((tag.group, tag.element), tag);
68 }
69
70 pub fn get_string(&self, group: u16, elem: u16) -> Option<String> {
74 let tag = self.tags.get(&(group, elem))?;
75 String::from_utf8(tag.data.clone()).ok()
76 }
77
78 pub fn get_f64(&self, group: u16, elem: u16) -> Option<f64> {
85 let tag = self.tags.get(&(group, elem))?;
86 let vr = tag.value_representation.as_str();
87 if vr == "FD" || vr == "FL" {
88 if tag.data.len() == 8 {
90 let bytes: [u8; 8] = tag.data[..8].try_into().ok()?;
91 return Some(f64::from_le_bytes(bytes));
92 }
93 }
94 if let Ok(s) = String::from_utf8(tag.data.clone())
96 && let Ok(v) = s.trim().parse::<f64>()
97 {
98 return Some(v);
99 }
100 if tag.data.len() == 8 {
102 let bytes: [u8; 8] = tag.data[..8].try_into().ok()?;
103 return Some(f64::from_le_bytes(bytes));
104 }
105 None
106 }
107}
108
109#[derive(Debug, Clone)]
113pub struct VoxelVolume {
114 pub dimensions: [usize; 3],
116 pub voxel_spacing: [f64; 3],
118 pub pixel_data: Vec<i16>,
120}
121
122impl VoxelVolume {
123 pub fn new(dimensions: [usize; 3], voxel_spacing: [f64; 3]) -> Self {
125 let n = dimensions[0] * dimensions[1] * dimensions[2];
126 Self {
127 dimensions,
128 voxel_spacing,
129 pixel_data: vec![0; n],
130 }
131 }
132
133 pub fn to_hounsfield(pixel: i16, slope: f64, intercept: f64) -> f64 {
137 pixel as f64 * slope + intercept
138 }
139
140 pub fn voxel_count(&self) -> usize {
142 self.dimensions[0] * self.dimensions[1] * self.dimensions[2]
143 }
144
145 pub fn physical_volume_mm3(&self) -> f64 {
147 self.voxel_count() as f64
148 * self.voxel_spacing[0]
149 * self.voxel_spacing[1]
150 * self.voxel_spacing[2]
151 }
152
153 pub fn get(&self, x: usize, y: usize, z: usize) -> Option<i16> {
155 if x < self.dimensions[0] && y < self.dimensions[1] && z < self.dimensions[2] {
156 Some(
157 self.pixel_data
158 [z * self.dimensions[1] * self.dimensions[0] + y * self.dimensions[0] + x],
159 )
160 } else {
161 None
162 }
163 }
164}
165
166#[derive(Debug, Clone)]
170pub struct Segmentation {
171 pub labels: Vec<u8>,
173 pub n_classes: usize,
175}
176
177impl Segmentation {
178 pub fn new(n_voxels: usize, n_classes: usize) -> Self {
180 Self {
181 labels: vec![0; n_voxels],
182 n_classes,
183 }
184 }
185
186 pub fn compute_volume(&self, label: u8, spacing: [f64; 3]) -> f64 {
190 let count = self.labels.iter().filter(|&&l| l == label).count();
191 count as f64 * spacing[0] * spacing[1] * spacing[2]
192 }
193
194 pub fn label_fraction(&self, label: u8) -> f64 {
196 if self.labels.is_empty() {
197 return 0.0;
198 }
199 let count = self.labels.iter().filter(|&&l| l == label).count();
200 count as f64 / self.labels.len() as f64
201 }
202}
203
204#[derive(Debug, Clone)]
212pub struct NiftiHeader {
213 pub dim: [usize; 7],
216 pub pixdim: [f64; 7],
218 pub datatype: u16,
220}
221
222impl NiftiHeader {
223 pub fn new_3d(nx: usize, ny: usize, nz: usize, dx: f64, dy: f64, dz: f64) -> Self {
225 Self {
226 dim: [3, nx, ny, nz, 1, 1, 1],
227 pixdim: [1.0, dx, dy, dz, 0.0, 0.0, 0.0],
228 datatype: 4, }
230 }
231
232 pub fn write_header(&self, path: &str) -> Result<()> {
237 let mut file = std::fs::File::create(path).map_err(Error::Io)?;
238 for d in &self.dim {
239 file.write_all(&(*d as u64).to_le_bytes())
240 .map_err(Error::Io)?;
241 }
242 for p in &self.pixdim {
243 file.write_all(&p.to_le_bytes()).map_err(Error::Io)?;
244 }
245 file.write_all(&self.datatype.to_le_bytes())
246 .map_err(Error::Io)?;
247 Ok(())
248 }
249
250 pub fn read_header(path: &str) -> Result<Self> {
252 let mut file = std::fs::File::open(path).map_err(Error::Io)?;
253 let mut buf = [0u8; 114];
254 file.read_exact(&mut buf).map_err(Error::Io)?;
255 let mut dim = [0usize; 7];
256 for (i, d) in dim.iter_mut().enumerate() {
257 let bytes: [u8; 8] = buf[i * 8..(i + 1) * 8]
258 .try_into()
259 .map_err(|_| Error::Parse("dim bytes".into()))?;
260 *d = u64::from_le_bytes(bytes) as usize;
261 }
262 let offset = 7 * 8;
263 let mut pixdim = [0f64; 7];
264 for (i, p) in pixdim.iter_mut().enumerate() {
265 let bytes: [u8; 8] = buf[offset + i * 8..offset + (i + 1) * 8]
266 .try_into()
267 .map_err(|_| Error::Parse("pixdim bytes".into()))?;
268 *p = f64::from_le_bytes(bytes);
269 }
270 let dt_offset = offset + 7 * 8;
271 let datatype = u16::from_le_bytes([buf[dt_offset], buf[dt_offset + 1]]);
272 Ok(Self {
273 dim,
274 pixdim,
275 datatype,
276 })
277 }
278}
279
280#[derive(Debug, Clone, Copy, PartialEq)]
284pub enum PhantomGeometry {
285 Sphere(f64),
287 Cylinder {
289 radius: f64,
291 height: f64,
293 },
294}
295
296#[derive(Debug, Clone)]
300pub struct MriPhantom {
301 pub geometry: PhantomGeometry,
303 pub t1: f64,
305 pub t2: f64,
307 pub proton_density: f64,
309}
310
311impl MriPhantom {
312 pub fn new(geometry: PhantomGeometry, t1: f64, t2: f64, proton_density: f64) -> Self {
314 Self {
315 geometry,
316 t1,
317 t2,
318 proton_density,
319 }
320 }
321
322 pub fn simulate_signal(&self, te: f64, tr: f64) -> f64 {
328 mri_signal_se(self.proton_density, self.t1, self.t2, tr, te)
329 }
330}
331
332pub fn hounsfield_to_material(hu: f64) -> &'static str {
344 if hu < -950.0 {
345 "air"
346 } else if hu < -100.0 {
347 "lung"
348 } else if hu < 20.0 {
349 "fat/soft_tissue"
350 } else if hu < 400.0 {
351 "soft_tissue/blood"
352 } else if hu < 1000.0 {
353 "bone"
354 } else {
355 "dense_bone/metal"
356 }
357}
358
359pub fn mri_signal_se(rho: f64, t1: f64, t2: f64, tr: f64, te: f64) -> f64 {
365 if t1 <= 0.0 || t2 <= 0.0 {
366 return 0.0;
367 }
368 rho * (1.0 - (-tr / t1).exp()) * (-te / t2).exp()
369}
370
371#[cfg(test)]
374mod tests {
375 use super::*;
376
377 const EPS: f64 = 1e-9;
378
379 #[test]
382 fn test_dicom_tag_from_str() {
383 let t = DicomTag::from_str(0x0008, 0x0060, "CS", "CT");
384 assert_eq!(t.group, 0x0008);
385 assert_eq!(t.element, 0x0060);
386 assert_eq!(t.value_representation, "CS");
387 assert_eq!(&t.data, b"CT");
388 }
389
390 #[test]
391 fn test_dicom_tag_from_f64() {
392 let t = DicomTag::from_f64(0x0028, 0x0030, "DS", 1.5);
393 assert_eq!(t.data.len(), 8);
394 let v = f64::from_le_bytes(t.data[..8].try_into().unwrap());
395 assert!((v - 1.5).abs() < EPS);
396 }
397
398 #[test]
401 fn test_dataset_get_string() {
402 let mut ds = DicomDataset::new();
403 ds.insert(DicomTag::from_str(0x0010, 0x0010, "PN", "Smith^John"));
404 let name = ds.get_string(0x0010, 0x0010);
405 assert_eq!(name, Some("Smith^John".to_string()));
406 }
407
408 #[test]
409 fn test_dataset_get_string_missing() {
410 let ds = DicomDataset::new();
411 assert!(ds.get_string(0x0001, 0x0001).is_none());
412 }
413
414 #[test]
415 fn test_dataset_get_f64_binary() {
416 let mut ds = DicomDataset::new();
417 ds.insert(DicomTag::from_f64(0x0028, 0x1053, "FD", 3.125));
418 let v = ds.get_f64(0x0028, 0x1053).unwrap();
419 assert!((v - 3.125).abs() < EPS);
420 }
421
422 #[test]
423 fn test_dataset_get_f64_string() {
424 let mut ds = DicomDataset::new();
425 ds.insert(DicomTag::from_str(0x0028, 0x1052, "DS", " 42.5 "));
426 let v = ds.get_f64(0x0028, 0x1052).unwrap();
427 assert!((v - 42.5).abs() < EPS);
428 }
429
430 #[test]
431 fn test_dataset_insert_overwrites() {
432 let mut ds = DicomDataset::new();
433 ds.insert(DicomTag::from_str(0x0010, 0x0010, "PN", "Old"));
434 ds.insert(DicomTag::from_str(0x0010, 0x0010, "PN", "New"));
435 assert_eq!(ds.get_string(0x0010, 0x0010), Some("New".to_string()));
436 }
437
438 #[test]
441 fn test_voxel_volume_count() {
442 let v = VoxelVolume::new([4, 5, 6], [1.0; 3]);
443 assert_eq!(v.voxel_count(), 120);
444 }
445
446 #[test]
447 fn test_voxel_volume_physical_volume() {
448 let v = VoxelVolume::new([10, 10, 10], [2.0, 2.0, 2.0]);
449 assert!((v.physical_volume_mm3() - 8000.0).abs() < EPS);
450 }
451
452 #[test]
453 fn test_to_hounsfield_water() {
454 let hu = VoxelVolume::to_hounsfield(0, 1.0, 0.0);
456 assert!((hu).abs() < EPS);
457 }
458
459 #[test]
460 fn test_to_hounsfield_bone() {
461 let hu = VoxelVolume::to_hounsfield(700, 1.0, -1024.0);
462 assert!((hu + 324.0).abs() < EPS);
463 }
464
465 #[test]
466 fn test_voxel_get_in_bounds() {
467 let v = VoxelVolume::new([3, 3, 3], [1.0; 3]);
468 assert_eq!(v.get(0, 0, 0), Some(0));
469 }
470
471 #[test]
472 fn test_voxel_get_out_of_bounds() {
473 let v = VoxelVolume::new([3, 3, 3], [1.0; 3]);
474 assert!(v.get(10, 0, 0).is_none());
475 }
476
477 #[test]
480 fn test_hu_material_air() {
481 assert_eq!(hounsfield_to_material(-1000.0), "air");
482 }
483
484 #[test]
485 fn test_hu_material_lung() {
486 assert_eq!(hounsfield_to_material(-500.0), "lung");
487 }
488
489 #[test]
490 fn test_hu_material_fat() {
491 assert_eq!(hounsfield_to_material(-50.0), "fat/soft_tissue");
492 }
493
494 #[test]
495 fn test_hu_material_soft_tissue() {
496 assert_eq!(hounsfield_to_material(50.0), "soft_tissue/blood");
497 }
498
499 #[test]
500 fn test_hu_material_bone() {
501 assert_eq!(hounsfield_to_material(700.0), "bone");
502 }
503
504 #[test]
505 fn test_hu_material_dense_bone() {
506 assert_eq!(hounsfield_to_material(1500.0), "dense_bone/metal");
507 }
508
509 #[test]
512 fn test_segmentation_volume_zero() {
513 let seg = Segmentation::new(100, 3);
514 let vol = seg.compute_volume(1, [1.0; 3]);
515 assert!((vol).abs() < EPS);
516 }
517
518 #[test]
519 fn test_segmentation_volume_all_labelled() {
520 let mut seg = Segmentation::new(8, 1);
521 seg.labels = vec![1; 8];
522 let vol = seg.compute_volume(1, [2.0, 2.0, 2.0]);
523 assert!((vol - 64.0).abs() < EPS);
524 }
525
526 #[test]
527 fn test_segmentation_label_fraction() {
528 let mut seg = Segmentation::new(10, 2);
529 seg.labels[0] = 1;
530 seg.labels[1] = 1;
531 assert!((seg.label_fraction(1) - 0.2).abs() < EPS);
532 }
533
534 #[test]
535 fn test_segmentation_empty() {
536 let seg = Segmentation::new(0, 1);
537 assert!((seg.label_fraction(1)).abs() < EPS);
538 }
539
540 #[test]
543 fn test_mri_signal_long_tr_short_te() {
544 let s = mri_signal_se(1.0, 500.0, 100.0, 1e9, 0.0);
546 assert!((s - 1.0).abs() < 1e-6, "signal should be ~rho: {s}");
547 }
548
549 #[test]
550 fn test_mri_signal_zero_rho() {
551 assert!((mri_signal_se(0.0, 500.0, 100.0, 1000.0, 10.0)).abs() < EPS);
552 }
553
554 #[test]
555 fn test_mri_signal_invalid_t1() {
556 assert!((mri_signal_se(1.0, 0.0, 100.0, 1000.0, 10.0)).abs() < EPS);
557 }
558
559 #[test]
560 fn test_mri_signal_invalid_t2() {
561 assert!((mri_signal_se(1.0, 500.0, 0.0, 1000.0, 10.0)).abs() < EPS);
562 }
563
564 #[test]
565 fn test_mri_signal_t1_weighting() {
566 let s_low = mri_signal_se(1.0, 300.0, 100.0, 600.0, 10.0);
568 let s_high = mri_signal_se(1.0, 1500.0, 100.0, 600.0, 10.0);
569 assert!(
570 s_low > s_high,
571 "lower T1 should give higher T1-weighted signal"
572 );
573 }
574
575 #[test]
576 fn test_mri_signal_t2_weighting() {
577 let s_short = mri_signal_se(1.0, 500.0, 80.0, 2000.0, 10.0);
579 let s_long = mri_signal_se(1.0, 500.0, 80.0, 2000.0, 100.0);
580 assert!(s_short > s_long, "short TE should give higher signal");
581 }
582
583 #[test]
586 fn test_phantom_simulate_signal() {
587 let p = MriPhantom::new(PhantomGeometry::Sphere(50.0), 800.0, 80.0, 1.0);
588 let s = p.simulate_signal(10.0, 2000.0);
589 assert!(s > 0.0 && s <= 1.0);
590 }
591
592 #[test]
593 fn test_phantom_sphere_geometry() {
594 let p = MriPhantom::new(PhantomGeometry::Sphere(25.0), 500.0, 60.0, 0.8);
595 if let PhantomGeometry::Sphere(r) = p.geometry {
596 assert!((r - 25.0).abs() < EPS);
597 } else {
598 panic!("expected sphere");
599 }
600 }
601
602 #[test]
603 fn test_phantom_cylinder_geometry() {
604 let p = MriPhantom::new(
605 PhantomGeometry::Cylinder {
606 radius: 30.0,
607 height: 100.0,
608 },
609 1000.0,
610 100.0,
611 1.0,
612 );
613 if let PhantomGeometry::Cylinder { radius, height } = p.geometry {
614 assert!((radius - 30.0).abs() < EPS);
615 assert!((height - 100.0).abs() < EPS);
616 } else {
617 panic!("expected cylinder");
618 }
619 }
620
621 #[test]
624 fn test_nifti_roundtrip() {
625 let path = "/tmp/test_nifti_header.bin";
626 let hdr = NiftiHeader::new_3d(64, 128, 32, 0.5, 0.5, 1.0);
627 hdr.write_header(path).unwrap();
628 let loaded = NiftiHeader::read_header(path).unwrap();
629 assert_eq!(loaded.dim[0], 3);
630 assert_eq!(loaded.dim[1], 64);
631 assert_eq!(loaded.dim[2], 128);
632 assert_eq!(loaded.dim[3], 32);
633 assert!((loaded.pixdim[1] - 0.5).abs() < EPS);
634 assert!((loaded.pixdim[3] - 1.0).abs() < EPS);
635 assert_eq!(loaded.datatype, 4);
636 }
637
638 #[test]
639 fn test_nifti_write_nonexistent_dir_fails() {
640 let path = "/tmp/nonexistent_dir_xyz/header.bin";
641 let hdr = NiftiHeader::new_3d(10, 10, 10, 1.0, 1.0, 1.0);
642 assert!(hdr.write_header(path).is_err());
643 }
644
645 #[test]
646 fn test_nifti_read_nonexistent_fails() {
647 assert!(NiftiHeader::read_header("/tmp/does_not_exist_nifti.bin").is_err());
648 }
649
650 #[test]
651 fn test_nifti_multiple_roundtrips() {
652 for i in 0..3_u8 {
653 let path = format!("/tmp/test_nifti_{i}.bin");
654 let hdr = NiftiHeader::new_3d(10 + i as usize * 5, 20, 30, 1.0 + i as f64, 1.0, 1.0);
655 hdr.write_header(&path).unwrap();
656 let loaded = NiftiHeader::read_header(&path).unwrap();
657 assert_eq!(loaded.dim[1], 10 + i as usize * 5);
658 assert!((loaded.pixdim[1] - (1.0 + i as f64)).abs() < EPS);
659 }
660 }
661}