use crate::api::Direction;
use crate::kernel::Float;
#[derive(Debug, Clone)]
pub struct SparseProblem<T: Float> {
pub n: usize,
pub k: usize,
pub direction: Direction,
pub noise_tolerance: T,
}
impl<T: Float> SparseProblem<T> {
pub fn new(n: usize, k: usize, direction: Direction) -> Self {
Self {
n,
k,
direction,
noise_tolerance: T::ZERO,
}
}
pub fn with_noise_tolerance(mut self, tolerance: T) -> Self {
self.noise_tolerance = tolerance;
self
}
pub fn is_valid(&self) -> bool {
self.n > 0 && self.k > 0 && self.k <= self.n
}
pub fn is_sparse_beneficial(&self) -> bool {
self.k < self.n / 16 && self.n >= 128
}
pub fn optimal_buckets(&self) -> usize {
let c = 3;
(c * self.k).max(16).min(self.n)
}
pub fn optimal_repetitions(&self) -> usize {
let ratio = (self.n as f64) / (self.k as f64);
let reps = (ratio.ln() / 2.0_f64.ln()).ceil() as usize;
reps.max(1).min(10)
}
pub fn crt_factors(&self) -> Vec<usize> {
let b = self.optimal_buckets();
let mut factors = Vec::new();
let small_primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31];
let mut product = 1usize;
for &p in &small_primes {
if product >= b {
break;
}
factors.push(p);
product *= p;
}
if factors.len() < 2 {
factors = vec![2, 3];
}
factors
}
}
impl<T: Float> Default for SparseProblem<T> {
fn default() -> Self {
Self {
n: 0,
k: 0,
direction: Direction::Forward,
noise_tolerance: T::ZERO,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_problem_creation() {
let problem: SparseProblem<f64> = SparseProblem::new(1024, 10, Direction::Forward);
assert_eq!(problem.n, 1024);
assert_eq!(problem.k, 10);
assert!(problem.is_valid());
}
#[test]
fn test_sparse_beneficial() {
let problem1: SparseProblem<f64> = SparseProblem::new(1024, 10, Direction::Forward);
assert!(problem1.is_sparse_beneficial());
let problem2: SparseProblem<f64> = SparseProblem::new(1024, 512, Direction::Forward);
assert!(!problem2.is_sparse_beneficial());
let problem3: SparseProblem<f64> = SparseProblem::new(64, 4, Direction::Forward);
assert!(!problem3.is_sparse_beneficial());
}
#[test]
fn test_optimal_buckets() {
let problem: SparseProblem<f64> = SparseProblem::new(1024, 10, Direction::Forward);
let buckets = problem.optimal_buckets();
assert!(buckets >= 16);
assert!(buckets <= 1024);
}
#[test]
fn test_crt_factors() {
let problem: SparseProblem<f64> = SparseProblem::new(1024, 10, Direction::Forward);
let factors = problem.crt_factors();
assert!(factors.len() >= 2);
for i in 0..factors.len() {
for j in i + 1..factors.len() {
assert_eq!(gcd(factors[i], factors[j]), 1);
}
}
}
fn gcd(a: usize, b: usize) -> usize {
if b == 0 {
a
} else {
gcd(b, a % b)
}
}
}