atlas-archive-core 1.1.0

High-performance compression library with adaptive context modeling (Loom) and .nyx archives
Documentation
//! Burrows-Wheeler Transform (BWT) for Atlas preprocessing.
//! Permutes the input to group identical characters together, improving MTF performance.

use crate::alloc::vec;
use crate::alloc::vec::Vec;

/// Burrows-Wheeler Transform.
pub fn bwt_transform(data: &[u8]) -> (Vec<u8>, usize) {
    let n = data.len();
    if n == 0 {
        return (vec![], 0);
    }

    // Creating a suffix array of indices.
    // For large files, we would use a more advanced suffix array construction (DC3/SA-IS).
    // For versions, we use a simpler approach: sort indices by their cyclic shifts.
    let mut indices: Vec<usize> = (0..n).collect();

    indices.sort_by(|&a, &b| {
        // Cyclic BWT: compare ALL N bytes to ensure correct rotation ordering.
        // This is O(N^2) worst case but necessary for correct inverse.
        // For large data, use bwt_encode_safe which blocks into smaller chunks.
        for i in 0..n {
            let char_a = data[(a + i) % n];
            let char_b = data[(b + i) % n];
            if char_a != char_b {
                return char_a.cmp(&char_b);
            }
        }
        // For truly equal rotations (only possible when data is a perfect repeat),
        // use index as stable tie-breaker
        a.cmp(&b)
    });

    let mut bwt = Vec::with_capacity(n);
    let mut primary_index = 0;

    for (i, &idx) in indices.iter().enumerate() {
        if idx == 0 {
            primary_index = i;
            bwt.push(data[n - 1]);
        } else {
            bwt.push(data[idx - 1]);
        }
    }

    (bwt, primary_index)
}

/// Block size for safe/blocked BWT operations.
pub const BWT_BLOCK_SIZE: usize = 32 * 1024;

/// Blocked BWT transform that returns external indices.
/// Splits data into chunks of BWT_BLOCK_SIZE, applies BWT to each,
/// and returns concatenated BWT data and a vector of primary indices.
pub fn bwt_transform_blocked(data: &[u8]) -> (Vec<u8>, Vec<usize>) {
    if data.is_empty() {
        return (Vec::new(), Vec::new());
    }

    let mut out = Vec::with_capacity(data.len());
    let mut indices = Vec::new();

    for chunk in data.chunks(BWT_BLOCK_SIZE) {
        let (chunk_bwt, idx) = bwt_transform(chunk);
        out.extend_from_slice(&chunk_bwt);
        indices.push(idx);
    }

    (out, indices)
}

/// Blocked BWT inverse that uses external indices.
/// Splits data into chunks of BWT_BLOCK_SIZE and uses corresponding index for each.
pub fn bwt_inverse_blocked(data: &[u8], indices: &[usize]) -> Vec<u8> {
    if data.is_empty() {
        return Vec::new();
    }

    let mut out = Vec::with_capacity(data.len());
    let mut idx_iter = indices.iter();

    for chunk in data.chunks(BWT_BLOCK_SIZE) {
        if let Some(&idx) = idx_iter.next() {
            let chunk_orig = bwt_inverse(chunk, idx);
            out.extend_from_slice(&chunk_orig);
        } else {
            // Should not happen if indices match chunks
            #[cfg(feature = "std")]
            std::println!("[BWT Inverse Blocked] Error: Not enough indices");
            break;
        }
    }

    out
}

/// Inverse Burrows-Wheeler Transform using the Move-To-Front refinement.
pub fn bwt_inverse(bwt: &[u8], primary_index: usize) -> Vec<u8> {
    let n = bwt.len();
    if n == 0 {
        return vec![];
    }

    // T[i] maps to the sorted occurrence of the character.
    let mut count = [0usize; 256];
    for &b in bwt {
        count[b as usize] += 1;
    }

    let mut sum = 0;
    let mut first = [0usize; 256];
    for i in 0..256 {
        first[i] = sum;
        sum += count[i];
    }

    let mut t = vec![0usize; n];
    let mut char_count = [0usize; 256];
    for (i, &b) in bwt.iter().enumerate() {
        t[first[b as usize] + char_count[b as usize]] = i;
        char_count[b as usize] += 1;
    }

    let mut out = Vec::with_capacity(n);
    let mut curr = t[primary_index];
    for _ in 0..n {
        out.push(bwt[curr]);
        curr = t[curr];
    }

    out
}

/// Safe BWT encoding using blocking.
/// Returns (encoded_data, 0). Indices are embedded in the stream.
/// Format: [NumChunks: u32] [[Len: u32] [PrimaryIdx: u32] [Data...]]...
pub fn bwt_encode_safe(data: &[u8]) -> (Vec<u8>, usize) {
    let n = data.len();
    if n == 0 {
        return (vec![], 0);
    }

    let mut out = Vec::with_capacity(n + n / BWT_BLOCK_SIZE * 8 + 4);
    let chunks = data.chunks(BWT_BLOCK_SIZE);
    let num_chunks = chunks.len() as u32;

    out.extend_from_slice(&num_chunks.to_le_bytes());

    for chunk in chunks {
        let (bwt_chunk, idx) = bwt_transform(chunk);
        let len = bwt_chunk.len() as u32;
        let idx = idx as u32;

        out.extend_from_slice(&len.to_le_bytes());
        out.extend_from_slice(&idx.to_le_bytes());
        out.extend_from_slice(&bwt_chunk);
    }

    (out, 0) // Dummy index
}

/// Safe BWT decoding using blocking.
/// Ignores the passed primary_index (embedded in stream).
pub fn bwt_decode_safe(data: &[u8], _ignore_idx: usize) -> Vec<u8> {
    if data.len() < 4 {
        #[cfg(feature = "std")]
        std::println!("[BWT Decode] Error: data too short ({})", data.len());
        return vec![];
    }

    let mut pos = 0;
    let num_chunks_bytes: [u8; 4] = data[pos..pos + 4].try_into().unwrap_or([0; 4]);
    let num_chunks = u32::from_le_bytes(num_chunks_bytes);
    pos += 4;

    #[cfg(feature = "std")]
    std::println!(
        "[BWT Decode] data_len={}, num_chunks={}, first_bytes={:02x?}",
        data.len(),
        num_chunks,
        &data[..data.len().min(16)]
    );

    let mut out = Vec::new();

    for i in 0..num_chunks {
        if pos + 8 > data.len() {
            #[cfg(feature = "std")]
            std::println!(
                "[BWT Decode] Break: pos={} + 8 > data_len={}",
                pos,
                data.len()
            );
            break; // Error/Truncated
        }
        let len_bytes: [u8; 4] = data[pos..pos + 4].try_into().unwrap_or([0; 4]);
        let len = u32::from_le_bytes(len_bytes) as usize;
        pos += 4;

        let idx_bytes: [u8; 4] = data[pos..pos + 4].try_into().unwrap_or([0; 4]);
        let idx = u32::from_le_bytes(idx_bytes) as usize;
        pos += 4;

        #[cfg(feature = "std")]
        std::println!(
            "[BWT Decode] Chunk {}: len={}, idx={}, pos={}, remaining={}",
            i,
            len,
            idx,
            pos,
            data.len() - pos
        );

        if pos + len > data.len() {
            #[cfg(feature = "std")]
            std::println!(
                "[BWT Decode] Break: pos={} + len={} > data_len={}",
                pos,
                len,
                data.len()
            );
            break; // Error
        }

        let chunk_data = &data[pos..pos + len];
        pos += len;

        let decoded_chunk = bwt_inverse(chunk_data, idx);
        out.extend_from_slice(&decoded_chunk);
    }

    #[cfg(feature = "std")]
    std::println!("[BWT Decode] Final output: {} bytes", out.len());

    out
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_bwt_roundtrip_simple() {
        let data = b"banana";
        let (bwt, idx) = bwt_transform(data);
        let restored = bwt_inverse(&bwt, idx);
        assert_eq!(data.to_vec(), restored, "Simple roundtrip failed");
    }

    #[test]
    fn test_bwt_roundtrip_8kb() {
        let base = b"Compression Test Data. ";
        let mut data = Vec::with_capacity(8192);
        while data.len() < 8192 {
            data.extend_from_slice(base);
        }
        data.truncate(8192);
        let (bwt, idx) = bwt_transform(&data);
        let restored = bwt_inverse(&bwt, idx);
        assert_eq!(data, restored, "8KB roundtrip failed");
    }
}