ragged-buffer 0.4.8

Efficient RaggedBuffer datatype that implements 3D arrays with variable-length 2nd dimension.
Documentation
from typing import Any, Generic, List, Protocol, Type, TypeVar, Union, cast, overload
from numpy.typing import NDArray
import numpy as np

from .ragged_buffer import (
    RaggedBufferF32,
    RaggedBufferI64,
    RaggedBufferBool,
    translate_rotate,
)

ScalarType = TypeVar("ScalarType", bound=np.generic)


class RaggedBuffer(Generic[ScalarType]):
    def __init__(self, features: int) -> None:
        raise ValueError(
            "RaggedBuffer is an abstract class, use RaggedBufferF32 or RaggedBufferI64"
        )

    @classmethod
    def from_array(cls, x: NDArray[ScalarType]) -> "RaggedBuffer[ScalarType]":
        if x.dtype == np.float32:
            return RaggedBufferF32.from_array(x)
        elif x.dtype == np.int64:
            return RaggedBufferI64.from_array(x)
        elif x.dtype == np.bool_:
            return RaggedBufferBool.from_array(x)
        else:
            raise ValueError(
                f"Unsupported dtype {x.dtype}. Only float32 and int64 are currently supported."
            )

    @classmethod
    def from_flattened(
        cls, flattened: NDArray[ScalarType], lengths: NDArray[np.int64]
    ) -> "RaggedBuffer[ScalarType]":
        if flattened.dtype == np.float32:
            return RaggedBufferF32.from_flattened(flattened, lengths)
        elif flattened.dtype == np.int64:
            return RaggedBufferI64.from_flattened(flattened, lengths)
        elif flattened.dtype == np.bool_:
            return RaggedBufferBool.from_flattened(flattened, lengths)
        else:
            raise ValueError(
                f"Unsupported dtype {flattened.dtype}. Only float32 and int64 are currently supported."
            )


def cat(
    buffers: List[RaggedBuffer[ScalarType]], dim: int = 0
) -> RaggedBuffer[ScalarType]:
    if len(buffers) == 0:
        raise ValueError("Can't concatenate an empty list of buffers")
    else:
        if isinstance(buffers[0], RaggedBufferF32):
            return RaggedBufferF32.cat(buffers, dim)
        elif isinstance(buffers[0], RaggedBufferI64):
            return RaggedBufferI64.cat(buffers, dim)
        elif isinstance(buffers[0], RaggedBufferBool):
            return RaggedBufferBool.cat(buffers, dim)
        else:
            raise TypeError(f"Type {type(buffers[0])} is not a RaggedBuffer")