laddu-extensions 0.19.1

Extensions to the laddu library
Documentation
#!/usr/bin/env python3
from __future__ import annotations

from collections.abc import Iterator
from pathlib import Path

import laddu as ld
from laddu import Dataset, Event, Vec3

P4_NAMES = ['beam']
PARQUET_P4_NAMES = ['beam', 'proton', 'kshort1', 'kshort2']


def make_event(weight: float) -> Event:
    return Event(
        [Vec3(0.0, 0.0, weight).with_mass(0.0)],
        [],
        weight,
        p4_names=P4_NAMES,
        aux_names=[],
    )


def check_manual_dataset() -> None:
    rank = ld.mpi.get_rank() or 0
    size = ld.mpi.get_size() or 1
    events = [make_event(float(index + 1)) for index in range(size)]
    dataset = Dataset(events, p4_names=P4_NAMES, aux_names=[])

    local_weights = [event.weight for event in dataset.events_local]
    global_weights = [event.weight for event in dataset.events_global]
    stored_weights = global_weights
    stored_local_weights = [event.weight for event in dataset.events_local]

    assert dataset.n_events == size
    assert dataset.n_events_local == 1
    assert dataset.n_events_weighted == sum(float(index + 1) for index in range(size))
    assert dataset.n_events_weighted_local == float(rank + 1)
    assert local_weights == [float(rank + 1)]
    assert sorted(global_weights) == [float(index + 1) for index in range(size)]
    assert sorted(stored_weights) == [float(index + 1) for index in range(size)]
    assert stored_local_weights == [float(rank + 1)]
    assert list(dataset.weights) == global_weights
    assert list(dataset.weights_local) == local_weights
    assert isinstance(dataset.event_global(rank), Event)
    assert isinstance(dataset.events_global, Iterator)
    assert isinstance(dataset.weights_global, type(dataset.weights))
    assert isinstance(dataset.n_events_global, int)
    assert isinstance(dataset.n_events_weighted_global, float)


def check_parquet_dataset() -> None:
    repo_root = Path(__file__).resolve().parents[3]
    data_path = repo_root / 'py-laddu' / 'tests' / 'data_files' / 'data_f32.parquet'
    dataset = ld.io.read_parquet(data_path, p4s=PARQUET_P4_NAMES)

    local_weights = [event.weight for event in dataset.events_local]
    global_weights = [event.weight for event in dataset.events_global]
    stored_weights = global_weights
    stored_local_weights = [event.weight for event in dataset.events_local]

    assert global_weights == stored_weights
    assert len(global_weights) == dataset.n_events
    assert len(global_weights) == len(stored_weights)
    assert len(local_weights) == dataset.n_events_local
    assert len(stored_local_weights) == dataset.n_events_local
    assert len(local_weights) <= len(global_weights)
    assert list(dataset.weights) == global_weights
    assert list(dataset.weights_local) == local_weights
    assert dataset.n_events_weighted == sum(global_weights)
    assert dataset.n_events_weighted_local == sum(local_weights)
    assert isinstance(dataset.event_global(0), Event)
    assert isinstance(dataset.events_global, Iterator)
    assert isinstance(dataset.weights_global, type(dataset.weights))
    assert isinstance(dataset.n_events_global, int)
    assert isinstance(dataset.n_events_weighted_global, float)


def main() -> None:
    if not ld.mpi.is_mpi_available():
        msg = 'laddu MPI backend is not available'
        raise RuntimeError(msg)

    with ld.mpi.MPI(trigger=True):
        check_manual_dataset()
        check_parquet_dataset()

        if ld.mpi.is_root():
            print('python-mpi-dataset-iteration: ok')


if __name__ == '__main__':
    main()