Skip to main content

flow_fcs_compress/codec/
log_quant.rs

1//! Mode C: lossy log-domain quantization.
2//!
3//! Pipeline (encode):
4//! 1. `y = asinh(x / cofactor)` per [`crate::transform::asinh`].
5//! 2. Compute `y_min` and `y_max` for the chunk.
6//! 3. Quantize: `q = round((y - y_min) / step)` where `step = (y_max - y_min) / (2^bits - 1)`.
7//! 4. Bitpack `q` into `bits` per value (same packing primitive as Mode B).
8//!
9//! Decode reverses: unpack `q`, reconstruct `y = y_min + q * step`, then `x = cofactor * sinh(y)`.
10//!
11//! ## Loss model
12//!
13//! Quantization step in `y`-space translates to a roughly *relative* error in
14//! `x`-space for large `|x|` (where asinh ≈ log) and a *near-absolute* error
15//! near zero (where asinh ≈ identity / cofactor). The codec is therefore
16//! "log-relative" — a property practitioners want for fluorescence data
17//! displayed on log axes.
18//!
19//! ## Configuration
20//!
21//! [`LogQuantizationConfig`] exposes two knobs:
22//! - `cofactor` — passed straight to `asinh`. Default 150 (Cytek-style).
23//! - `bits` — 4..=24. Lower = smaller / lossier. Default 16.
24//!
25//! ## Per-chunk header (24 bytes)
26//!
27//! ```text
28//! [cofactor   f32 4B]
29//! [y_min      f32 4B]
30//! [y_step     f32 4B]   (0.0 if all values quantize to the same level)
31//! [n_values   u32 4B]
32//! [bits       u8  1B]
33//! [reserved   u8  1B]
34//! [pad        u16 2B]
35//! [packed bytes...]    when bits > 0
36//! ```
37
38use byteorder::{ByteOrder, LittleEndian};
39
40use crate::codec::{ChannelParams, CodecId, ColumnCodec, EncodeStats};
41use crate::error::{Error, Result};
42use crate::transform::asinh::{DEFAULT_COFACTOR, forward, inverse};
43
44const HEADER_BYTES: usize = 4 + 4 + 4 + 4 + 1 + 1 + 2;
45
46/// Tunable parameters for the lossy log-quant codec.
47#[derive(Debug, Clone, Copy)]
48pub struct LogQuantizationConfig {
49    pub cofactor: f32,
50    pub bits: u8,
51}
52
53impl Default for LogQuantizationConfig {
54    fn default() -> Self {
55        Self {
56            cofactor: DEFAULT_COFACTOR,
57            bits: 16,
58        }
59    }
60}
61
62#[derive(Debug, Clone, Copy)]
63pub struct LogQuantization {
64    pub cfg: LogQuantizationConfig,
65}
66
67impl Default for LogQuantization {
68    fn default() -> Self {
69        Self {
70            cfg: LogQuantizationConfig::default(),
71        }
72    }
73}
74
75impl LogQuantization {
76    pub fn new(cfg: LogQuantizationConfig) -> Self {
77        Self { cfg }
78    }
79}
80
81impl ColumnCodec for LogQuantization {
82    fn id(&self) -> CodecId {
83        CodecId::LogQuantization
84    }
85
86    fn encode_chunk(
87        &self,
88        input: &[f32],
89        _params: &ChannelParams,
90        out: &mut Vec<u8>,
91    ) -> Result<EncodeStats> {
92        if input.is_empty() {
93            return Err(Error::InvalidParams("LogQuantization: empty chunk"));
94        }
95        if !(4..=24).contains(&self.cfg.bits) {
96            return Err(Error::InvalidParams("LogQuantization: bits must be in 4..=24"));
97        }
98        if !self.cfg.cofactor.is_finite() || self.cfg.cofactor <= 0.0 {
99            return Err(Error::InvalidParams(
100                "LogQuantization: cofactor must be finite and > 0",
101            ));
102        }
103
104        // Pass 1: compute y values + extrema.
105        let mut ys: Vec<f32> = Vec::with_capacity(input.len());
106        let mut y_min = f32::INFINITY;
107        let mut y_max = f32::NEG_INFINITY;
108        for &x in input {
109            if !x.is_finite() {
110                return Err(Error::InvalidParams(
111                    "LogQuantization: encountered NaN or infinite input",
112                ));
113            }
114            let y = forward(x, self.cfg.cofactor);
115            y_min = y_min.min(y);
116            y_max = y_max.max(y);
117            ys.push(y);
118        }
119
120        let bits = self.cfg.bits;
121        let levels = (1u32 << bits) - 1;
122        let span = y_max - y_min;
123        let step = if span <= 0.0 {
124            0.0
125        } else {
126            span / levels as f32
127        };
128
129        // Header
130        let header_start = out.len();
131        out.resize(header_start + HEADER_BYTES, 0);
132        {
133            let h = &mut out[header_start..header_start + HEADER_BYTES];
134            LittleEndian::write_f32(&mut h[0..4], self.cfg.cofactor);
135            LittleEndian::write_f32(&mut h[4..8], y_min);
136            LittleEndian::write_f32(&mut h[8..12], step);
137            LittleEndian::write_u32(&mut h[12..16], input.len() as u32);
138            h[16] = bits;
139            h[17] = 0;
140            // pad bytes already zero
141        }
142
143        // Payload
144        if step > 0.0 {
145            let mask = if bits == 32 {
146                u32::MAX
147            } else {
148                (1u32 << bits) - 1
149            };
150            let mut staged: Vec<u32> = Vec::with_capacity(ys.len());
151            for &y in &ys {
152                let q = ((y - y_min) / step).round();
153                let q_clamped = q.clamp(0.0, levels as f32) as u32 & mask;
154                staged.push(q_clamped);
155            }
156            pack_bits_fast(&staged, bits, out);
157        }
158
159        let written = out.len() - header_start;
160        Ok(EncodeStats {
161            input_events: input.len() as u32,
162            input_bytes: (input.len() * 4) as u64,
163            output_bytes: written as u64,
164        })
165    }
166
167    fn decode_chunk(
168        &self,
169        payload: &[u8],
170        _params: &ChannelParams,
171        out: &mut [f32],
172    ) -> Result<()> {
173        if payload.len() < HEADER_BYTES {
174            return Err(Error::Truncated {
175                needed: HEADER_BYTES,
176                have: payload.len(),
177            });
178        }
179        let cofactor = LittleEndian::read_f32(&payload[0..4]);
180        let y_min = LittleEndian::read_f32(&payload[4..8]);
181        let step = LittleEndian::read_f32(&payload[8..12]);
182        let n_values = LittleEndian::read_u32(&payload[12..16]) as usize;
183        let bits = payload[16];
184
185        if !cofactor.is_finite() || cofactor <= 0.0 {
186            return Err(Error::InvalidParams(
187                "LogQuantization: payload cofactor invalid",
188            ));
189        }
190        if !(4..=24).contains(&bits) {
191            return Err(Error::InvalidParams("LogQuantization: payload bits out of range"));
192        }
193        if out.len() != n_values {
194            return Err(Error::LengthMismatch {
195                expected: n_values,
196                actual: out.len(),
197            });
198        }
199
200        if step <= 0.0 {
201            // Constant chunk
202            let x = inverse(y_min, cofactor);
203            for slot in out.iter_mut() {
204                *slot = x;
205            }
206            return Ok(());
207        }
208
209        let total_bits = n_values * bits as usize;
210        let needed = HEADER_BYTES + total_bits.div_ceil(8);
211        if payload.len() < needed {
212            return Err(Error::Truncated {
213                needed,
214                have: payload.len(),
215            });
216        }
217        let packed = &payload[HEADER_BYTES..];
218
219        // SIMD-style optimization: bulk-unpack quantization codes into a u32
220        // staging buffer, then dequantize. Two-pass form lets the compiler
221        // auto-vectorize the dequant arithmetic. For small bit widths (≤14)
222        // we additionally precompute a `cofactor * sinh(y_min + i*step)` LUT,
223        // skipping per-value sinh calls and yielding a 4–10× decode speedup.
224        let mut staging: Vec<u32> = vec![0; n_values];
225        unpack_bits_fast(packed, bits, n_values, &mut staging);
226
227        if bits <= 14 {
228            let levels = (1usize << bits).min(1 << 14);
229            let mut lut: Vec<f32> = Vec::with_capacity(levels);
230            for i in 0..levels {
231                let y = y_min + (i as f32) * step;
232                lut.push(inverse(y, cofactor));
233            }
234            for (slot, &q) in out.iter_mut().zip(staging.iter()) {
235                let idx = (q as usize).min(levels - 1);
236                *slot = lut[idx];
237            }
238        } else {
239            for (slot, &q) in out.iter_mut().zip(staging.iter()) {
240                let y = y_min + (q as f32) * step;
241                *slot = inverse(y, cofactor);
242            }
243        }
244        Ok(())
245    }
246}
247
248/// Pack u32 values into a contiguous LE bit-stream of `width`-bit fields.
249/// Bit-reservoir form, mirrors the packer in `adc_bitpack.rs`.
250fn pack_bits_fast(values: &[u32], width: u8, dst: &mut Vec<u8>) {
251    if width == 0 {
252        return;
253    }
254    let mask = if width >= 32 { u32::MAX } else { (1u32 << width) - 1 };
255    let mut buf: u64 = 0;
256    let mut buf_bits: u32 = 0;
257    for &v in values {
258        let masked = (v & mask) as u64;
259        buf |= masked << buf_bits;
260        buf_bits += width as u32;
261        if buf_bits >= 32 {
262            let four = (buf & 0xFFFF_FFFF) as u32;
263            dst.extend_from_slice(&four.to_le_bytes());
264            buf >>= 32;
265            buf_bits -= 32;
266        }
267    }
268    while buf_bits >= 8 {
269        dst.push((buf & 0xFF) as u8);
270        buf >>= 8;
271        buf_bits -= 8;
272    }
273    if buf_bits > 0 {
274        dst.push((buf & 0xFF) as u8);
275    }
276}
277
278/// Inverse of [`pack_bits_fast`].
279#[inline]
280fn unpack_bits_fast(src: &[u8], width: u8, n: usize, out: &mut [u32]) {
281    if width == 0 {
282        for slot in out.iter_mut().take(n) {
283            *slot = 0;
284        }
285        return;
286    }
287    let mask = if width >= 32 {
288        u32::MAX as u64
289    } else {
290        (1u64 << width) - 1
291    };
292    let mut buf: u64 = 0;
293    let mut buf_bits: u32 = 0;
294    let mut src_pos = 0usize;
295    let bytes_avail = src.len();
296    for slot in out.iter_mut().take(n) {
297        while buf_bits < width as u32 {
298            if src_pos + 4 <= bytes_avail && buf_bits + 32 <= 64 {
299                let four = u32::from_le_bytes([
300                    src[src_pos],
301                    src[src_pos + 1],
302                    src[src_pos + 2],
303                    src[src_pos + 3],
304                ]);
305                buf |= (four as u64) << buf_bits;
306                buf_bits += 32;
307                src_pos += 4;
308            } else if src_pos < bytes_avail {
309                buf |= (src[src_pos] as u64) << buf_bits;
310                buf_bits += 8;
311                src_pos += 1;
312            } else {
313                break;
314            }
315        }
316        *slot = (buf & mask) as u32;
317        buf >>= width;
318        buf_bits = buf_bits.saturating_sub(width as u32);
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    fn synth_log_channel(n: usize, seed: u64) -> Vec<f32> {
327        // Mix of small noise around zero and decade-spread positive values —
328        // shaped like a typical fluorescence channel.
329        let mut s = seed;
330        let mut v = Vec::with_capacity(n);
331        for i in 0..n {
332            s = s
333                .wrapping_mul(6364136223846793005)
334                .wrapping_add(1442695040888963407);
335            let u = ((s >> 32) as u32) as f32 / u32::MAX as f32;
336            let base = if i % 5 == 0 {
337                (u - 0.5) * 50.0
338            } else {
339                10f32.powf(u * 5.0)
340            };
341            v.push(base);
342        }
343        v
344    }
345
346    fn params() -> ChannelParams {
347        ChannelParams {
348            name: "fluo".into(),
349            stored_bits: 32,
350            range: 262_144,
351            log_decades: (5.0, 0.0),
352            adc_bits: None,
353            signed: true,
354        }
355    }
356
357    #[test]
358    fn round_trip_within_tolerance_at_16_bits() {
359        let codec = LogQuantization::default();
360        let p = params();
361        let input = synth_log_channel(8192, 42);
362
363        let mut payload = Vec::new();
364        codec.encode_chunk(&input, &p, &mut payload).unwrap();
365        let mut out = vec![0.0f32; input.len()];
366        codec.decode_chunk(&payload, &p, &mut out).unwrap();
367
368        let mut max_rel = 0f32;
369        for (a, b) in input.iter().zip(out.iter()) {
370            if a.abs() > 100.0 {
371                max_rel = max_rel.max(((a - b).abs()) / a.abs());
372            }
373        }
374        // 16-bit log-quant should give well under 0.1% relative error on log-scale data.
375        assert!(max_rel < 1e-3, "max rel err = {max_rel}");
376    }
377
378    #[test]
379    fn smaller_bits_smaller_payload() {
380        let p = params();
381        let input = synth_log_channel(4096, 7);
382
383        let mut p16 = Vec::new();
384        LogQuantization::new(LogQuantizationConfig {
385            cofactor: 150.0,
386            bits: 16,
387        })
388        .encode_chunk(&input, &p, &mut p16)
389        .unwrap();
390
391        let mut p8 = Vec::new();
392        LogQuantization::new(LogQuantizationConfig {
393            cofactor: 150.0,
394            bits: 8,
395        })
396        .encode_chunk(&input, &p, &mut p8)
397        .unwrap();
398
399        assert!(
400            p8.len() < p16.len(),
401            "8-bit payload ({}) should be smaller than 16-bit ({})",
402            p8.len(),
403            p16.len()
404        );
405    }
406
407    #[test]
408    fn beats_raw_f32_on_log_data() {
409        let p = params();
410        let input = synth_log_channel(65_536, 1);
411        let raw_bytes = input.len() * 4;
412
413        // 16-bit Mode C: ~16 bits/value = 50% of raw f32 (header overhead pushes
414        // it just over 50%, so test with a looser bound).
415        let mut p16 = Vec::new();
416        LogQuantization::default().encode_chunk(&input, &p, &mut p16).unwrap();
417        assert!(
418            p16.len() < raw_bytes,
419            "16-bit Mode C ({}) should be smaller than raw f32 ({})",
420            p16.len(),
421            raw_bytes
422        );
423
424        // 12-bit Mode C: 12/32 = 37.5% of raw f32, decisively smaller.
425        let mut p12 = Vec::new();
426        LogQuantization::new(LogQuantizationConfig {
427            cofactor: 150.0,
428            bits: 12,
429        })
430        .encode_chunk(&input, &p, &mut p12)
431        .unwrap();
432        assert!(
433            p12.len() * 2 < raw_bytes,
434            "12-bit Mode C ({}) should be < 50% of raw f32 ({})",
435            p12.len(),
436            raw_bytes
437        );
438    }
439
440    #[test]
441    fn rejects_invalid_bits() {
442        let codec = LogQuantization::new(LogQuantizationConfig {
443            cofactor: 150.0,
444            bits: 2,
445        });
446        let p = params();
447        let input = vec![1.0f32; 64];
448        let mut payload = Vec::new();
449        let err = codec.encode_chunk(&input, &p, &mut payload).unwrap_err();
450        assert!(matches!(err, Error::InvalidParams(_)));
451    }
452
453    #[test]
454    fn rejects_nan() {
455        let codec = LogQuantization::default();
456        let p = params();
457        let mut input = synth_log_channel(64, 1);
458        input[5] = f32::NAN;
459        let mut payload = Vec::new();
460        let err = codec.encode_chunk(&input, &p, &mut payload).unwrap_err();
461        assert!(matches!(err, Error::InvalidParams(_)));
462    }
463
464    #[test]
465    fn constant_chunk_roundtrips() {
466        let codec = LogQuantization::default();
467        let p = params();
468        let input = vec![137.5f32; 256];
469        let mut payload = Vec::new();
470        codec.encode_chunk(&input, &p, &mut payload).unwrap();
471        // Constant chunk → header only, step = 0.
472        assert_eq!(payload.len(), HEADER_BYTES);
473
474        let mut out = vec![0.0f32; input.len()];
475        codec.decode_chunk(&payload, &p, &mut out).unwrap();
476        for (a, b) in input.iter().zip(out.iter()) {
477            assert!((a - b).abs() < 1e-3, "{} vs {}", a, b);
478        }
479    }
480}