Skip to main content

egsphsp/
lib.rs

1use std::fmt;
2use std::fs::{remove_file, File, OpenOptions};
3use std::io;
4use std::io::prelude::*;
5use std::io::{BufReader, BufWriter};
6use std::path::Path;
7
8use byteorder::{ByteOrder, LittleEndian};
9use float_cmp::ApproxEqUlps;
10use rand::{RngExt, SeedableRng, rngs::StdRng, seq::SliceRandom};
11
12const HEADER_LENGTH: usize = 25;
13const MAX_RECORD_LENGTH: usize = 32;
14const BUFFER_CAPACITY: usize = 1024 * 1024;
15const MODE_LENGTH: usize = 5;
16const BATCHES: usize = 128; // too high and one hits ulimit (around 1024)
17
18#[derive(Debug, Copy, Clone)]
19pub struct Header {
20    pub mode: [u8; 5],
21    pub total_particles: i32,
22    pub total_photons: i32,
23    pub min_energy: f32,
24    pub max_energy: f32,
25    pub total_particles_in_source: f32,
26    pub record_size: u64,
27    pub using_zlast: bool,
28}
29
30#[derive(Debug, Copy, Clone)]
31pub struct Record {
32    pub latch: u32,
33    total_energy: f32,
34    pub x_cm: f32,
35    pub y_cm: f32,
36    pub x_cos: f32, // TODO verify these are normalized
37    pub y_cos: f32,
38    pub weight: f32, // also carries the sign of the z direction, yikes
39    pub zlast: Option<f32>,
40}
41
42#[derive(Debug)]
43pub struct Transform;
44
45#[derive(Debug)]
46pub enum EGSError {
47    Io(io::Error),
48    BadMode,
49    BadLength,
50    ModeMismatch,
51    HeaderMismatch,
52    RecordMismatch,
53}
54
55pub type EGSResult<T> = Result<T, EGSError>;
56
57impl From<io::Error> for EGSError {
58    fn from(err: io::Error) -> EGSError {
59        EGSError::Io(err)
60    }
61}
62
63impl fmt::Display for EGSError {
64    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65        match *self {
66            EGSError::Io(ref err) => err.fmt(f),
67            EGSError::BadMode => {
68                write!(
69                    f,
70                    "First 5 bytes of file are invalid, must be MODE0 or MODE2"
71                )
72            }
73            EGSError::BadLength => {
74                write!(
75                    f,
76                    "Number of total particles does notmatch byte length of file"
77                )
78            }
79            EGSError::ModeMismatch => write!(f, "Input file MODE0/MODE2 do not match"),
80            EGSError::HeaderMismatch => write!(f, "Headers are different"),
81            EGSError::RecordMismatch => write!(f, "Records are different"),
82        }
83    }
84}
85
86pub struct PHSPReader {
87    reader: BufReader<File>,
88    pub header: Header,
89    next_record: u64,
90}
91
92pub struct PHSPWriter {
93    writer: BufWriter<File>,
94    pub header: Header,
95}
96
97impl PHSPReader {
98    pub fn from(file: File) -> EGSResult<PHSPReader> {
99        let actual_size = (file.metadata()?).len();
100        let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, file);
101        let mut buffer = [0; HEADER_LENGTH];
102        reader.read_exact(&mut buffer)?;
103        let mut mode = [0; MODE_LENGTH];
104        mode.clone_from_slice(&buffer[0..5]);
105        let header = Header {
106            mode,
107            total_particles: LittleEndian::read_i32(&buffer[5..9]),
108            total_photons: LittleEndian::read_i32(&buffer[9..13]),
109            max_energy: LittleEndian::read_f32(&buffer[13..17]),
110            min_energy: LittleEndian::read_f32(&buffer[17..21]),
111            total_particles_in_source: LittleEndian::read_f32(&buffer[21..25]),
112            using_zlast: &mode == b"MODE2",
113            record_size: if &mode == b"MODE0" {
114                28
115            } else if &mode == b"MODE2" {
116                32
117            } else {
118                return Err(EGSError::BadMode);
119            },
120        };
121        if actual_size != header.expected_size() as u64 {
122            writeln!(
123                &mut std::io::stderr(),
124                "Expected {} bytes in file, not {}",
125                header.expected_size(),
126                actual_size
127            )
128            .unwrap();
129            //return Err(EGSError::BadLength);
130        }
131        reader.consume(header.record_size as usize - HEADER_LENGTH);
132        Ok(PHSPReader {
133            reader,
134            header,
135            next_record: 0,
136        })
137    }
138    fn exhausted(&self) -> bool {
139        self.next_record >= self.header.total_particles as u64
140    }
141}
142
143impl Iterator for PHSPReader {
144    type Item = EGSResult<Record>;
145    fn next(&mut self) -> Option<EGSResult<Record>> {
146        if self.next_record >= self.header.total_particles as u64 {
147            return None;
148        }
149        let mut buffer = [0; MAX_RECORD_LENGTH];
150        match self
151            .reader
152            .read_exact(&mut buffer[..self.header.record_size as usize])
153        {
154            Ok(()) => (),
155            Err(err) => return Some(Err(EGSError::Io(err))),
156        };
157        self.next_record += 1;
158        Some(Ok(Record {
159            latch: LittleEndian::read_u32(&buffer[0..4]),
160            total_energy: LittleEndian::read_f32(&buffer[4..8]),
161            x_cm: LittleEndian::read_f32(&buffer[8..12]),
162            y_cm: LittleEndian::read_f32(&buffer[12..16]),
163            x_cos: LittleEndian::read_f32(&buffer[16..20]),
164            y_cos: LittleEndian::read_f32(&buffer[20..24]),
165            weight: LittleEndian::read_f32(&buffer[24..28]),
166            zlast: if self.header.using_zlast {
167                Some(LittleEndian::read_f32(&buffer[28..32]))
168            } else {
169                None
170            },
171        }))
172    }
173}
174
175impl PHSPWriter {
176    pub fn from(file: File, header: &Header) -> EGSResult<PHSPWriter> {
177        let mut writer = BufWriter::with_capacity(BUFFER_CAPACITY, file);
178        let mut buffer = [0; MAX_RECORD_LENGTH];
179        buffer[0..5].clone_from_slice(&header.mode);
180        LittleEndian::write_i32(&mut buffer[5..9], header.total_particles);
181        LittleEndian::write_i32(&mut buffer[9..13], header.total_photons);
182        LittleEndian::write_f32(&mut buffer[13..17], header.max_energy);
183        LittleEndian::write_f32(&mut buffer[17..21], header.min_energy);
184        LittleEndian::write_f32(&mut buffer[21..25], header.total_particles_in_source);
185        writer.write_all(&buffer[..header.record_size as usize])?;
186        Ok(PHSPWriter {
187            header: *header,
188            writer,
189        })
190    }
191
192    pub fn write(&mut self, record: &Record) -> EGSResult<()> {
193        let mut buffer = [0; 32];
194        LittleEndian::write_u32(&mut buffer[0..4], record.latch);
195        LittleEndian::write_f32(&mut buffer[4..8], record.total_energy);
196        LittleEndian::write_f32(&mut buffer[8..12], record.x_cm);
197        LittleEndian::write_f32(&mut buffer[12..16], record.y_cm);
198        LittleEndian::write_f32(&mut buffer[16..20], record.x_cos);
199        LittleEndian::write_f32(&mut buffer[20..24], record.y_cos);
200        LittleEndian::write_f32(&mut buffer[24..28], record.weight);
201        if self.header.using_zlast {
202            LittleEndian::write_f32(
203                &mut buffer[28..32],
204                record.zlast.expect("MODE2 record missing zlast"),
205            );
206        }
207        self.writer
208            .write_all(&buffer[..self.header.record_size as usize])?;
209        Ok(())
210    }
211}
212
213impl Header {
214    fn expected_size(&self) -> usize {
215        (self.total_particles as usize + 1) * self.record_size as usize
216    }
217    pub fn similar_to(&self, other: &Header) -> bool {
218        self.mode == other.mode
219            && self.total_particles == other.total_particles
220            && self.total_photons == other.total_photons
221            && self.max_energy.approx_eq_ulps(&other.max_energy, 10)
222            && self.min_energy.approx_eq_ulps(&other.min_energy, 10)
223            && self
224                .total_particles_in_source
225                .approx_eq_ulps(&other.total_particles_in_source, 2)
226    }
227    fn merge(&mut self, other: &Header) {
228        assert!(self.mode == other.mode, "Merge mode mismatch");
229        self.total_particles = self
230            .total_particles
231            .checked_add(other.total_particles)
232            .expect("Too many particles, i32 overflow");
233        self.total_photons += other.total_photons;
234        self.min_energy = self.min_energy.min(other.min_energy);
235        self.max_energy = self.max_energy.max(other.max_energy);
236        self.total_particles_in_source += other.total_particles_in_source;
237    }
238}
239
240impl Record {
241    pub fn similar_to(&self, other: &Record) -> bool {
242        self.latch == other.latch
243            && (self.total_energy() - other.total_energy()).abs() < 0.01
244            && (self.x_cm - other.x_cm).abs() < 0.01
245            && (self.y_cm - other.y_cm).abs() < 0.01
246            && (self.x_cos - other.x_cos).abs() < 0.01
247            && (self.y_cos - other.y_cos).abs() < 0.01
248            && (self.weight - other.weight).abs() < 0.01
249            && self.zlast == other.zlast
250    }
251    pub fn bremsstrahlung_or_annihilation(&self) -> bool {
252        self.latch & 1 != 0
253    }
254    pub fn bit_region(&self) -> u32 {
255        self.latch & 0xfffffe
256    }
257    pub fn region_number(&self) -> u32 {
258        // EGSnrc: region-of-origin is bits 24-28 (5 bits), used as the value
259        // after >>24. See pirs509a-beamnrc.tex:4954-4989.
260        (self.latch >> 24) & 0x1f
261    }
262    pub fn b29(&self) -> bool {
263        self.latch & (1 << 29) != 0
264    }
265    pub fn charged(&self) -> bool {
266        // EGSnrc encodes IQ in bits 29-30: electron sets bit 30, positron
267        // sets bit 29 alone (phsp_macros.mortran:$GET_E_NPASS_IQ).
268        (self.latch >> 29) & 0b11 != 0
269    }
270    pub fn crossed_multiple(&self) -> bool {
271        self.latch & (1 << 31) != 0
272    }
273    pub fn get_weight(&self) -> f32 {
274        self.weight.abs()
275    }
276    pub fn set_weight(&mut self, new_weight: f32) {
277        self.weight = new_weight * self.weight.signum();
278    }
279    pub fn total_energy(&self) -> f32 {
280        self.total_energy.abs()
281    }
282    pub fn z_positive(&self) -> bool {
283        self.weight.is_sign_positive()
284    }
285    pub fn z_cos(&self) -> f32 {
286        (1.0 - (self.x_cos * self.x_cos + self.y_cos * self.y_cos)).sqrt()
287    }
288    pub fn first_scored_by_primary_history(&self) -> bool {
289        self.total_energy.is_sign_negative()
290    }
291
292    fn translate(&mut self, x: f32, y: f32) {
293        self.x_cm += x;
294        self.y_cm += y;
295    }
296
297    fn transform(&mut self, matrix: &[[f32; 3]; 3]) {
298        let x_cm = self.x_cm;
299        let y_cm = self.y_cm;
300        self.x_cm = matrix[0][0] * x_cm + matrix[0][1] * y_cm + matrix[0][2] * 1.0;
301        self.y_cm = matrix[1][0] * x_cm + matrix[1][1] * y_cm + matrix[1][2] * 1.0;
302        let x_cos = self.x_cos;
303        let y_cos = self.y_cos;
304        let z_cos = self.z_cos();
305        self.x_cos = matrix[0][0] * x_cos + matrix[0][1] * y_cos + matrix[0][2] * z_cos;
306        self.y_cos = matrix[1][0] * x_cos + matrix[1][1] * y_cos + matrix[1][2] * z_cos;
307    }
308}
309
310impl Transform {
311    pub fn reflection(matrix: &mut [[f32; 3]; 3], x_raw: f32, y_raw: f32) {
312        let norm = (x_raw * x_raw + y_raw * y_raw).sqrt();
313        let x = x_raw / norm;
314        let y = y_raw / norm;
315        *matrix = [
316            [x * x - y * y, 2.0 * x * y, 0.0],
317            [2.0 * x * y, y * y - x * x, 0.0],
318            [0.0, 0.0, 1.0],
319        ];
320    }
321    pub fn rotation(matrix: &mut [[f32; 3]; 3], theta: f32) {
322        *matrix = [
323            [theta.cos(), -theta.sin(), 0.0],
324            [theta.sin(), theta.cos(), 0.0],
325            [0.0, 0.0, 1.0],
326        ];
327    }
328}
329
330pub fn randomize(path: &Path, seed: u64) -> EGSResult<()> {
331    let mut rng = StdRng::seed_from_u64(seed);
332    let ifile = File::open(path)?;
333    let mut reader = PHSPReader::from(ifile)?;
334    let header = reader.header;
335    let max_per_batch = reader.header.total_particles as usize / BATCHES + 1;
336    let mut batch_paths = Vec::with_capacity(BATCHES);
337    for i in 0..BATCHES {
338        let mut batch_path = path.to_path_buf();
339        batch_path.set_extension(format!("rand{}", i));
340        batch_paths.push(batch_path);
341    }
342    let mut records = Vec::with_capacity(max_per_batch);
343    for path in batch_paths.iter() {
344        for _ in 0..max_per_batch {
345            if let Some(record) = reader.next() { records.push(record.unwrap()) }
346        }
347        //let mut vec: Vec<Record> = records.collect();
348
349        records.shuffle(&mut rng);
350
351        let header = Header {
352            mode: reader.header.mode,
353            total_particles: records.len() as i32,
354            total_photons: 0,
355            max_energy: 0.0,
356            min_energy: 0.0,
357            total_particles_in_source: 0.0,
358            using_zlast: &reader.header.mode == b"MODE2",
359            record_size: reader.header.record_size,
360        };
361        let ofile = File::create(path)?;
362        let mut writer = PHSPWriter::from(ofile, &header)?;
363        for record in records.iter() {
364            writer.write(record)?;
365        }
366        records.clear();
367    }
368    drop(records);
369    let mut readers = Vec::with_capacity(BATCHES);
370    for path in batch_paths.iter() {
371        let ifile = File::open(path)?;
372        readers.push(PHSPReader::from(ifile)?);
373    }
374
375    let ofile = File::create(path)?;
376    let mut writer = PHSPWriter::from(ofile, &header)?;
377    while !readers.is_empty() {
378        readers.shuffle(&mut rng);
379        for reader in readers.iter_mut() {
380            if let Some(record) = reader.next() { writer.write(&record.unwrap())? }
381        }
382        readers.retain(|r| !r.exhausted());
383    }
384    for path in batch_paths.iter() {
385        remove_file(path)?;
386    }
387    Ok(())
388}
389
390pub fn combine(input_paths: &[&Path], output_path: &Path, delete: bool) -> EGSResult<()> {
391    assert!(!input_paths.is_empty(), "Cannot combine zero files");
392    let reader = PHSPReader::from(File::open(input_paths[0])?)?;
393    let mut final_header = reader.header;
394    for path in input_paths[1..].iter() {
395        let reader = PHSPReader::from(File::open(path)?)?;
396        final_header.merge(&reader.header);
397    }
398    println!("Final header: {:?}", final_header);
399    let ofile = File::create(output_path)?;
400    let mut writer = PHSPWriter::from(ofile, &final_header)?;
401    for path in input_paths.iter() {
402        let reader = PHSPReader::from(File::open(path)?)?;
403        for record in reader {
404            writer.write(&record.unwrap())?
405        }
406        if delete {
407            remove_file(path)?;
408        }
409    }
410    Ok(())
411}
412
413pub fn compare(path1: &Path, path2: &Path) -> EGSResult<()> {
414    let ifile1 = File::open(path1)?;
415    let ifile2 = File::open(path2)?;
416    let reader1 = PHSPReader::from(ifile1)?;
417    let reader2 = PHSPReader::from(ifile2)?;
418    println!("                   First\t\tSecond");
419    println!(
420        "Total particles:   {0: <10}\t\t{1:}",
421        reader1.header.total_particles, reader2.header.total_particles
422    );
423    println!(
424        "Total photons:     {0: <10}\t\t{1}",
425        reader1.header.total_photons, reader2.header.total_photons
426    );
427    println!(
428        "Minimum energy:    {0: <10}\t\t{1}",
429        reader1.header.min_energy, reader2.header.min_energy
430    );
431    println!(
432        "Maximum energy:    {0: <10}\t\t{1}",
433        reader1.header.max_energy, reader2.header.max_energy
434    );
435    println!(
436        "Source particles:  {0: <10}\t\t{1}",
437        reader1.header.total_particles_in_source, reader2.header.total_particles_in_source
438    );
439    if !reader1.header.similar_to(&reader2.header) {
440        println!("Headers different");
441        return Err(EGSError::HeaderMismatch);
442    } else {
443        for (record1, record2) in reader1.zip(reader2) {
444            let r1 = record1.unwrap();
445            let r2 = record2.unwrap();
446            if !r1.similar_to(&r2) {
447                println!("{:?} != {:?}", r1, r2);
448                return Err(EGSError::RecordMismatch);
449            }
450        }
451    }
452    Ok(())
453}
454
455pub fn sample_combine(ipaths: &[&Path], opath: &Path, rate: f64, seed: u64) -> EGSResult<()> {
456    assert!(!ipaths.is_empty(), "Cannot combine zero files");
457    let mut rng = StdRng::seed_from_u64(seed);
458    let mut header = Header {
459        mode: *b"MODE0",
460        record_size: 28,
461        using_zlast: false,
462        total_particles: 0,
463        total_photons: 0,
464        min_energy: 1000.0,
465        max_energy: 0.0,
466        total_particles_in_source: 0.0,
467    };
468    let mut writer = PHSPWriter::from(File::create(opath)?, &header)?;
469    for path in ipaths.iter() {
470        let reader = PHSPReader::from(File::open(path)?)?;
471        if reader.header.using_zlast {
472            return Err(EGSError::ModeMismatch);
473        }
474        println!("Found {} particles", reader.header.total_particles);
475        header.total_particles_in_source += reader.header.total_particles_in_source;
476        let records = reader.filter(|_| rng.random_bool(rate));
477        for record in records.map(|r| r.unwrap()) {
478            header.total_particles = header
479                .total_particles
480                .checked_add(1)
481                .expect("Total particles overflow");
482            if !record.charged() {
483                header.total_photons += 1;
484            }
485            let energy = record.total_energy();
486            header.min_energy = header.min_energy.min(energy);
487            header.max_energy = header.max_energy.max(energy);
488            writer.write(&record)?;
489        }
490        println!("Now have {} particles", header.total_particles);
491    }
492    header.total_particles_in_source *= rate as f32;
493    drop(writer);
494    // write out the header
495    let ofile = OpenOptions::new()
496        .write(true)
497        .create(true)
498        .truncate(false)
499        .open(opath)?;
500    PHSPWriter::from(ofile, &header)?;
501    Ok(())
502}
503
504pub fn translate(input_path: &Path, output_path: &Path, x: f32, y: f32) -> EGSResult<()> {
505    let ifile = File::open(input_path)?;
506    let reader = PHSPReader::from(ifile)?;
507    let ofile = if input_path == output_path {
508        println!(
509            "Translating {} in place by ({}, {})",
510            input_path.display(),
511            x,
512            y
513        );
514        OpenOptions::new()
515            .write(true)
516            .create(true)
517            .truncate(false)
518            .open(output_path)?
519    } else {
520        println!(
521            "Translating {} by ({}, {}) and saving to {}",
522            input_path.display(),
523            x,
524            y,
525            output_path.display()
526        );
527        File::create(output_path)?
528    };
529    let mut writer = PHSPWriter::from(ofile, &reader.header)?;
530    let n_particles = reader.header.total_particles;
531    let mut records_translated = 0;
532    for mut record in reader.map(|r| r.unwrap()) {
533        record.translate(x, y);
534        writer.write(&record)?;
535        records_translated += 1;
536    }
537    println!(
538        "Translated {} records, expected {}",
539        records_translated, n_particles
540    );
541    Ok(())
542}
543
544pub fn transform(input_path: &Path, output_path: &Path, matrix: &[[f32; 3]; 3]) -> EGSResult<()> {
545    let ifile = File::open(input_path)?;
546    let reader = PHSPReader::from(ifile)?;
547    let ofile = if input_path == output_path {
548        println!("Transforming {} in place", input_path.display());
549        OpenOptions::new()
550            .write(true)
551            .create(true)
552            .truncate(false)
553            .open(output_path)?
554    } else {
555        println!(
556            "Transforming {} and saving to {}",
557            input_path.display(),
558            output_path.display()
559        );
560        File::create(output_path)?
561    };
562    let mut writer = PHSPWriter::from(ofile, &reader.header)?;
563    let n_particles = reader.header.total_particles;
564    let mut records_transformed = 0;
565    for mut record in reader.map(|r| r.unwrap()) {
566        record.transform(matrix);
567        writer.write(&record)?;
568        records_transformed += 1;
569    }
570    println!(
571        "Transformed {} records, expected {}",
572        records_transformed, n_particles
573    );
574    Ok(())
575}
576
577pub fn reweight(
578    input_path: &Path,
579    output_path: &Path,
580    f: &dyn Fn(f32) -> f32,
581    _number_bins: usize,
582    _max_radius: f32,
583) -> EGSResult<()> {
584    if input_path == output_path {
585        println!("Reweighting in-place");
586    } else {
587        println!("Reweighting and saving to {}", output_path.display());
588    }
589
590    let reader1 = PHSPReader::from(File::open(input_path)?)?;
591    let mut sum_old_weight = 0.0_f32;
592    let mut sum_new_weight = 0.0_f32;
593    for record in reader1.map(|r| r.unwrap()) {
594        let w = record.get_weight();
595        sum_old_weight += w;
596        let r = (record.x_cm * record.x_cm + record.y_cm * record.y_cm).sqrt();
597        sum_new_weight += w * f(r);
598    }
599
600    let reader2 = PHSPReader::from(File::open(input_path)?)?;
601    let output_file = if input_path == output_path {
602        OpenOptions::new()
603            .write(true)
604            .create(true)
605            .truncate(false)
606            .open(output_path)?
607    } else {
608        File::create(output_path)?
609    };
610    let mut writer = PHSPWriter::from(output_file, &reader2.header)?;
611    let factor = sum_old_weight / sum_new_weight;
612    for mut record in reader2.map(|r| r.unwrap()) {
613        let r = (record.x_cm * record.x_cm + record.y_cm * record.y_cm).sqrt();
614        record.weight *= f(r) * factor;
615        writer.write(&record)?;
616    }
617    Ok(())
618}
619
620#[cfg(test)]
621mod tests {
622    use super::*;
623    use std::sync::atomic::{AtomicU64, Ordering};
624
625    static COUNTER: AtomicU64 = AtomicU64::new(0);
626
627    fn tmp_path(label: &str) -> std::path::PathBuf {
628        let n = COUNTER.fetch_add(1, Ordering::SeqCst);
629        let pid = std::process::id();
630        std::env::temp_dir().join(format!("beamdpr_test_{}_{}_{}.egsphsp1", label, pid, n))
631    }
632
633    fn make_record(latch: u32, energy: f32, x: f32, y: f32, zlast: Option<f32>) -> Record {
634        Record {
635            latch,
636            total_energy: energy,
637            x_cm: x,
638            y_cm: y,
639            x_cos: 0.1,
640            y_cos: 0.2,
641            weight: 1.0,
642            zlast,
643        }
644    }
645
646    fn write_phsp(path: &Path, header: &Header, records: &[Record]) {
647        let f = File::create(path).unwrap();
648        let mut writer = PHSPWriter::from(f, header).unwrap();
649        for r in records {
650            writer.write(r).unwrap();
651        }
652    }
653
654    #[test]
655    fn reweight_applies_radial_function_and_normalizes() {
656        let input = tmp_path("reweight_in");
657        let output = tmp_path("reweight_out");
658        let header = Header {
659            mode: *b"MODE0",
660            total_particles: 3,
661            total_photons: 3,
662            min_energy: 1.0,
663            max_energy: 1.0,
664            total_particles_in_source: 10.0,
665            record_size: 28,
666            using_zlast: false,
667        };
668        let mut records = vec![
669            make_record(0, 1.0, 0.0, 0.0, None),
670            make_record(0, 1.0, 1.0, 0.0, None),
671            make_record(0, 1.0, 2.0, 0.0, None),
672        ];
673        for r in records.iter_mut() {
674            r.weight = 2.0;
675            r.x_cos = 0.0;
676            r.y_cos = 0.0;
677        }
678        write_phsp(&input, &header, &records);
679
680        reweight(&input, &output, &|r| r + 1.0, 10, 5.0).unwrap();
681
682        let reader = PHSPReader::from(File::open(&output).unwrap()).unwrap();
683        let out: Vec<Record> = reader.map(|r| r.unwrap()).collect();
684        let _ = remove_file(&input);
685        let _ = remove_file(&output);
686
687        // sum_old = 6, sum_new = 2*1 + 2*2 + 2*3 = 12, factor = 0.5
688        // expected = original_weight * f(r) * factor = 2 * (r+1) * 0.5 = r + 1
689        let expected = [1.0_f32, 2.0, 3.0];
690        for (i, r) in out.iter().enumerate() {
691            assert!(
692                (r.weight - expected[i]).abs() < 1e-4,
693                "record {}: expected weight {}, got {}",
694                i,
695                expected[i],
696                r.weight
697            );
698        }
699    }
700
701    #[test]
702    fn reweight_uses_abs_weight_for_normalization() {
703        // WT sign carries the Z direction (per EGSnrc / lib.rs:38).
704        // reweight() must conserve total weight *magnitude*. If sum_old_weight
705        // and sum_new_weight accumulate signed weights, a mix of +/- weights
706        // (forward- and backward-going particles) makes the normalization
707        // factor blow up or flip sign.
708        let input = tmp_path("reweight_signed_in");
709        let output = tmp_path("reweight_signed_out");
710        let header = Header {
711            mode: *b"MODE0",
712            total_particles: 4,
713            total_photons: 4,
714            min_energy: 1.0,
715            max_energy: 1.0,
716            total_particles_in_source: 10.0,
717            record_size: 28,
718            using_zlast: false,
719        };
720        // 2 forward (weight=+1), 2 backward (weight=-1), all at r=1.
721        let mut records = vec![
722            make_record(0, 1.0, 1.0, 0.0, None),
723            make_record(0, 1.0, 0.0, 1.0, None),
724            make_record(0, 1.0, -1.0, 0.0, None),
725            make_record(0, 1.0, 0.0, -1.0, None),
726        ];
727        records[0].weight = 1.0;
728        records[1].weight = 1.0;
729        records[2].weight = -1.0;
730        records[3].weight = -1.0;
731        for r in records.iter_mut() {
732            r.x_cos = 0.0;
733            r.y_cos = 0.0;
734        }
735        write_phsp(&input, &header, &records);
736
737        // f(r) = 1 (constant). sum_old_|w| = 4, sum_new_|w| = 4, factor = 1.
738        // After reweight, magnitudes should all be 1.0, signs preserved.
739        reweight(&input, &output, &|_r| 1.0, 10, 5.0).unwrap();
740
741        let reader = PHSPReader::from(File::open(&output).unwrap()).unwrap();
742        let out: Vec<Record> = reader.map(|r| r.unwrap()).collect();
743        let _ = remove_file(&input);
744        let _ = remove_file(&output);
745
746        let expected_sign = [1.0_f32, 1.0, -1.0, -1.0];
747        for (i, r) in out.iter().enumerate() {
748            assert!(
749                r.weight.is_finite(),
750                "record {}: weight became non-finite ({}) due to signed-sum normalization",
751                i,
752                r.weight
753            );
754            assert!(
755                (r.weight.abs() - 1.0).abs() < 1e-4,
756                "record {}: expected |weight| 1.0, got |{}| (factor was wrong)",
757                i,
758                r.weight
759            );
760            assert!(
761                r.weight.signum() == expected_sign[i],
762                "record {}: expected sign {}, got {}",
763                i,
764                expected_sign[i],
765                r.weight.signum()
766            );
767        }
768    }
769
770    #[test]
771    fn sample_combine_uses_abs_energy_for_min_max() {
772        let input = tmp_path("sample_in");
773        let output = tmp_path("sample_out");
774        let header = Header {
775            mode: *b"MODE0",
776            total_particles: 3,
777            total_photons: 3,
778            min_energy: 0.5,
779            max_energy: 3.0,
780            total_particles_in_source: 10.0,
781            record_size: 28,
782            using_zlast: false,
783        };
784        let records = vec![
785            make_record(0, 0.5, 0.0, 0.0, None),
786            make_record(0, -3.0, 0.0, 0.0, None),
787            make_record(0, 1.5, 0.0, 0.0, None),
788        ];
789        write_phsp(&input, &header, &records);
790
791        sample_combine(&[&input], &output, 1.0, 0).unwrap();
792
793        let reader = PHSPReader::from(File::open(&output).unwrap()).unwrap();
794        let got = reader.header;
795        let _ = remove_file(&input);
796        let _ = remove_file(&output);
797
798        assert_eq!(got.total_particles, 3);
799        assert!(
800            (got.max_energy - 3.0).abs() < 1e-5,
801            "max_energy should be 3.0 (abs of -3.0), got {}",
802            got.max_energy
803        );
804        assert!(
805            (got.min_energy - 0.5).abs() < 1e-5,
806            "min_energy should be 0.5, got {}",
807            got.min_energy
808        );
809    }
810
811    #[test]
812    fn sample_combine_returns_mode_mismatch_on_mode2_input() {
813        // sample_combine is MODE0-only; MODE2 input must return a clean
814        // EGSError::ModeMismatch, not panic via assert!().
815        let input = tmp_path("sample_mode2_in");
816        let output = tmp_path("sample_mode2_out");
817        let header = Header {
818            mode: *b"MODE2",
819            total_particles: 1,
820            total_photons: 1,
821            min_energy: 0.5,
822            max_energy: 1.0,
823            total_particles_in_source: 1.0,
824            record_size: 32,
825            using_zlast: true,
826        };
827        let r = make_record(0, 1.0, 0.0, 0.0, Some(2.5));
828        write_phsp(&input, &header, &[r]);
829
830        let result = sample_combine(&[&input], &output, 1.0, 0);
831        let _ = remove_file(&input);
832        let _ = remove_file(&output);
833
834        assert!(
835            matches!(result, Err(EGSError::ModeMismatch)),
836            "expected Err(ModeMismatch), got {:?}",
837            result
838        );
839    }
840
841    #[test]
842    fn region_number_decodes_five_bits_at_offset_24() {
843        // Per EGSnrc beamnrc docs (pirs509a-beamnrc.tex:4954-4989) and
844        // beamnrc_user_macros.mortran:121-128 ($LATCH_NUMBER_OF_BITS=5):
845        // region-of-origin lives in bits 24-28 and is consumed after >>24.
846        // Today's mask 0xf000000 drops bit 28 and doesn't shift.
847        let all_five_bits = make_record(0x1f000000, 1.0, 0.0, 0.0, None);
848        assert_eq!(
849            all_five_bits.region_number(),
850            31,
851            "all 5 region bits set should decode to 31"
852        );
853
854        let only_bit_28 = make_record(0x10000000, 1.0, 0.0, 0.0, None);
855        assert_eq!(
856            only_bit_28.region_number(),
857            16,
858            "bit 28 alone should decode to 16 (currently lost by 0xf000000 mask)"
859        );
860
861        // Low bits (region-traversed bits 1-23, charge bits 29-30, NPASS bit 31)
862        // must NOT bleed into the region-of-origin value.
863        let noisy = make_record(0xff_ff_ff_ff, 1.0, 0.0, 0.0, None);
864        assert_eq!(
865            noisy.region_number(),
866            31,
867            "region_number must mask off bits 29-31 and bits 0-23"
868        );
869    }
870
871    #[test]
872    fn charged_is_true_for_positrons_via_bit_29() {
873        // Per EGSnrc phsp_macros.mortran:234-262 ($GET_E_NPASS_IQ):
874        //   bit 30 set            => electron (IQ = -1)
875        //   bit 30 clear, bit 29  => positron (IQ = +1)
876        //   both clear            => photon
877        let electron = make_record(1 << 30, 1.0, 0.0, 0.0, None);
878        let positron = make_record(1 << 29, 1.0, 0.0, 0.0, None);
879        let photon = make_record(0, 1.0, 0.0, 0.0, None);
880
881        assert!(electron.charged(), "electron (bit 30) must be charged");
882        assert!(
883            positron.charged(),
884            "positron (bit 29 alone) must be charged \
885             — currently slipping through and being counted as a photon"
886        );
887        assert!(!photon.charged(), "photon (no charge bits) must not be charged");
888    }
889
890    #[test]
891    fn crossed_multiple_is_independent_of_charged() {
892        let mut r = make_record(0, 1.0, 0.0, 0.0, None);
893        r.latch = 1 << 30;
894        assert!(r.charged());
895        assert!(
896            !r.crossed_multiple(),
897            "crossed_multiple should be bit 31, distinct from charged (bit 30)"
898        );
899
900        r.latch = 1 << 31;
901        assert!(!r.charged());
902        assert!(r.crossed_multiple(), "bit 31 should mean crossed_multiple");
903    }
904
905    #[test]
906    fn transform_uses_original_z_cos_for_y_row() {
907        let mut record = make_record(0, 1.0, 0.0, 0.0, None);
908        record.x_cos = 0.6;
909        record.y_cos = 0.0;
910        let original_z_cos = record.z_cos();
911        let matrix = [
912            [0.5, 0.0, 0.0],
913            [0.0, 1.0, 1.0],
914            [0.0, 0.0, 1.0],
915        ];
916        record.transform(&matrix);
917        let expected_y_cos = 1.0 * original_z_cos;
918        assert!(
919            (record.y_cos - expected_y_cos).abs() < 1e-5,
920            "expected y_cos {} (from original z_cos), got {}",
921            expected_y_cos,
922            record.y_cos
923        );
924    }
925
926    #[test]
927    fn similar_to_detects_negative_difference() {
928        let a = make_record(0, 1.0, 0.0, 0.0, None);
929        let mut b = make_record(0, 1.0, 0.0, 0.0, None);
930        b.x_cm = 100.0;
931        assert!(
932            !a.similar_to(&b),
933            "records with x_cm differing by 100 should not be similar"
934        );
935    }
936
937    #[test]
938    fn mode2_writer_preserves_zlast() {
939        let path = tmp_path("mode2_zlast");
940        let header = Header {
941            mode: *b"MODE2",
942            total_particles: 2,
943            total_photons: 1,
944            min_energy: 0.5,
945            max_energy: 2.0,
946            total_particles_in_source: 100.0,
947            record_size: 32,
948            using_zlast: true,
949        };
950        let r1 = make_record(0, 1.0, 0.0, 0.0, Some(7.25));
951        let r2 = make_record(1 << 30, 2.0, 1.0, 1.0, Some(-3.5));
952        write_phsp(&path, &header, &[r1, r2]);
953
954        let reader = PHSPReader::from(File::open(&path).unwrap()).unwrap();
955        let records: Vec<Record> = reader.map(|r| r.unwrap()).collect();
956        let _ = remove_file(&path);
957
958        assert_eq!(records.len(), 2);
959        assert_eq!(records[0].zlast, Some(7.25), "first zlast not preserved");
960        assert_eq!(records[1].zlast, Some(-3.5), "second zlast not preserved");
961    }
962}