aprender-core 0.31.2

Next-generation machine learning library in pure Rust

/// Dequantize `Q6_K` format (K-quants)
/// `Q6_K`: super blocks of 256 elements
/// Each super block: ql (128 bytes) + qh (64 bytes) + scales (16 bytes) + d (f16) = 210 bytes
///
/// Delegates to `trueno_quant::dequantize_q6_k_to_f32` — the single source of truth.
#[ensures(ret.as_ref().map_or(true, |v| v.len() == num_elements))]
pub(crate) fn dequantize_q6_k(data: &[u8], start: usize, num_elements: usize) -> Result<Vec<f32>> {
    const SUPER_BLOCK_SIZE: usize = 256;
    const SUPER_BLOCK_BYTES: usize = 210;

    let num_blocks = num_elements.div_ceil(SUPER_BLOCK_SIZE);
    let total_bytes = num_blocks * SUPER_BLOCK_BYTES;

    if start + total_bytes > data.len() {
        return Err(AprenderError::FormatError {
            message: "Q6_K data exceeds file size".to_string(),
        });
    }

    Ok(trueno_quant::dequantize_q6_k_to_f32(
        &data[start..],
        num_elements,
    ))
}

/// Dequantize `Q4_1` format
/// `Q4_1`: blocks of 32 elements, each block has f16 scale + f16 min + 16 bytes of 4-bit quants
///
/// PMAT-231 FIX: Element order matches llama.cpp/GGML layout:
/// - Low nibbles first (elements 0-15)
/// - High nibbles second (elements 16-31)
#[ensures(ret.as_ref().map_or(true, |v| v.len() == num_elements))]
pub fn dequantize_q4_1(data: &[u8], start: usize, num_elements: usize) -> Result<Vec<f32>> {
    const BLOCK_SIZE: usize = 32;
    const BLOCK_BYTES: usize = 2 + 2 + 16; // f16 scale + f16 min + 16 bytes

    let num_blocks = num_elements.div_ceil(BLOCK_SIZE);
    let total_bytes = num_blocks * BLOCK_BYTES;

    if start + total_bytes > data.len() {
        return Err(AprenderError::FormatError {
            message: "Q4_1 data exceeds file size".to_string(),
        });
    }

    let mut result = Vec::with_capacity(num_elements);
    let mut offset = start;

    for _ in 0..num_blocks {
        // GH-186 FIX: Use safe_f16_scale to clamp NaN/Inf/subnormal
        let scale = safe_f16_scale(u16::from_le_bytes([data[offset], data[offset + 1]]));
        let min = safe_f16_scale(u16::from_le_bytes([data[offset + 2], data[offset + 3]]));
        offset += 4;

        // PMAT-231: Low nibbles first (elements 0-15)
        for i in 0..16 {
            let byte = data[offset + i];
            let v0 = f32::from(byte & 0x0F) * scale + min;
            result.push(v0);
        }

        // PMAT-231: High nibbles second (elements 16-31)
        for i in 0..16 {
            let byte = data[offset + i];
            let v1 = f32::from(byte >> 4) * scale + min;
            result.push(v1);
        }

        offset += 16;
    }

    result.truncate(num_elements);
    Ok(result)
}

/// Dequantize `Q2_K` format (K-quants)
/// `Q2_K`: super blocks of 256 elements
#[ensures(ret.as_ref().map_or(true, |v| v.len() == num_elements))]
pub(crate) fn dequantize_q2_k(data: &[u8], start: usize, num_elements: usize) -> Result<Vec<f32>> {
    const SUPER_BLOCK_SIZE: usize = 256;
    const SUPER_BLOCK_BYTES: usize = 2 + 2 + 16 + 64; // d, dmin, scales, qs = 84 bytes

    let num_blocks = num_elements.div_ceil(SUPER_BLOCK_SIZE);
    let total_bytes = num_blocks * SUPER_BLOCK_BYTES;

    if start + total_bytes > data.len() {
        return Err(AprenderError::FormatError {
            message: "Q2_K data exceeds file size".to_string(),
        });
    }

    let mut result = Vec::with_capacity(num_elements);
    let mut offset = start;

    for _ in 0..num_blocks {
        // Read scales (16 bytes = 16 4-bit scale/min pairs)
        let scales_bytes = &data[offset..offset + 16];
        offset += 16;

        // Read qs (64 bytes = 256 2-bit values)
        let qs = &data[offset..offset + 64];
        offset += 64;

        // Read d and dmin
        // GH-186 FIX: Use safe_f16_scale to clamp NaN/Inf/subnormal
        let d = safe_f16_scale(u16::from_le_bytes([data[offset], data[offset + 1]]));
        let dmin = safe_f16_scale(u16::from_le_bytes([data[offset + 2], data[offset + 3]]));
        offset += 4;

        // Dequantize 16 sub-blocks of 16 elements
        for j in 0..16 {
            let sc_byte = scales_bytes[j];
            let scale = d * f32::from(sc_byte & 0x0F);
            let min_val = dmin * f32::from(sc_byte >> 4);

            for l in 0..4 {
                let q_byte = qs[j * 4 + l];
                for k in 0..4 {
                    let q = (q_byte >> (k * 2)) & 0x03;
                    result.push(f32::from(q) * scale - min_val);
                }
            }
        }
    }

    result.truncate(num_elements);
    Ok(result)
}

/// Dequantize `Q3_K` format (K-quants)
/// `Q3_K`: super blocks of 256 elements
#[ensures(ret.as_ref().map_or(true, |v| v.len() == num_elements))]
pub(crate) fn dequantize_q3_k(data: &[u8], start: usize, num_elements: usize) -> Result<Vec<f32>> {
    const SUPER_BLOCK_SIZE: usize = 256;
    const SUPER_BLOCK_BYTES: usize = 32 + 64 + 12 + 2; // hmask, qs, scales, d = 110 bytes

    let num_blocks = num_elements.div_ceil(SUPER_BLOCK_SIZE);
    let total_bytes = num_blocks * SUPER_BLOCK_BYTES;

    if start + total_bytes > data.len() {
        return Err(AprenderError::FormatError {
            message: "Q3_K data exceeds file size".to_string(),
        });
    }

    let mut result = Vec::with_capacity(num_elements);
    let mut offset = start;

    for _ in 0..num_blocks {
        // Read hmask (32 bytes = 256 high bits)
        let hmask = &data[offset..offset + 32];
        offset += 32;

        // Read qs (64 bytes = 256 low 2-bit values)
        let qs = &data[offset..offset + 64];
        offset += 64;

        // Read scales (12 bytes = packed 6-bit scales)
        let scales_bytes = &data[offset..offset + 12];
        offset += 12;

        // Read d
        // GH-186 FIX: Use safe_f16_scale to clamp NaN/Inf/subnormal
        let d = safe_f16_scale(u16::from_le_bytes([data[offset], data[offset + 1]]));
        offset += 2;

        // Unpack scales
        let mut scales = [0i8; 16];
        for i in 0..8 {
            scales[i] = (scales_bytes[i] & 0x0F) as i8 - 8;
            scales[i + 8] = (scales_bytes[i] >> 4) as i8 - 8;
        }

        // Dequantize
        for j in 0..256 {
            let sub_block = j / 16;
            let q_idx = j / 4;
            let q_shift = (j % 4) * 2;
            let h_idx = j / 8;
            let h_shift = j % 8;

            let q_low = (qs[q_idx] >> q_shift) & 0x03;
            let q_high = ((hmask[h_idx] >> h_shift) & 1) << 2;
            let q = (q_low | q_high) as i8 - 4;

            result.push(d * f32::from(scales[sub_block]) * f32::from(q));
        }
    }

    result.truncate(num_elements);
    Ok(result)
}

/// Approximate dequantization for I-quants (IQ2, IQ3, IQ4)
/// These use importance-weighted quantization with lookup tables.
/// For import purposes, we approximate with a simple linear mapping.
pub(crate) fn dequantize_iq_approximate(
    data: &[u8],
    start: usize,
    num_elements: usize,
    dtype: u32,
) -> Vec<f32> {
    // I-quants have variable block sizes and complex lookup tables
    // Approximate by treating as low-bit quantization with estimated scale

    let (bits_per_element, block_size): (usize, usize) = match dtype {
        13..=15 => (2, 256), // IQ2_XXS, IQ2_XS, IQ2_S
        16 | 17 => (3, 256), // IQ3_XXS, IQ3_S
        18 => (1, 256),      // IQ1_S
        _ => (4, 256),       // IQ4_NL, IQ4_XS, and default
    };

    let bytes_per_block = (block_size * bits_per_element).div_ceil(8) + 4; // data + scale overhead
    let num_blocks = num_elements.div_ceil(block_size);

    // For approximation, create small random-ish values based on byte patterns
    // This is NOT correct dequantization but allows import to proceed
    let mut result = Vec::with_capacity(num_elements);
    let scale = 0.01; // Small scale for approximate values

    for block_idx in 0..num_blocks {
        let block_start = start + block_idx * bytes_per_block;

        for i in 0..block_size {
            if result.len() >= num_elements {
                break;
            }

            // Use byte pattern to generate approximate value
            let byte_idx = block_start + (i * bits_per_element) / 8;
            if byte_idx < data.len() {
                let byte_val = data[byte_idx];
                // Map to roughly centered distribution
                let approx = (f32::from(byte_val) - 128.0) * scale;
                result.push(approx);
            } else {
                result.push(0.0);
            }
        }
    }

    result.truncate(num_elements);
    result
}