use std::{
array, cmp,
mem::{self, MaybeUninit},
};
use crate::{
constants::{Bitlen, DeltaLookback},
data_types::Latent,
metadata::DeltaLookbackConfig,
FULL_BATCH_N,
};
const PROPOSED_LOOKBACKS: usize = 16;
const BRUTE_LOOKBACKS: usize = 6;
const REPEATING_LOOKBACKS: usize = 4;
const COARSENESSES: [Bitlen; 2] = [0, 8];
fn hash_lookup(
l: u64,
i: usize,
hash_table_n: usize,
window_n: usize,
idx_hash_table: &mut [usize],
proposed_lookbacks: &mut [usize; PROPOSED_LOOKBACKS],
) {
let hash_mask = hash_table_n - 1;
let hash_fn = |mut x: u64| {
x = (x ^ (x >> 32)).wrapping_mul(11400714819323197441);
x = x ^ (x >> 32);
x as usize & hash_mask
};
let mut proposal_idx = BRUTE_LOOKBACKS + REPEATING_LOOKBACKS;
let mut offset = 0;
for coarseness in COARSENESSES {
let bucket = l >> coarseness;
let buckets = [bucket.wrapping_sub(1), bucket, bucket.wrapping_add(1)];
let hashes = buckets.map(hash_fn);
for h in hashes {
let lookback_to_last_instance = unsafe { i - *idx_hash_table.get_unchecked(offset + h) };
proposed_lookbacks[proposal_idx] = if lookback_to_last_instance <= window_n {
lookback_to_last_instance
} else {
proposal_idx.min(i)
};
proposal_idx += 1;
}
let h = hashes[1];
unsafe {
*idx_hash_table.get_unchecked_mut(offset + h) = i;
}
offset += hash_table_n;
}
}
#[inline(never)]
fn find_best_lookback<L: Latent>(
l: L,
i: usize,
latents: &[L],
proposed_lookbacks: &[usize; PROPOSED_LOOKBACKS],
lookback_counts: &mut [u32],
) -> usize {
let mut best_goodness = 0;
let mut best_lookback: usize = 0;
for &lookback in proposed_lookbacks {
let (lookback_count, other) = unsafe {
(
*lookback_counts.get_unchecked(lookback - 1),
*latents.get_unchecked(i - lookback),
)
};
let lookback_goodness = Bitlen::BITS - lookback_count.leading_zeros();
let delta = L::min(l.wrapping_sub(other), other.wrapping_sub(l));
let delta_goodness = delta.leading_zeros();
let goodness = lookback_goodness + delta_goodness;
if goodness > best_goodness {
best_goodness = goodness;
best_lookback = lookback;
}
}
best_lookback
}
#[inline(never)]
pub fn choose_lookbacks<L: Latent>(
config: DeltaLookbackConfig,
latents: &[L],
) -> Vec<DeltaLookback> {
let state_n = config.state_n();
if latents.len() <= state_n {
return vec![];
}
let hash_table_n_log = config.window_n_log + 1;
let hash_table_n = 1 << hash_table_n_log;
let window_n = config.window_n();
assert!(
window_n >= PROPOSED_LOOKBACKS,
"we do not support tiny windows during compression"
);
let mut lookback_counts = vec![1_u32; window_n.min(latents.len())];
let mut lookbacks = vec![MaybeUninit::uninit(); latents.len() - state_n];
let mut idx_hash_table = vec![0_usize; COARSENESSES.len() * hash_table_n];
let mut proposed_lookbacks = array::from_fn::<_, PROPOSED_LOOKBACKS, _>(|i| (i + 1).min(state_n));
let mut best_lookback = 1;
let mut repeating_lookback_idx: usize = 0;
for i in state_n..latents.len() {
let l = latents[i];
let new_brute_lookback = i.min(PROPOSED_LOOKBACKS);
proposed_lookbacks[new_brute_lookback - 1] = new_brute_lookback;
hash_lookup(
l.to_u64(),
i,
hash_table_n,
window_n,
&mut idx_hash_table,
&mut proposed_lookbacks,
);
let new_best_lookback = find_best_lookback(
l,
i,
latents,
&proposed_lookbacks,
&mut lookback_counts,
);
if new_best_lookback != best_lookback {
repeating_lookback_idx += 1;
}
proposed_lookbacks[BRUTE_LOOKBACKS + (repeating_lookback_idx) % REPEATING_LOOKBACKS] =
new_best_lookback;
best_lookback = new_best_lookback;
lookbacks[i - state_n] = MaybeUninit::new(best_lookback as DeltaLookback);
lookback_counts[best_lookback - 1] += 1;
}
unsafe { mem::transmute::<Vec<MaybeUninit<DeltaLookback>>, Vec<DeltaLookback>>(lookbacks) }
}
#[inline(never)]
pub fn encode_in_place<L: Latent>(
config: DeltaLookbackConfig,
lookbacks: &[DeltaLookback],
latents: &mut [L],
) -> Vec<L> {
let state_n = config.state_n();
let real_state_n = cmp::min(latents.len(), state_n);
for i in (real_state_n..latents.len()).rev() {
let lookback = lookbacks[i - state_n] as usize;
latents[i] = latents[i].wrapping_sub(latents[i - lookback])
}
let mut state = vec![L::ZERO; state_n];
state[state_n - real_state_n..].copy_from_slice(&latents[..real_state_n]);
super::toggle_center_in_place(latents);
state
}
pub fn new_window_buffer_and_pos<L: Latent>(
config: DeltaLookbackConfig,
state: &[L],
) -> (Vec<L>, usize) {
let window_n = config.window_n();
let buffer_n = cmp::max(window_n, FULL_BATCH_N) * 2;
let mut res = vec![L::ZERO; buffer_n];
res[window_n - state.len()..window_n].copy_from_slice(state);
(res, window_n)
}
pub fn decode_in_place<L: Latent>(
config: DeltaLookbackConfig,
lookbacks: &[DeltaLookback],
window_buffer_pos: &mut usize,
window_buffer: &mut [L],
latents: &mut [L],
) -> bool {
super::toggle_center_in_place(latents);
let (window_n, state_n) = (config.window_n(), config.state_n());
let mut start_pos = *window_buffer_pos;
let batch_n = latents.len();
if start_pos + batch_n > window_buffer.len() {
window_buffer.copy_within(start_pos - window_n..start_pos, 0);
start_pos = window_n;
}
let mut has_oob_lookbacks = false;
for (i, (&latent, &lookback)) in latents.iter().zip(lookbacks).enumerate() {
let pos = start_pos + i;
let lookback = if lookback <= window_n as DeltaLookback {
lookback as usize
} else {
has_oob_lookbacks = true;
1
};
unsafe {
*window_buffer.get_unchecked_mut(pos) =
latent.wrapping_add(*window_buffer.get_unchecked(pos - lookback));
}
}
let end_pos = start_pos + batch_n;
latents.copy_from_slice(&window_buffer[start_pos - state_n..end_pos - state_n]);
*window_buffer_pos = end_pos;
has_oob_lookbacks
}
#[cfg(test)]
mod tests {
use super::*;
use crate::metadata::DeltaLookbackConfig;
#[test]
fn test_lookback_encode_decode() {
let original_latents = {
let mut res = vec![100_u32; 100];
res[1] = 200;
res[2] = 201;
res[3] = 202;
res[5] = 203;
res[15] = 204;
res[50] = 205;
res
};
let config = DeltaLookbackConfig {
window_n_log: 4,
state_n_log: 1,
};
let window_n = config.window_n();
assert_eq!(window_n, 16);
let state_n = config.state_n();
assert_eq!(state_n, 2);
let mut deltas = original_latents.clone();
let lookbacks = choose_lookbacks(config, &original_latents);
assert_eq!(lookbacks[0], 1); assert_eq!(lookbacks[2], 4); assert_eq!(lookbacks[13], 10); assert_eq!(lookbacks[48], 1);
let state = encode_in_place(config, &lookbacks, &mut deltas);
assert_eq!(state, vec![100, 200]);
let mut deltas_to_decode = Vec::<u32>::new();
deltas_to_decode.extend(&deltas[state_n..]);
for _ in 0..state_n {
deltas_to_decode.push(1337);
}
let (mut window_buffer, mut pos) = new_window_buffer_and_pos(config, &state);
assert_eq!(pos, window_n);
let has_oob_lookbacks = decode_in_place(
config,
&lookbacks,
&mut pos,
&mut window_buffer,
&mut deltas_to_decode,
);
assert!(!has_oob_lookbacks);
assert_eq!(deltas_to_decode, original_latents);
assert_eq!(pos, window_n + original_latents.len());
}
#[test]
fn test_corrupt_lookbacks_do_not_panic() {
let config = DeltaLookbackConfig {
state_n_log: 0,
window_n_log: 2,
};
let delta_state = vec![0_u32];
let lookbacks = vec![5, 1, 1, 1];
let mut latents = vec![1_u32, 2, 3, 4];
let (mut window_buffer, mut pos) = new_window_buffer_and_pos(config, &delta_state);
let has_oob_lookbacks = decode_in_place(
config,
&lookbacks,
&mut pos,
&mut window_buffer,
&mut latents,
);
assert!(has_oob_lookbacks);
}
}