fits-well 0.1.1

A blazing-fast reader and writer for FITS (Flexible Image Transport System) files, targeting the full FITS 4.0 standard.
//! `RICE_1` tile codec (a port of cfitsio's `fits_rdecomp` bitstream layout).

use crate::bitpix::Bitpix;
use crate::header::Header;
use crate::keyword::key;

/// Rice block size and pixel width, from the `ZNAMEi`/`ZVALi` parameters.
#[derive(Debug, Clone, Copy)]
pub(super) struct RiceParams {
    pub blocksize: usize,
    pub bytepix: usize,
}

/// Rice parameters from the `ZNAMEi`/`ZVALi` keywords, defaulting to a block size
/// of 32 and `bytepix = |ZBITPIX|/8`.
pub(super) fn rice_params(header: &Header, zbitpix: Bitpix) -> RiceParams {
    let mut blocksize = 32;
    let mut bytepix = zbitpix.elem_size();
    let mut i = 1;
    while let Some(name) = header.get_text(key!("ZNAME{i}").as_str()) {
        if let Some(v) = header.get_integer(key!("ZVAL{i}").as_str()) {
            match name {
                "BLOCKSIZE" => blocksize = v.max(1) as usize,
                "BYTEPIX" => bytepix = v.max(1) as usize,
                _ => {}
            }
        }
        i += 1;
    }
    RiceParams { blocksize, bytepix }
}

/// Decode a `RICE_1` tile of `nx` integer values into `out` (cleared first; a reused
/// buffer, so steady-state decode allocates nothing).
pub(super) fn rice_decode_into(
    bytes: &[u8],
    nx: usize,
    bytepix: usize,
    blocksize: usize,
    out: &mut Vec<i64>,
) {
    let nbits_pp = (8 * bytepix) as u32;
    let (fsbits, fsmax) = match bytepix {
        1 => (3u32, 6u32),
        2 => (4, 14),
        _ => (5, 25), // 4-byte (and wider) pixels
    };
    let mask = if nbits_pp >= 64 {
        u64::MAX
    } else {
        (1u64 << nbits_pp) - 1
    };

    let mut br = BitReader::new(bytes);
    let mut lastpix = br.read(nbits_pp); // literal first pixel (big-endian)
    out.clear();
    out.reserve(nx);
    let mut i = 0;
    while i < nx {
        let fs = br.read(fsbits) as i64 - 1;
        let imax = (i + blocksize).min(nx);
        for _ in i..imax {
            let diff = if fs < 0 {
                0
            } else if fs as u32 == fsmax {
                br.read(nbits_pp) // uncompressed block
            } else {
                (br.read_zeros() << fs) | br.read(fs as u32)
            };
            // Undo the zigzag mapping, then the differencing (modular at pixel width).
            let d = if diff & 1 == 1 {
                !(diff >> 1)
            } else {
                diff >> 1
            };
            lastpix = lastpix.wrapping_add(d) & mask;
            out.push(sign_extend(lastpix, nbits_pp));
        }
        i = imax;
    }
}

/// Interpret the low `nbits` of `v` as a two's-complement signed value.
fn sign_extend(v: u64, nbits: u32) -> i64 {
    let shift = 64 - nbits;
    ((v << shift) as i64) >> shift
}

/// Encode `values` as a `RICE_1` tile (a port of cfitsio's `fits_rcomp`),
/// parameterized by `bytepix` (1/2/4). Differences are taken modulo the pixel
/// width so the stream round-trips through [`rice_decode_into`].
pub(super) fn rice_encode(values: &[i64], bytepix: usize, blocksize: usize) -> Vec<u8> {
    let nbits = (8 * bytepix) as u32;
    let (fsbits, fsmax) = match bytepix {
        1 => (3i32, 6i32),
        2 => (4, 14),
        _ => (5, 25),
    };
    let mask: u64 = if nbits >= 64 {
        u64::MAX
    } else {
        (1u64 << nbits) - 1
    };
    let half: u64 = 1u64 << (nbits - 1);

    let mut bo = BitOutput::new();
    // Rice output is at most a few bytes per pixel; reserve a pixel's worth up front
    // so the bitstream rarely reallocates mid-tile.
    bo.out.reserve(values.len());
    let first = (*values.first().unwrap_or(&0) as u64) & mask;
    bo.output_nbits(first as i64, nbits as i32);
    let mut lastpix = first;

    // One difference buffer reused across blocks (cleared each block) rather than a
    // fresh allocation per block — a tile has thousands of blocks.
    let mut diffs: Vec<u64> = Vec::with_capacity(blocksize);
    let mut i = 0;
    while i < values.len() {
        let thisblock = blocksize.min(values.len() - i);
        diffs.clear();
        let mut pixelsum = 0.0f64;
        for j in 0..thisblock {
            let next = (values[i + j] as u64) & mask;
            // signed difference reduced to the pixel width, then zigzag-mapped
            let raw = next.wrapping_sub(lastpix) & mask;
            let s = if raw >= half {
                raw as i64 - (mask as i64) - 1
            } else {
                raw as i64
            };
            let d = if s >= 0 {
                (s as u64) << 1
            } else {
                (((-s) as u64) << 1) - 1
            };
            diffs.push(d);
            pixelsum += d as f64;
            lastpix = next;
        }

        let dpsum = ((pixelsum - thisblock as f64 / 2.0 - 1.0) / thisblock as f64).max(0.0);
        let mut psum = (dpsum as u64) >> 1;
        let mut fs = 0i32;
        while psum > 0 {
            fs += 1;
            psum >>= 1;
        }

        if fs >= fsmax {
            bo.output_nbits((fsmax + 1) as i64, fsbits);
            for &d in &diffs {
                bo.output_nbits(d as i64, nbits as i32);
            }
        } else if fs == 0 && pixelsum == 0.0 {
            bo.output_nbits(0, fsbits);
        } else {
            bo.output_nbits((fs + 1) as i64, fsbits);
            let fsmask = (1i64 << fs) - 1;
            for &d in &diffs {
                bo.output_rice_value(d as i64, fs, fsmask);
            }
        }
        i += thisblock;
    }
    bo.done();
    bo.out
}

/// MSB-first bit output, mirroring cfitsio's `Buffer`/`output_nbits`.
struct BitOutput {
    out: Vec<u8>,
    bitbuffer: i64,
    bits_to_go: i32,
}

impl BitOutput {
    fn new() -> Self {
        BitOutput {
            out: Vec::new(),
            bitbuffer: 0,
            bits_to_go: 8,
        }
    }

    fn output_nbits(&mut self, bits: i64, mut n: i32) {
        let mask = |k: i32| {
            if k >= 32 {
                0xFFFF_FFFFi64
            } else {
                (1i64 << k) - 1
            }
        };
        let mut lb = self.bitbuffer;
        let mut ltg = self.bits_to_go;
        if ltg + n > 32 {
            lb <<= ltg;
            lb |= (bits >> (n - ltg)) & mask(ltg);
            self.out.push((lb & 0xff) as u8);
            n -= ltg;
            ltg = 8;
        }
        lb <<= n;
        lb |= bits & mask(n);
        ltg -= n;
        while ltg <= 0 {
            self.out.push(((lb >> (-ltg)) & 0xff) as u8);
            ltg += 8;
        }
        self.bitbuffer = lb;
        self.bits_to_go = ltg;
    }

    /// Output one Rice-coded value: `top = v >> fs` zero bits, a 1, then the low
    /// `fs` bits of `v`.
    fn output_rice_value(&mut self, v: i64, fs: i32, fsmask: i64) {
        let top = v >> fs;
        if (self.bits_to_go as i64) > top {
            self.bitbuffer <<= top + 1;
            self.bitbuffer |= 1;
            self.bits_to_go -= (top + 1) as i32;
        } else {
            self.bitbuffer <<= self.bits_to_go;
            self.out.push((self.bitbuffer & 0xff) as u8);
            let mut t = top - self.bits_to_go as i64;
            while t >= 8 {
                self.out.push(0);
                t -= 8;
            }
            self.bitbuffer = 1;
            self.bits_to_go = 7 - t as i32;
        }
        if fs > 0 {
            self.bitbuffer <<= fs;
            self.bitbuffer |= v & fsmask;
            self.bits_to_go -= fs;
            while self.bits_to_go <= 0 {
                self.out
                    .push(((self.bitbuffer >> (-self.bits_to_go)) & 0xff) as u8);
                self.bits_to_go += 8;
            }
        }
    }

    fn done(&mut self) {
        if self.bits_to_go < 8 {
            self.out
                .push(((self.bitbuffer << self.bits_to_go) & 0xff) as u8);
        }
    }
}

/// A MSB-first bit reader over a compressed byte stream.
pub(super) struct BitReader<'a> {
    bytes: &'a [u8],
    pos: usize,
    acc: u64,
    nbits: u32,
}

impl<'a> BitReader<'a> {
    pub(super) fn new(bytes: &'a [u8]) -> Self {
        BitReader {
            bytes,
            pos: 0,
            acc: 0,
            nbits: 0,
        }
    }

    /// Read `n` bits (MSB-first, `n ≤ 32`); past end-of-input reads as zero bits.
    pub(super) fn read(&mut self, n: u32) -> u64 {
        if self.nbits < n {
            self.fill();
        }
        self.nbits -= n;
        let mask = if n >= 64 { u64::MAX } else { (1u64 << n) - 1 };
        (self.acc >> self.nbits) & mask
    }

    /// Top up the accumulator so a subsequent `read` of up to 32 bits needs no
    /// refill, or the input is exhausted (past which bytes read as zero). Loads a
    /// whole 8-byte word at once when the accumulator is empty and a word remains —
    /// the common mid-stream case — instead of eight separate bounds-checked loads.
    #[inline]
    fn fill(&mut self) {
        if self.nbits == 0 && self.pos + 8 <= self.bytes.len() {
            let word = self.bytes[self.pos..self.pos + 8].try_into().unwrap();
            self.acc = u64::from_be_bytes(word);
            self.pos += 8;
            self.nbits = 64;
            return;
        }
        // Load whole bytes until another would overflow the 64-bit accumulator;
        // that leaves ≥ 57 bits, enough for any single ≤ 32-bit read.
        while self.nbits <= 56 {
            let byte = self.bytes.get(self.pos).copied().unwrap_or(0);
            self.pos += 1;
            self.acc = (self.acc << 8) | byte as u64;
            self.nbits += 8;
        }
    }

    /// Count and consume leading zero bits up to (and including) the next 1.
    ///
    /// Scans the zero run a whole word at a time via `leading_zeros` rather than one
    /// `read(1)` per bit — the unary quotient decode is the hot path of Rice decode.
    /// Stops once the real input is exhausted (a truncated tile with no terminating
    /// 1 bit would otherwise loop forever — a DoS on untrusted bytes); the exact
    /// count past EOF is unspecified (the data is corrupt), only termination matters.
    pub(super) fn read_zeros(&mut self) -> u64 {
        let mut z = 0u64;
        loop {
            if self.nbits == 0 {
                // Refill one byte; at EOF there is no terminating 1 bit, so stop.
                if self.pos >= self.bytes.len() {
                    return z;
                }
                self.acc = (self.acc << 8) | self.bytes[self.pos] as u64;
                self.pos += 1;
                self.nbits += 8;
            }
            // Left-align the valid low `nbits` bits so the next-to-read bit is the
            // MSB, then count zeros up to the first 1 (capped at the valid bits, since
            // the shifted-in low bits read as zero).
            let run = (self.acc << (64 - self.nbits))
                .leading_zeros()
                .min(self.nbits);
            if run < self.nbits {
                // Terminating 1 found within the valid bits: consume the zeros + the 1.
                self.nbits -= run + 1;
                return z + run as u64;
            }
            // All valid bits were zero: consume them and refill on the next pass.
            z += self.nbits as u64;
            self.nbits = 0;
        }
    }
}

#[cfg(test)]
mod tests {
    use super::BitReader;

    #[test]
    fn bit_reader_reads_msb_first() {
        let mut br = BitReader::new(&[0b1011_0010, 0b1111_0000]);
        assert_eq!(br.read(1), 1);
        assert_eq!(br.read(3), 0b011);
        assert_eq!(br.read(4), 0b0010);
        assert_eq!(br.read(4), 0b1111);
    }

    #[test]
    fn read_zeros_counts_runs_across_bytes_and_leftover_bits() {
        // MSB-first. 0x00 0x80 = 0000_0000 1000_0000: an 8-bit zero run spanning the
        // first byte, terminated by the leading 1 of the second (exercises the
        // cross-byte refill mid-run). Consumes 9 bits, leaving 7 trailing zeros which
        // then hit EOF and stop (count past EOF is unspecified — here the 7 zeros).
        let mut br = BitReader::new(&[0x00, 0x80]);
        assert_eq!(br.read_zeros(), 8);
        assert_eq!(br.read_zeros(), 7);

        // 0x01 = 0000_0001: 7 zeros then a 1, entirely within one byte.
        let mut br = BitReader::new(&[0x01]);
        assert_eq!(br.read_zeros(), 7);

        // Leftover bits before a run: read(4) leaves 4 valid bits, then read_zeros
        // works from them. 0x08 = 0000_1000 → high nibble 0, then the '1' is next
        // (run 0); 0x40 = 0100_0000 → 3 trailing zeros of byte0 + 1 zero → run 4.
        let mut br = BitReader::new(&[0x08, 0x40]);
        assert_eq!(br.read(4), 0);
        assert_eq!(br.read_zeros(), 0);
        assert_eq!(br.read_zeros(), 4);
    }

    #[test]
    fn truncated_stream_terminates_instead_of_hanging() {
        // A stream that enters a Rice zero-run (fs = 0) but ends before the
        // terminating 1-bit. byte0 is the literal first pixel; byte1 = 0b001_00000
        // gives the 3-bit fs field `001` (→ fs = 0), after which only zero bits
        // remain and `read` zero-fills past EOF. Without the exhaustion guard in
        // `read_zeros` this would spin forever; the decode must return (here, two
        // bounded values) — reaching this assert at all is the guarantee.
        let mut out = Vec::new();
        super::rice_decode_into(&[0x00, 0x20], 2, 1, 32, &mut out);
        assert_eq!(out.len(), 2);
    }
}