mousiki 0.2.1

Pure Rust Opus codec.
Documentation
#![allow(dead_code)]

use crate::mlp_data::{
    LAYER0_BIAS, LAYER0_WEIGHTS, LAYER1_BIAS, LAYER1_RECUR_WEIGHTS, LAYER1_WEIGHTS, LAYER2_BIAS,
    LAYER2_WEIGHTS,
};
use libm::fmaf;
/// Scaling factor applied to all dense and GRU outputs.
pub(crate) const WEIGHTS_SCALE: f32 = 1.0 / 128.0;

/// Upper bound on the number of neurons handled by the analysis GRU.
pub(crate) const MAX_NEURONS: usize = 32;

#[derive(Clone, Copy, Debug)]
pub(crate) struct AnalysisDenseLayer {
    pub bias: &'static [i8],
    pub input_weights: &'static [i8],
    pub nb_inputs: usize,
    pub nb_neurons: usize,
    pub sigmoid: bool,
}

#[derive(Clone, Copy, Debug)]
pub(crate) struct AnalysisGRULayer {
    pub bias: &'static [i8],
    pub input_weights: &'static [i8],
    pub recurrent_weights: &'static [i8],
    pub nb_inputs: usize,
    pub nb_neurons: usize,
}

#[cfg(test)]
mod gru_trace {
    extern crate std;

    use core::sync::atomic::{AtomicIsize, Ordering};
    use std::env;
    use std::sync::OnceLock;

    pub(crate) struct TraceConfig {
        frame: Option<usize>,
        want_bits: bool,
    }

    static TRACE_CONFIG: OnceLock<Option<TraceConfig>> = OnceLock::new();
    static TRACE_FRAME: AtomicIsize = AtomicIsize::new(-1);

    fn env_truthy(name: &str) -> bool {
        match env::var(name) {
            Ok(value) => !value.is_empty() && value != "0",
            Err(_) => false,
        }
    }

    pub(crate) fn set_frame(frame_idx: usize) {
        TRACE_FRAME.store(frame_idx as isize, Ordering::Relaxed);
    }

    fn current_frame() -> Option<usize> {
        let value = TRACE_FRAME.load(Ordering::Relaxed);
        if value >= 0 {
            Some(value as usize)
        } else {
            None
        }
    }

    pub(crate) fn config() -> Option<&'static TraceConfig> {
        TRACE_CONFIG
            .get_or_init(|| {
                let enabled = env_truthy("ANALYSIS_TRACE_GRU")
                    || env_truthy("ANALYSIS_TRACE_GRU_FRAME")
                    || env_truthy("ANALYSIS_TRACE_GRU_BITS");
                if !enabled {
                    return None;
                }
                let frame = env::var("ANALYSIS_TRACE_GRU_FRAME")
                    .ok()
                    .and_then(|value| value.parse::<usize>().ok());
                Some(TraceConfig {
                    frame,
                    want_bits: env_truthy("ANALYSIS_TRACE_GRU_BITS"),
                })
            })
            .as_ref()
    }

    pub(crate) fn should_dump(cfg: &TraceConfig) -> Option<usize> {
        let frame_idx = current_frame()?;
        if cfg.frame.map_or(true, |value| value == frame_idx) {
            Some(frame_idx)
        } else {
            None
        }
    }

    pub(crate) fn dump_vec(cfg: &TraceConfig, frame_idx: usize, label: &str, values: &[f32]) {
        crate::test_trace::trace_println!("analysis_gru[{frame_idx}].{label}.len={}", values.len());
        for (idx, &value) in values.iter().enumerate() {
            crate::test_trace::trace_println!(
                "analysis_gru[{frame_idx}].{label}[{idx}]={:.9e}",
                value as f64
            );
            if cfg.want_bits {
                crate::test_trace::trace_println!(
                    "analysis_gru[{frame_idx}].{label}_bits[{idx}]=0x{:08x}",
                    value.to_bits()
                );
            }
        }
    }
}

#[cfg(test)]
pub(crate) fn set_gru_trace_frame(frame_idx: usize) {
    gru_trace::set_frame(frame_idx);
}

#[inline]
fn fmadd(a: f32, b: f32, c: f32) -> f32 {
    // Keep fused semantics to match the reference C build's contraction.
    fmaf(a, b, c)
}

#[allow(clippy::excessive_precision)]
#[inline]
fn tansig_approx(x: f32) -> f32 {
    const N0: f32 = 952.528_015_14;
    const N1: f32 = 96.392_356_87;
    const N2: f32 = 0.608_630_42;
    const D0: f32 = 952.723_999_02;
    const D1: f32 = 413.368_011_47;
    const D2: f32 = 11.886_009_22;

    let x2 = x * x;
    let num = fmadd(fmadd(N2, x2, N1), x2, N0);
    let den = fmadd(fmadd(D2, x2, D1), x2, D0);
    let value = num * x / den;
    value.clamp(-1.0, 1.0)
}

#[inline]
fn sigmoid_approx(x: f32) -> f32 {
    0.5 + 0.5 * tansig_approx(0.5 * x)
}

fn gemm_accum(
    out: &mut [f32],
    weights: &[i8],
    rows: usize,
    cols: usize,
    col_stride: usize,
    x: &[f32],
) {
    debug_assert!(out.len() >= rows);
    debug_assert!(x.len() >= cols);
    let required = rows + col_stride.saturating_mul(cols.saturating_sub(1));
    debug_assert!(weights.len() >= required);

    for i in 0..rows {
        let mut acc = out[i];
        for j in 0..cols {
            let weight = f32::from(weights[j * col_stride + i]);
            // Mirror C's multiply-accumulate contraction.
            acc = fmaf(weight, x[j], acc);
        }
        out[i] = acc;
    }
}

pub(crate) fn analysis_compute_dense(
    layer: &AnalysisDenseLayer,
    output: &mut [f32],
    input: &[f32],
) {
    let m = layer.nb_inputs;
    let n = layer.nb_neurons;
    assert!(input.len() >= m, "dense input shorter than expected");
    assert!(output.len() >= n, "dense output shorter than expected");
    assert_eq!(
        layer.bias.len(),
        n,
        "dense bias must contain one entry per neuron"
    );
    assert!(
        layer.input_weights.len() >= n * m,
        "dense weights buffer is too small"
    );

    for (dst, &bias) in output.iter_mut().take(n).zip(layer.bias.iter()) {
        *dst = f32::from(bias);
    }
    gemm_accum(output, layer.input_weights, n, m, n, input);

    if layer.sigmoid {
        for value in output.iter_mut().take(n) {
            *value = sigmoid_approx(*value * WEIGHTS_SCALE);
        }
    } else {
        for value in output.iter_mut().take(n) {
            *value = tansig_approx(*value * WEIGHTS_SCALE);
        }
    }
}

pub(crate) fn analysis_compute_gru(gru: &AnalysisGRULayer, state: &mut [f32], input: &[f32]) {
    let m = gru.nb_inputs;
    let n = gru.nb_neurons;
    assert!(n <= MAX_NEURONS, "GRU exceeds max neuron count");
    assert!(state.len() >= n, "GRU state shorter than expected");
    assert!(input.len() >= m, "GRU input shorter than expected");
    assert_eq!(
        gru.bias.len(),
        3 * n,
        "GRU bias must contain three slices of length `nb_neurons`"
    );
    let stride = 3 * n;
    assert!(
        gru.input_weights.len() >= stride * m,
        "GRU input weights buffer is too small"
    );
    assert!(
        gru.recurrent_weights.len() >= stride * n,
        "GRU recurrent weights buffer is too small"
    );

    let mut tmp = [0.0f32; MAX_NEURONS];
    let mut z = [0.0f32; MAX_NEURONS];
    let mut r = [0.0f32; MAX_NEURONS];
    let mut h = [0.0f32; MAX_NEURONS];
    let mut h_pre = [0.0f32; MAX_NEURONS];
    let mut h_act = [0.0f32; MAX_NEURONS];
    let mut h_mix1 = [0.0f32; MAX_NEURONS];
    let mut h_mix2 = [0.0f32; MAX_NEURONS];

    // Update gate.
    for (dst, &bias) in z.iter_mut().zip(&gru.bias[..n]) {
        *dst = f32::from(bias);
    }
    gemm_accum(&mut z, gru.input_weights, n, m, stride, input);
    gemm_accum(&mut z, gru.recurrent_weights, n, n, stride, state);
    for value in z.iter_mut().take(n) {
        *value = sigmoid_approx(*value * WEIGHTS_SCALE);
    }

    // Reset gate.
    for (dst, &bias) in r.iter_mut().zip(&gru.bias[n..2 * n]) {
        *dst = f32::from(bias);
    }
    gemm_accum(&mut r, &gru.input_weights[n..], n, m, stride, input);
    gemm_accum(&mut r, &gru.recurrent_weights[n..], n, n, stride, state);
    for value in r.iter_mut().take(n) {
        *value = sigmoid_approx(*value * WEIGHTS_SCALE);
    }

    // Candidate output.
    for (dst, &bias) in h.iter_mut().zip(&gru.bias[2 * n..3 * n]) {
        *dst = f32::from(bias);
    }
    for i in 0..n {
        tmp[i] = state[i] * r[i];
    }
    gemm_accum(&mut h, &gru.input_weights[2 * n..], n, m, stride, input);
    gemm_accum(&mut h, &gru.recurrent_weights[2 * n..], n, n, stride, &tmp);
    h_pre[..n].copy_from_slice(&h[..n]);
    for i in 0..n {
        h_act[i] = tansig_approx(h[i] * WEIGHTS_SCALE);
    }
    for i in 0..n {
        h_mix1[i] = z[i] * state[i];
        h_mix2[i] = (1.0 - z[i]) * h_act[i];
        h[i] = h_mix1[i] + h_mix2[i];
    }
    #[cfg(test)]
    if let Some(cfg) = gru_trace::config() {
        if let Some(frame_idx) = gru_trace::should_dump(cfg) {
            gru_trace::dump_vec(cfg, frame_idx, "z", &z[..n]);
            gru_trace::dump_vec(cfg, frame_idx, "r", &r[..n]);
            gru_trace::dump_vec(cfg, frame_idx, "h_pre", &h_pre[..n]);
            gru_trace::dump_vec(cfg, frame_idx, "h_act", &h_act[..n]);
            gru_trace::dump_vec(cfg, frame_idx, "h_mix1", &h_mix1[..n]);
            gru_trace::dump_vec(cfg, frame_idx, "h_mix2", &h_mix2[..n]);
            gru_trace::dump_vec(cfg, frame_idx, "h_post", &h[..n]);
        }
    }
    state[..n].copy_from_slice(&h[..n]);
}

pub(crate) const LAYER0: AnalysisDenseLayer = AnalysisDenseLayer {
    bias: &LAYER0_BIAS,
    input_weights: &LAYER0_WEIGHTS,
    nb_inputs: 25,
    nb_neurons: 32,
    sigmoid: false,
};

pub(crate) const LAYER1: AnalysisGRULayer = AnalysisGRULayer {
    bias: &LAYER1_BIAS,
    input_weights: &LAYER1_WEIGHTS,
    recurrent_weights: &LAYER1_RECUR_WEIGHTS,
    nb_inputs: 32,
    nb_neurons: 24,
};

pub(crate) const LAYER2: AnalysisDenseLayer = AnalysisDenseLayer {
    bias: &LAYER2_BIAS,
    input_weights: &LAYER2_WEIGHTS,
    nb_inputs: 24,
    nb_neurons: 2,
    sigmoid: true,
};

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn dense_layer_matches_manual_product() {
        let layer = AnalysisDenseLayer {
            bias: &[0, 0],
            input_weights: &[1, -1, 2, -2],
            nb_inputs: 2,
            nb_neurons: 2,
            sigmoid: false,
        };
        let mut output = [0.0f32; 2];
        let input = [1.0f32, -1.0];

        analysis_compute_dense(&layer, &mut output, &input);

        let expected0 = tansig_approx((1.0 - 2.0) * WEIGHTS_SCALE);
        let expected1 = tansig_approx((-1.0 + 2.0) * WEIGHTS_SCALE);
        assert!((output[0] - expected0).abs() < 1e-6);
        assert!((output[1] - expected1).abs() < 1e-6);
    }

    #[test]
    fn gru_updates_state_with_expected_shape() {
        // Two-neuron GRU with a single input; only the candidate gate is biased.
        const BIAS: [i8; 6] = [0, 0, 0, 0, 64, 64];
        const WEIGHTS: [i8; 6] = [1, 2, 0, 0, 0, 0];
        const RECUR: [i8; 12] = [0; 12];
        let gru = AnalysisGRULayer {
            bias: &BIAS,
            input_weights: &WEIGHTS,
            recurrent_weights: &RECUR,
            nb_inputs: 1,
            nb_neurons: 2,
        };
        let mut state = [0.0f32; 2];
        let input = [1.0f32];

        analysis_compute_gru(&gru, &mut state, &input);

        // Larger update weight on the second neuron damps the biased candidate
        // a bit more than the first.
        assert!(state[0] > 0.0);
        assert!(state[1] > 0.0);
        assert!(state[0] > state[1]);
    }

    #[test]
    fn fmadd_uses_fma_semantics() {
        let mut seed: u32 = 0x9e37_79b9;
        let mut found = false;
        for _ in 0..200_000 {
            seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
            let a_bits = (seed & 0x007f_ffff) | 0x3f00_0000;
            seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
            let b_bits = (seed & 0x007f_ffff) | 0x3f80_0000;
            seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
            let c_bits = (seed & 0x007f_ffff) | 0x3f00_0000;

            let a = f32::from_bits(a_bits);
            let b = f32::from_bits(b_bits);
            let c = f32::from_bits(c_bits);
            let fused = fmaf(a, b, c);
            let unfused = a * b + c;
            if fused.to_bits() != unfused.to_bits() {
                assert_eq!(fmadd(a, b, c).to_bits(), fused.to_bits());
                found = true;
                break;
            }
        }
        assert!(found, "expected to find FMA-sensitive inputs");
    }
}