bed-reader 1.0.6

Read and write the PLINK BED format, simply and efficiently.
Documentation
import logging
from datetime import datetime
from pathlib import Path

import numpy as np
import pytest

from bed_reader import (
    file_aat_piece_f64_orderf,
    file_ata_piece_f64_orderf,
    file_b_less_aatbx,
    file_dot_piece,
)
from bed_reader._open_bed import get_num_threads, open_bed  # noqa


def file_ata(filename, offset, iid_count, sid_count, sid_step):
    ata = np.full((sid_count, sid_count), np.nan)
    for sid_index in range(0, sid_count, sid_step):
        sid_range_len = min(sid_step, sid_count - sid_index)
        ata_piece = np.full((sid_count - sid_index, sid_range_len), np.nan)
        if sid_index % 2 == 0:  # test new and old method
            file_ata_piece_f64_orderf(
                str(filename),
                offset,
                iid_count,
                sid_count,
                sid_index,
                ata_piece,
                num_threads=get_num_threads(None),
                log_frequency=sid_range_len,
            )
        else:
            file_dot_piece(
                str(filename),
                offset,
                iid_count,
                sid_index,
                ata_piece,
                num_threads=get_num_threads(None),
                log_frequency=sid_range_len,
            )
        ata[sid_index:, sid_index : sid_index + sid_range_len] = ata_piece
    for sid_index in range(sid_count):
        ata[sid_index, sid_index + 1 :] = ata[sid_index + 1 :, sid_index]
    return ata


def file_aat(filename, offset, iid_count, sid_count, iid_step):
    aat = np.full((iid_count, iid_count), np.nan)
    for iid_index in range(0, iid_count, iid_step):
        iid_range_len = min(iid_step, iid_count - iid_index)
        aat_piece = np.full((iid_count - iid_index, iid_range_len), np.nan)
        file_aat_piece_f64_orderf(
            str(filename),
            offset,
            iid_count,
            sid_count,
            iid_index,
            aat_piece,
            num_threads=get_num_threads(None),
            log_frequency=iid_range_len,
        )
        aat[iid_index:, iid_index : iid_index + iid_range_len] = aat_piece
    for iid_index in range(iid_count):
        aat[iid_index, iid_index + 1 :] = aat[iid_index + 1 :, iid_index]

    return aat


def write_read_test_file_ata(iid_count, sid_count, sid_step, tmp_path) -> None:
    offset = 640
    file_path = tmp_path / f"{iid_count}x{sid_count}_o{offset}_array.memmap"

    mm = np.memmap(
        file_path,
        dtype="float64",
        mode="w+",
        offset=offset,
        shape=(iid_count, sid_count),
        order="F",
    )
    mm[:] = np.linspace(0, 1, mm.size).reshape(mm.shape)
    mm.flush()

    out_val = file_ata(file_path, offset, iid_count, sid_count, sid_step)
    expected = mm.T.dot(mm)
    assert np.allclose(expected, out_val, equal_nan=True)


def write_read_test_file_aat(
    iid_count, sid_count, iid_step, tmp_path, skip_if_there=False,
):
    offset = 640
    file_path = tmp_path / f"{iid_count}x{sid_count}_o{offset}_array.memmap"
    if skip_if_there and file_path.exists():
        logging.info(f"'{file_path}' exists")
        mm = np.memmap(
            file_path,
            dtype="float64",
            mode="r",
            offset=offset,
            shape=(iid_count, sid_count),
            order="F",
        )
    else:
        logging.info(f"Creating '{file_path}'")
        mm = np.memmap(
            file_path,
            dtype="float64",
            mode="w+",
            offset=offset,
            shape=(iid_count, sid_count),
            order="F",
        )
        mm[:] = np.linspace(0, 1, mm.size).reshape(mm.shape)
        mm.flush()
        logging.info(f"Finished creating '{file_path}'")

    start_time = datetime.now()
    out_val = file_aat(file_path, offset, iid_count, sid_count, iid_step)
    delta = datetime.now() - start_time
    expected = mm.dot(mm.T)
    assert np.allclose(expected, out_val, equal_nan=True)
    return delta


def test_file_ata_medium(tmp_path) -> None:
    write_read_test_file_ata(100, 1000, 33, tmp_path)


def test_file_aat_medium(tmp_path) -> None:
    write_read_test_file_aat(100, 1000, 34, tmp_path)


# # Too slow
# def test_file_ata_giant(tmp_path):
#     write_read_test_file_ata(100_000, 10_000, 1000, tmp_path)

# # Too slow
# def test_file_aat_giant(tmp_path):
#     write_read_test_file_ata(1_000, 10_000, 100, tmp_path)


def test_file_ata_small(shared_datadir) -> None:
    filename = shared_datadir / "small_array.memmap"

    out_val = file_ata(filename, 0, 2, 3, 2)

    expected = np.array([[17.0, 22.0, 27.0], [22.0, 29.0, 36.0], [27.0, 36.0, 45.0]])
    assert np.allclose(expected, out_val, equal_nan=True)


def test_file_aat_small(shared_datadir) -> None:
    filename = shared_datadir / "small_array.memmap"

    out_val = file_aat(filename, 0, iid_count=2, sid_count=3, iid_step=1)

    expected = np.array([[14.0, 32.0], [32.0, 77.0]])
    assert np.allclose(expected, out_val, equal_nan=True)


def mmultfile_b_less_aatb(a_snp_mem_map, b, log_frequency=0, force_python_only=False):
    # Without memory efficiency
    #   a=a_snp_mem_map.val
    #   aTb = np.dot(a.T,b)
    #   aaTb = b-np.dot(a,aTb)
    #   return aTb, aaTb

    if force_python_only:
        aTb = np.zeros(
            (a_snp_mem_map.shape[1], b.shape[1]),
        )  # b can be destroyed. Is everything is in best order, i.e. F vs C
        aaTb = b.copy()
        b_mem = np.array(b, order="F")
        with open(a_snp_mem_map.filename, "rb") as U_fp:
            U_fp.seek(a_snp_mem_map.offset)
            for i in range(a_snp_mem_map.shape[1]):
                a_mem = np.fromfile(
                    U_fp, dtype=np.float64, count=a_snp_mem_map.shape[0],
                )
                if log_frequency > 0 and i % log_frequency == 0:
                    logging.info(f"{i}/{a_snp_mem_map.shape[1]}")
                aTb[i, :] = np.dot(a_mem, b_mem)
                aaTb -= np.dot(a_mem.reshape(-1, 1), aTb[i : i + 1, :])
    else:
        b1 = np.array(b, order="F")
        aTb = np.zeros((a_snp_mem_map.shape[1], b.shape[1]))
        aaTb = np.array(b1, order="F")

        file_b_less_aatbx(
            str(a_snp_mem_map.filename),
            a_snp_mem_map.offset,
            a_snp_mem_map.shape[0],  # row count
            b1,  # B copy 1 in "F" order
            aaTb,  # B copy 2 in "F" order
            aTb,  # result
            num_threads=get_num_threads(None),
            log_frequency=log_frequency,
        )

    return aTb, aaTb


def write_read_test_file_b_less_aatbx(
    iid_count, a_sid_count, b_sid_count, log_frequency, tmp_path, do_both=True,
) -> None:
    offset = 640
    file_path = tmp_path / f"{iid_count}x{a_sid_count}_o{offset}_array.memmap"
    mm = np.memmap(
        file_path,
        dtype="float64",
        mode="w+",
        offset=offset,
        shape=(iid_count, a_sid_count),
        order="F",
    )
    mm[:] = np.linspace(0, 1, mm.size).reshape(mm.shape)
    mm.flush()

    b = np.array(
        np.linspace(0, 1, iid_count * b_sid_count).reshape(
            (iid_count, b_sid_count), order="F",
        ),
    )
    b_again = b.copy()

    logging.info("Calling Rust")
    aTb, aaTb = mmultfile_b_less_aatb(
        mm, b_again, log_frequency, force_python_only=False,
    )

    if do_both:
        logging.info("Calling Python")
        aTb_python, aaTb_python = mmultfile_b_less_aatb(
            mm, b, log_frequency, force_python_only=True,
        )

        if (
            not np.abs(aTb_python - aTb).max() < 1e-8
            or not np.abs(aaTb_python - aaTb).max() < 1e-8
        ):
            msg = "Expect Python and Rust to get the same mmultfile_b_less_aatb answer"
            raise AssertionError(
                msg,
            )


def test_file_b_less_aatbx_medium(tmp_path) -> None:
    write_read_test_file_b_less_aatbx(500, 400, 100, 10, tmp_path, do_both=True)


def test_file_b_less_aatbx_medium2(tmp_path) -> None:
    write_read_test_file_b_less_aatbx(5_000, 400, 100, 100, tmp_path, do_both=True)


# Slow and doesn't check answer
# def test_file_b_less_aatbx_2(tmp_path):
#     write_read_test_file_b_less_aatbx(50_000, 4000, 1000, 100, tmp_path, do_both=False)


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    shared_datadir = Path(r"D:\OneDrive\programs\bed-reader\bed_reader\tests\data")
    tmp_path = Path(r"m:/deldir/tests")

    if False:
        small2_array = np.array(
            [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]],
            order="F",
        )
        mm = np.memmap(
            tmp_path / "small2_array.memmap",
            dtype="float64",
            mode="w+",
            shape=(3, 4),
            order="F",
        )
        mm[:] = small2_array[:]
        mm.flush()

    if False:
        small_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], order="F")
        mm = np.memmap(
            tmp_path / "small_array.memmap",
            dtype="float64",
            mode="w+",
            shape=(2, 3),
            order="F",
        )
        mm[:] = small_array[:]
        mm.flush()

    if False:
        mm = np.memmap(
            tmp_path / "100x1000_o640_array.memmap",
            dtype="float64",
            mode="w+",
            offset=640,
            shape=(100, 1000),
            order="F",
        )
        total = mm.shape[0] * mm.shape[1]
        lin = np.linspace(0, 1, total).reshape(mm.shape)

        mm[:] = lin[:]
        mm.flush()

    if False:
        mm = np.memmap(
            tmp_path / "1000x10000_o640_array.memmap",
            dtype="float64",
            mode="w+",
            offset=640,
            shape=(1000, 10_000),
            order="F",
        )
        total = mm.shape[0] * mm.shape[1]
        lin = np.linspace(0, 1, total).reshape(mm.shape)

        mm[:] = lin[:]
        mm.flush()

    if False:
        mm = np.memmap(
            tmp_path / "10_000x100_000_o6400_array.memmap",
            dtype="float64",
            mode="w+",
            offset=640,
            shape=(10_000, 100_000),
            order="F",
        )
        total = mm.shape[0] * mm.shape[1]
        lin = np.linspace(0, 1, total).reshape(mm.shape)

        mm[:] = lin[:]
        mm.flush()

    if False:
        mm = np.memmap(
            tmp_path / "100_000x10_000_o6400_array.memmap",
            dtype="float64",
            mode="w+",
            offset=640,
            shape=(100_000, 10_000),
            order="F",
        )
        total = mm.shape[0] * mm.shape[1]
        lin = np.linspace(0, 1, total).reshape(mm.shape)

        mm[:] = lin[:]
        mm.flush()

    # test_file_b_less_aatbx_2(tmp_path)
    # test_zero_files(tmp_path)
    # test_index(shared_datadir)
    # test_c_reader_bed(shared_datadir)
    # test_read1(shared_datadir)
    pytest.main([__file__])