1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
use std::io::Write;

use crate::bit_reader::BitReader;
use crate::bit_writer::BitWriter;
use crate::data_types::UnsignedLike;
use crate::errors::PcoResult;

#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct DeltaMoments<U: UnsignedLike> {
  // length = delta encoding order
  pub moments: Vec<U>,
}

impl<U: UnsignedLike> DeltaMoments<U> {
  fn new(moments: Vec<U>) -> Self {
    Self { moments }
  }

  pub fn parse_from(reader: &mut BitReader, order: usize) -> PcoResult<Self> {
    let mut moments = Vec::new();
    for _ in 0..order {
      moments.push(reader.read_uint::<U>(U::BITS));
    }
    Ok(DeltaMoments { moments })
  }

  pub fn write_to<W: Write>(&self, writer: &mut BitWriter<W>) {
    for &moment in &self.moments {
      writer.write_uint(moment, U::BITS);
    }
  }

  pub fn order(&self) -> usize {
    self.moments.len()
  }
}

// Without this, deltas in, say, [-5, 5] would be split out of order into
// [U::MAX - 4, U::MAX] and [0, 5].
// This can be used to convert from
// * unsigned deltas -> (effectively) signed deltas; encoding
// * signed deltas -> unsigned deltas; decoding
#[inline(never)]
pub fn toggle_center_in_place<U: UnsignedLike>(unsigneds: &mut [U]) {
  for u in unsigneds.iter_mut() {
    *u = u.wrapping_add(U::MID);
  }
}

fn first_order_encode_in_place<U: UnsignedLike>(unsigneds: &mut [U]) {
  if unsigneds.is_empty() {
    return;
  }

  for i in 0..unsigneds.len() - 1 {
    unsigneds[i] = unsigneds[i + 1].wrapping_sub(unsigneds[i]);
  }
}

// used for a single page, so we return the delta moments
#[inline(never)]
pub fn encode_in_place<U: UnsignedLike>(mut latents: &mut [U], order: usize) -> DeltaMoments<U> {
  // TODO this function could be made faster by doing all steps on mini batches
  // of ~512 at a time
  if order == 0 {
    // exit early so we don't toggle to signed values
    return DeltaMoments::default();
  }

  let mut page_moments = Vec::with_capacity(order);
  for _ in 0..order {
    page_moments.push(latents.get(0).copied().unwrap_or(U::ZERO));

    first_order_encode_in_place(latents);
    let truncated_len = latents.len().saturating_sub(1);
    latents = &mut latents[0..truncated_len];
  }
  toggle_center_in_place(latents);

  DeltaMoments::new(page_moments)
}

fn first_order_decode_in_place<U: UnsignedLike>(moment: &mut U, unsigneds: &mut [U]) {
  for delta in unsigneds.iter_mut() {
    let tmp = *delta;
    *delta = *moment;
    *moment = moment.wrapping_add(tmp);
  }
}

// used for a single batch, so we mutate the delta moments
#[inline(never)]
pub fn decode_in_place<U: UnsignedLike>(delta_moments: &mut DeltaMoments<U>, unsigneds: &mut [U]) {
  if delta_moments.order() == 0 {
    // exit early so we don't toggle to signed values
    return;
  }

  toggle_center_in_place(unsigneds);
  for moment in delta_moments.moments.iter_mut().rev() {
    first_order_decode_in_place(moment, unsigneds);
  }
}

#[cfg(test)]
mod tests {
  use super::*;

  #[test]
  fn test_delta_encode_decode() {
    let orig_unsigneds: Vec<u32> = vec![2, 2, 1, u32::MAX, 0];
    let mut deltas = orig_unsigneds.to_vec();
    let order = 2;
    let zero_delta = u32::MID;
    let mut moments = encode_in_place(&mut deltas, order);

    // add back some padding we lose during compression
    for _ in 0..order {
      deltas.push(zero_delta);
    }

    decode_in_place::<u32>(&mut moments, &mut deltas[..3]);
    assert_eq!(&deltas[..3], &orig_unsigneds[..3]);

    decode_in_place::<u32>(&mut moments, &mut deltas[3..]);
    assert_eq!(&deltas[3..5], &orig_unsigneds[3..5]);
  }
}