atlas-archive-core 1.1.0

High-performance compression library with adaptive context modeling (Loom) and .nyx archives
Documentation
use crate::ans::ProbModel;
use core::f32::consts::E;
#[cfg(feature = "simd")]
use core::simd::{SimdFloat, f32x8};
use ndarray::prelude::*;

use crate::PlcConfig;

/// Minimal transformer layer for byte prediction.
pub struct TransformerPredictor {
    /// W_e: [256, D_MODEL] - Byte embedding lookup.
    embedding: Array2<f32>,
    /// W_out: [256, D_MODEL] - Mapping back to symbol scores (Transposed for contiguous access).
    w_out: Array2<f32>,
    config: PlcConfig,
}

impl TransformerPredictor {
    pub fn new(config: PlcConfig) -> Self {
        let d_model = config.model_dim;
        // Initialize with "identity-like" weights so embeddings project back to themselves.
        let mut embedding = Array2::zeros((256, d_model));
        let mut w_out = Array2::zeros((256, d_model));

        for i in 0..256 {
            // Map byte i to a 1.0 at index i % d_model
            embedding[[i, i % d_model]] = 1.0;
            // Map d_model feature back to byte i score
            // Transposed: w_out row i contains weights for byte i
            w_out[[i, i % d_model]] = 10.0; // Moderate bias
        }

        Self {
            embedding,
            w_out,
            config,
        }
    }

    /// Predict probabilities for the NEXT symbol given the recent history.
    pub fn predict(&self, history: &[u8], mixer_probs: Option<[u64; 256]>) -> ProbModel {
        let d_model = self.config.model_dim;
        let window_size = self.config.window_size;

        // 1. Take last window_size bytes.
        let start = history.len().saturating_sub(window_size);
        let ctx = &history[start..];

        // 2. Pooling (Avg of embeddings)
        // Optimization: Sum directly from embedding (remove intermediate allocation x)
        let mut pooled: Array1<f32> = Array1::zeros(d_model);

        if !ctx.is_empty() {
            for &b in ctx {
                let row = self.embedding.row(b as usize);
                for (j, &val) in row.iter().enumerate() {
                    pooled[j] += val;
                }
            }
            let inv_len = 1.0 / ctx.len() as f32;
            for val in pooled.iter_mut() {
                *val *= inv_len;
            }
        }

        // 3. Output projection
        // logits = w_out . pooled
        // w_out is [256, d_model] (transposed), pooled is [d_model]
        // result is [256]
        // SIMD optimized Dot Product

        let mut logits_arr = [0.0f32; 256];

        #[cfg(feature = "simd")]
        {
            let w_slice = self.w_out.as_slice().unwrap();
            let p_slice = pooled.as_slice().unwrap();

            for i in 0..256 {
                let row_start = i * d_model;
                let row = &w_slice[row_start..row_start + d_model];

                let mut sum = f32x8::splat(0.0);
                let mut k = 0;
                while k + 8 <= d_model {
                    let r_vec = f32x8::from_slice(&row[k..k + 8]);
                    let p_vec = f32x8::from_slice(&p_slice[k..k + 8]);
                    sum += r_vec * p_vec;
                    k += 8;
                }
                let mut reduced = sum.reduce_sum();
                for j in k..d_model {
                    reduced += row[j] * p_slice[j];
                }
                logits_arr[i] = reduced;
            }
        }

        #[cfg(not(feature = "simd"))]
        {
            // Use manual loop for dot product to avoid contiguous slice unwrap.
            for i in 0..256 {
                let mut sum = 0.0f32;
                let row = self.w_out.row(i);
                for j in 0..d_model {
                    sum += row[j] * pooled[j];
                }
                logits_arr[i] = sum;
            }
        }

        // 4. Stable Softmax (Scalar exp is fine)
        let mut max_logit = logits_arr[0];
        for i in 1..256 {
            if logits_arr[i] > max_logit {
                max_logit = logits_arr[i];
            }
        }

        let mut sum_exp = 0.0f64;
        let mut exps = [0.0f64; 256];
        for i in 0..256 {
            let e = E.powf(logits_arr[i] - max_logit) as f64;
            exps[i] = e;
            sum_exp += e;
        }

        // 5. Integrate Mixer Boost (PAQ Stage).
        if let Some(mixer_probs) = mixer_probs {
            let mut mixed_freq = [0u32; 256];
            let mut acc = 0u32;
            for (i, (p, count)) in mixed_freq.iter_mut().zip(mixer_probs.iter()).enumerate() {
                let t_p = (exps[i] / sum_exp * 65536.0) as u32;
                let m_p = *count as u32;
                // Weighted Blend: Give high weight to mixer if it has high-order matches.
                let blended = (t_p * 1 + m_p * 99) / 100;
                *p = blended.max(1);
                acc += *p;
            }

            // Robust adjustment to hit exactly 65536
            while acc != 65536 {
                if acc < 65536 {
                    let diff = 65536 - acc;
                    let mut best_i = 0;
                    for i in 1..256 {
                        if mixed_freq[i] > mixed_freq[best_i] {
                            best_i = i;
                        }
                    }
                    mixed_freq[best_i] += diff;
                    acc += diff;
                } else {
                    let mut diff = acc - 65536;
                    // Subtract from largest symbols, but don't go below 1
                    for i in 0..256 {
                        if mixed_freq[i] > 1 {
                            let can_take = (mixed_freq[i] - 1).min(diff);
                            mixed_freq[i] -= can_take;
                            diff -= can_take;
                            acc -= can_take;
                        }
                        if diff == 0 {
                            break;
                        }
                    }
                }
            }
            ProbModel::from_scaled_freqs(mixed_freq)
        } else {
            let scale = 65536.0 / sum_exp;
            let mut probs = [0u32; 256];
            let mut acc = 0u32;
            for i in 0..256 {
                let p = (exps[i] * scale) as u32;
                let p = p.max(1);
                probs[i] = p;
                acc += p;
            }

            while acc != 65536 {
                if acc < 65536 {
                    let diff = 65536 - acc;
                    let mut best_i = 0;
                    for i in 1..256 {
                        if probs[i] > probs[best_i] {
                            best_i = i;
                        }
                    }
                    probs[best_i] += diff;
                    acc += diff;
                } else {
                    let mut diff = acc - 65536;
                    for i in 0..256 {
                        if probs[i] > 1 {
                            let can_take = (probs[i] - 1).min(diff);
                            probs[i] -= can_take;
                            diff -= can_take;
                            acc -= can_take;
                        }
                        if diff == 0 {
                            break;
                        }
                    }
                }
            }
            ProbModel::from_scaled_freqs(probs)
        }
    }
}