from collections.abc import AsyncGenerator, Iterable
from typing import Any, Self
from zarr.abc.store import AccessMode, ByteRangeRequest, Store
from zarr.core.buffer import Buffer, BufferPrototype
from zarr.core.common import AccessModeLiteral, BytesLike
from zarr.core.sync import SyncMixin
from ._icechunk_python import (
PyIcechunkStore,
S3Credentials,
SnapshotMetadata,
StorageConfig,
StoreConfig,
VirtualRefConfig,
__version__,
pyicechunk_store_create,
pyicechunk_store_exists,
pyicechunk_store_from_bytes,
pyicechunk_store_open_existing,
)
__all__ = [
"__version__",
"IcechunkStore",
"StorageConfig",
"S3Credentials",
"SnapshotMetadata",
"StoreConfig",
"VirtualRefConfig",
]
class IcechunkStore(Store, SyncMixin):
_store: PyIcechunkStore
@classmethod
async def open(cls, *args: Any, **kwargs: Any) -> Self:
if "mode" in kwargs:
mode = kwargs.pop("mode")
else:
mode = "r"
access_mode = AccessMode.from_literal(mode)
if "storage" in kwargs:
storage = kwargs.pop("storage")
else:
raise ValueError(
"Storage configuration is required. Pass a Storage object to construct an IcechunkStore"
)
store_exists = await pyicechunk_store_exists(storage)
if access_mode.overwrite:
if store_exists:
raise ValueError(
"Store already exists and overwrite is not allowed for IcechunkStore"
)
store = await cls.create(storage, mode, *args, **kwargs)
elif access_mode.create or access_mode.update:
if store_exists:
store = await cls.open_existing(storage, mode, *args, **kwargs)
else:
store = await cls.create(storage, mode, *args, **kwargs)
else:
store = await cls.open_existing(storage, mode, *args, **kwargs)
store._is_open = True
return store
def __init__(
self,
store: PyIcechunkStore,
mode: AccessModeLiteral = "r",
*args: Any,
**kwargs: Any,
):
super().__init__(mode, *args, **kwargs)
if store is None:
raise ValueError(
"An IcechunkStore should not be created with the default constructor, instead use either the create or open_existing class methods."
)
self._store = store
@classmethod
async def open_existing(
cls,
storage: StorageConfig,
mode: AccessModeLiteral = "r",
config: StoreConfig | None = None,
*args: Any,
**kwargs: Any,
) -> Self:
config = config or StoreConfig()
read_only = mode == "r"
store = await pyicechunk_store_open_existing(
storage, read_only=read_only, config=config
)
return cls(store=store, mode=mode, args=args, kwargs=kwargs)
@classmethod
async def create(
cls,
storage: StorageConfig,
mode: AccessModeLiteral = "w",
config: StoreConfig | None = None,
*args: Any,
**kwargs: Any,
) -> Self:
config = config or StoreConfig()
store = await pyicechunk_store_create(storage, config=config)
return cls(store=store, mode=mode, args=args, kwargs=kwargs)
def with_mode(self, mode: AccessModeLiteral) -> Self:
read_only = mode == "r"
new_store = self._store.with_mode(read_only)
return self.__class__(new_store, mode=mode)
def __eq__(self, value: object) -> bool:
if not isinstance(value, self.__class__):
return False
return self._store == value._store
def __getstate__(self) -> object:
store_repr = self._store.as_bytes()
return {"store": store_repr, "mode": self.mode}
def __setstate__(self, state: Any) -> None:
store_repr = state["store"]
mode = state["mode"]
is_read_only = mode == "r"
self._store = pyicechunk_store_from_bytes(store_repr, is_read_only)
self._is_open = True
@property
def snapshot_id(self) -> str:
return self._store.snapshot_id
def change_set_bytes(self) -> bytes:
return self._store.change_set_bytes()
@property
def branch(self) -> str | None:
return self._store.branch
async def checkout(
self,
snapshot_id: str | None = None,
branch: str | None = None,
tag: str | None = None,
) -> None:
if snapshot_id is not None:
if branch is not None or tag is not None:
raise ValueError(
"only one of snapshot_id, branch, or tag may be specified"
)
return await self._store.checkout_snapshot(snapshot_id)
if branch is not None:
if tag is not None:
raise ValueError(
"only one of snapshot_id, branch, or tag may be specified"
)
return await self._store.checkout_branch(branch)
if tag is not None:
return await self._store.checkout_tag(tag)
raise ValueError("a snapshot_id, branch, or tag must be specified")
async def commit(self, message: str) -> str:
return await self._store.commit(message)
async def distributed_commit(
self, message: str, other_change_set_bytes: list[bytes]
) -> str:
return await self._store.distributed_commit(message, other_change_set_bytes)
@property
def has_uncommitted_changes(self) -> bool:
return self._store.has_uncommitted_changes
async def reset(self) -> None:
return await self._store.reset()
async def new_branch(self, branch_name: str) -> str:
return await self._store.new_branch(branch_name)
async def tag(self, tag_name: str, snapshot_id: str) -> None:
return await self._store.tag(tag_name, snapshot_id=snapshot_id)
def ancestry(self) -> AsyncGenerator[SnapshotMetadata, None]:
return self._store.ancestry()
async def empty(self) -> bool:
return await self._store.empty()
async def clear(self) -> None:
return await self._store.clear()
async def get(
self,
key: str,
prototype: BufferPrototype,
byte_range: tuple[int | None, int | None] | None = None,
) -> Buffer | None:
try:
result = await self._store.get(key, byte_range)
except ValueError as _e:
return None
return prototype.buffer.from_bytes(result)
async def get_partial_values(
self,
prototype: BufferPrototype,
key_ranges: Iterable[tuple[str, ByteRangeRequest]],
) -> list[Buffer | None]:
result = await self._store.get_partial_values(list(key_ranges))
return [prototype.buffer.from_bytes(r) for r in result]
async def exists(self, key: str) -> bool:
return await self._store.exists(key)
@property
def supports_writes(self) -> bool:
return self._store.supports_writes
async def set(self, key: str, value: Buffer) -> None:
return await self._store.set(key, value.to_bytes())
async def set_if_not_exists(self, key: str, value: Buffer) -> None:
return await self._store.set_if_not_exists(key, value.to_bytes())
async def set_virtual_ref(
self, key: str, location: str, *, offset: int, length: int
) -> None:
return await self._store.set_virtual_ref(key, location, offset, length)
async def delete(self, key: str) -> None:
return await self._store.delete(key)
@property
def supports_partial_writes(self) -> bool:
return self._store.supports_partial_writes
async def set_partial_values(
self, key_start_values: Iterable[tuple[str, int, BytesLike]]
) -> None:
return await self._store.set_partial_values(list(key_start_values))
@property
def supports_listing(self) -> bool:
return self._store.supports_listing
@property
def supports_deletes(self) -> bool:
return self._store.supports_deletes
def list(self) -> AsyncGenerator[str, None]:
return self._store.list()
def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
return self._store.list_prefix(prefix)
def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
return self._store.list_dir(prefix)