quadrank 0.2.0

Fast rank over binary and size-4 DNA alphabets.
Documentation
use std::{any::type_name, sync::OnceLock};

use clap::Parser;
use prefetch_index::prefetch_index;
use quadrank::{
    binary::{
        self,
        blocks::{
            BinaryBlock16, BinaryBlock16Spider, BinaryBlock16x2, BinaryBlock32x2, BinaryBlock64x2,
        },
    },
    ext::genedex,
    ext::sux::*,
    quad::{
        QuadRank, QuadRanker,
        blocks::{QuadBlock16, QuadBlock24_8, QuadBlock64},
    },
};
use sux::prelude::Rank9;

type QS = Vec<Vec<usize>>;

static THREADS: OnceLock<Vec<usize>> = OnceLock::new();

const REPEATS: usize = 3;

fn time_fn_mt(queries: &[Vec<usize>], f: impl Fn(&[usize]) + Sync + Copy) -> f64 {
    let start = std::time::Instant::now();
    std::thread::scope(|scope| {
        queries.iter().for_each(|queries| {
            let f = &f;
            scope.spawn(move || {
                f(queries);
            });
        });
    });
    start.elapsed().as_nanos() as f64 / (queries.len() * queries[0].len()) as f64
}

/// Take the median of 3 runs.
/// Exclude max: it might be slow.
/// Exclude min: it might too fast from short-term boost frequency.
fn time_fn(queries: &QS, t: usize, f: impl Fn(&[usize]) + Sync + Copy) {
    let mut nss = (0..REPEATS)
        .map(|_| time_fn_mt(&queries[0..t], f))
        .collect::<Vec<_>>();
    nss.sort_by(|a, b| a.partial_cmp(b).unwrap());
    let ns = nss[1];
    eprint!(" {ns:>8.3}");
    print!(",{ns:.5}");
}

const BATCH: usize = 32;

fn time_latency(queries: &QS, t: usize, f: impl Fn(usize) -> u64 + Sync + Copy) {
    time_fn(queries, t, |queries| {
        let mut acc = 0;
        for &q in queries {
            // Make query depend on previous result.
            let q = q ^ acc;
            let rank = f(q);
            acc ^= (rank & 1) as usize;
        }
    });
}

fn time_loop(queries: &QS, t: usize, f: impl Fn(usize) -> u64 + Sync + Copy) {
    time_fn(queries, t, |queries| {
        for &q in queries {
            f(q);
        }
    });
}

#[allow(unused)]
fn time_batch(
    queries: &QS,
    t: usize,
    prefetch: impl Fn(usize) + Sync,
    f: impl Fn(usize) -> u64 + Sync,
) {
    time_fn(queries, t, |queries| {
        let qs = queries.as_chunks::<BATCH>().0;
        for batch in qs {
            for &q in batch {
                prefetch(q);
            }
            for &q in batch {
                f(q);
            }
        }
    })
}

fn time_stream(
    queries: &QS,
    t: usize,
    prefetch: impl Fn(usize) + Sync,
    f: impl Fn(usize) -> u64 + Sync,
) {
    time_fn(queries, t, |queries| {
        for i in 0..queries.len() - BATCH {
            unsafe {
                let q = *queries.get_unchecked(i);
                let ahead = *queries.get_unchecked(i + BATCH);
                // Prefetch next cacheline of queries.
                prefetch_index(queries, i + 2 * BATCH);
                prefetch(ahead);
                f(q);
            }
        }
    });
}

fn time_trip(
    queries: &QS,
    t: usize,
    prefetch: impl Fn(usize) + Sync + Copy,
    f: impl Fn(usize) -> u64 + Sync + Copy,
    stream: bool,
) {
    time_latency(queries, t, f);
    time_loop(queries, t, f);
    if stream {
        time_stream(queries, t, prefetch, f);
    } else {
        eprint!(" {:>8.3}", 0);
        print!(",{:.5}", 0);
    }
}

fn bench_header() {
    eprint!("{:<60} {:>11} {:>6} |", "Ranker", "n", "size",);
    for t in THREADS.wait() {
        eprint!(" {:>7}t {:>8} {:>8} |", t, "", "");
    }
    eprintln!();
    eprint!("{:<60} {:>11} {:>6} |", "", "", "");
    for _t in THREADS.wait() {
        eprint!(" {:>8} {:>8} {:>8} |", "latncy", "loop", "stream",);
    }
    eprintln!();
    print!("ranker,sigma,n,rel_size,count4");
    for t in THREADS.wait() {
        print!(",latency_{},loop_{},stream_{}", t, t, t);
    }
    println!();
}

fn bench_one_quad<R: QuadRanker>(packed_seq: &[u64], queries: &QS) {
    let ranker = R::new_packed(packed_seq);
    for count4 in [true, false] {
        let name = type_name::<R>();
        let name = regex::Regex::new(r"[a-zA-Z0-9_]+::")
            .unwrap()
            .replace_all(name, |_: &regex::Captures| "".to_string());

        eprint!("{name:<60} ");
        let n = packed_seq.len() * 64 / 2;
        eprint!("{n:>11} ");

        let rel_size = (ranker.size() * 8) as f64 / (packed_seq.len() * 64) as f64;
        eprint!("{rel_size:>5.3}x |");
        print!("\"{name}\",4,{n},{rel_size:>.3},{}", count4 as u8);

        for &t in THREADS.wait() {
            if count4 {
                time_trip(
                    &queries,
                    t,
                    |q| ranker.prefetch4(q),
                    |q| std::hint::black_box(unsafe { ranker.rank4_unchecked(q) })[0],
                    true,
                );
            } else {
                time_trip(
                    &queries,
                    t,
                    |q| ranker.prefetch1(q, q as u8 & 3),
                    |q| std::hint::black_box(unsafe { ranker.rank1_unchecked(q, q as u8 & 3) }),
                    true,
                );
            }
            eprint!(" |");
        }
        eprintln!();
        println!();
    }
}

fn bench_one_binary<R: binary::BiRanker>(packed_seq: &[u64], queries: &QS) {
    let name = type_name::<R>();
    let name = regex::Regex::new(r"[a-zA-Z0-9_]+::")
        .unwrap()
        .replace_all(name, |_: &regex::Captures| "".to_string());

    eprint!("{name:<60} ");
    let n = packed_seq.len() * 64;
    eprint!("{n:>11} ");

    let ranker = R::new_packed(packed_seq);
    let rel_size = (ranker.size() * 8) as f64 / (packed_seq.len() * 64) as f64;
    eprint!("{rel_size:>5.3}x |");
    print!("\"{name}\",2,{n},{rel_size:>.3},0");

    for &t in THREADS.wait() {
        time_trip(
            &queries,
            t,
            |q| ranker.prefetch(q),
            |q| std::hint::black_box(unsafe { ranker.rank_unchecked(q) }),
            R::HAS_PREFETCH,
        );
        eprint!(" |");
    }
    eprintln!();
    println!();
}

#[inline(never)]
fn bench_quad(seq: &[u64], queries: &QS) {
    bench_header();

    bench_one_quad::<qwt::RSQVector256>(seq, queries);
    bench_one_quad::<qwt::RSQVector512>(seq, queries);

    bench_one_quad::<genedex::Flat64>(seq, queries);
    bench_one_quad::<genedex::Flat512>(seq, queries);
    bench_one_quad::<genedex::Condensed64>(seq, queries);
    bench_one_quad::<genedex::Condensed512>(seq, queries);

    bench_one_quad::<QuadRank<QuadBlock64>>(seq, queries);
    bench_one_quad::<QuadRank<QuadBlock24_8>>(seq, queries);
    bench_one_quad::<QuadRank<QuadBlock16>>(seq, queries);

    // use quadrank::quad::super_blocks::ShiftPairedSB;
    // bench_one_quad::<Ranker<QuadBlock24_8, ShiftPairedSB, SimdCount11B>>(seq, queries);
    // bench_one_quad::<Ranker<QuadBlock16, ShiftPairedSB, NoCount>>(seq, queries);
}

#[inline(never)]
fn bench_binary(seq: &[u64], queries: &QS) {
    bench_header();

    bench_one_binary::<qwt::RSNarrow>(seq, queries);
    bench_one_binary::<qwt::RSWide>(seq, queries);

    bench_one_binary::<genedex::Flat512>(seq, queries);
    bench_one_binary::<genedex::Condensed64>(seq, queries);
    bench_one_binary::<genedex::Condensed512>(seq, queries);

    bench_one_binary::<bitm::RankSelect101111>(seq, queries);

    bench_one_binary::<Rank9>(seq, queries);
    bench_one_binary::<RankSmall0>(seq, queries);
    bench_one_binary::<RankSmall1>(seq, queries);
    bench_one_binary::<RankSmall2>(seq, queries);
    bench_one_binary::<RankSmall3>(seq, queries);
    bench_one_binary::<RankSmall4>(seq, queries);

    bench_one_binary::<binary::BiRank<BinaryBlock64x2>>(seq, queries);
    bench_one_binary::<binary::BiRank<BinaryBlock32x2>>(seq, queries);
    bench_one_binary::<binary::BiRank<BinaryBlock16x2>>(seq, queries);
    bench_one_binary::<binary::BiRank<BinaryBlock16>>(seq, queries);
    bench_one_binary::<binary::BiRank<BinaryBlock16Spider>>(seq, queries);
    // bench_one_binary::<binary::Ranker<BinaryBlock16Spider2>>(seq, queries);

    // bench_one_binary::<binary::Ranker<BinaryBlock32x2, ShiftPairedSB>>(seq, queries);
    // bench_one_binary::<binary::Ranker<BinaryBlock16x2, ShiftPairedSB>>(seq, queries);
    // bench_one_binary::<binary::Ranker<BinaryBlock16, ShiftPairedSB>>(seq, queries);
}

#[derive(clap::Parser)]
struct Args {
    /// Max number of threads
    #[clap(short = 'j')]
    threads: Option<usize>,
    #[clap(long)]
    to: Option<usize>,
    #[clap(short = 'n')]
    n: Option<usize>,
    #[clap(short = 'b')]
    binary: bool,
    #[clap(short = 'q')]
    quad: bool,
}

fn main() {
    let args = Args::parse();

    // queries per thread
    let q = 10_000_000;

    // size in bytes
    let mut sizes = vec![
        (1 << 17), // L2, 128 KiB
        (1 << 32), // RAM, 4 GiB
    ];

    // let mut sizes = (13..=args.to.unwrap_or(32))
    //     .map(|i| 1usize << i)
    //     .collect::<Vec<_>>();

    THREADS
        .set({
            let mut ts = vec![];
            let mut t = args.threads.unwrap_or(12);
            loop {
                ts.push(t);
                if t == 1 {
                    break;
                }
                t /= 2;
            }
            ts.reverse();
            ts
        })
        .unwrap();

    if let Some(n) = args.n {
        sizes = vec![n];
    }

    for size in sizes {
        eprintln!(
            "size = {} bytes = {} bits = {} bp",
            size,
            size * 8,
            size * 4
        );
        let seq = vec![
            0b1110010011100100111001001110010011100100111001001110010011100100;
            size.div_ceil(8)
        ];

        if args.binary {
            let n = size * 8;
            let queries = (0..*THREADS.wait().last().unwrap())
                .map(|_| (0..q).map(|_| rand::random_range(2..n)).collect::<Vec<_>>())
                .collect::<Vec<_>>();

            bench_binary(&seq, &queries);
        }
        if args.quad {
            let n = size * 4;
            let queries = (0..*THREADS.wait().last().unwrap())
                .map(|_| (0..q).map(|_| rand::random_range(2..n)).collect::<Vec<_>>())
                .collect::<Vec<_>>();

            bench_quad(&seq, &queries);
        }
    }
}

// TODO: 40bit support