poulpy-cpu-ref 0.6.0

Portable reference CPU implementations of poulpy-hal open extension points
Documentation
use poulpy_ckks::ckks_backend_test_suite;

const ATK_ROTATIONS: &[i64] = &[1, 7];

#[test]
fn encode_decode_reim_roundtrip() {
    use crate::FFT64ReimTable;
    use poulpy_ckks::encoding::reim::Encoder;
    use poulpy_ckks::layouts::CKKSModuleAlloc;

    let n = 16usize;
    let m = n / 2;

    let re_in: Vec<f64> = (0..m).map(|i| (i as f64) / (m as f64)).collect();
    let im_in: Vec<f64> = (0..m).map(|i| -((i as f64) / (m as f64))).collect();

    let encoder = Encoder::<FFT64ReimTable<f64>>::new::<f64>(m).unwrap();

    let host_module = poulpy_hal::layouts::Module::<poulpy_hal::layouts::HostBytesBackend>::new(n as u64);
    let mut pt = host_module.ckks_pt_vec_alloc(
        poulpy_core::layouts::Base2K(16),
        poulpy_ckks::CKKSMeta {
            log_delta: 40,
            log_budget: 10,
        },
    );
    encoder.encode_reim(&mut pt, &re_in, &im_in).unwrap();

    let mut re_out = vec![0.0f64; m];
    let mut im_out = vec![0.0f64; m];
    encoder.decode_reim(&pt, &mut re_out, &mut im_out).unwrap();

    let max_err = |a: &[f64], b: &[f64]| a.iter().zip(b).map(|(x, y)| (x - y).abs()).fold(0.0_f64, f64::max);
    let bound = 1e-10;
    let err_re = max_err(&re_in, &re_out);
    let err_im = max_err(&im_in, &im_out);
    assert!(err_re < bound, "re max_err={err_re:.2e} exceeds bound={bound:.2e}");
    assert!(err_im < bound, "im max_err={err_im:.2e} exceeds bound={bound:.2e}");
}

ckks_backend_test_suite!(
    mod fft64_f64,
    backend = crate::FFT64Ref,
    scalar = f64,
    encoder = crate::FFT64ReimTable<f64>,
    params = poulpy_ckks::test_suite::FFT64_PARAMS_F64,
    rotations = super::ATK_ROTATIONS,
);

ckks_backend_test_suite!(
    mod ntt120_f64,
    backend = crate::NTT120Ref,
    scalar = f64,
    encoder = crate::FFT64ReimTable<f64>,
    params = poulpy_ckks::test_suite::NTT120_PARAMS_F64,
    rotations = super::ATK_ROTATIONS,
);

ckks_backend_test_suite!(
    mod ntt120_f128,
    backend = crate::NTT120Ref,
    scalar = f128::f128,
    encoder = crate::FFT64ReimTable<f128::f128>,
    params = poulpy_ckks::test_suite::NTT120_PARAMS_F128,
    rotations = super::ATK_ROTATIONS,
);