import asyncio
import time
from dataclasses import dataclass
from typing import cast
import icechunk
import numpy as np
import zarr
from dask.distributed import Client
@dataclass
class Task:
storage_config: dict
store_config: dict
area: tuple[slice, slice]
seed: int
CHUNKS_PER_DIM = 10
CHUNK_DIM_SIZE = 10
CHUNKS_PER_TASK = 2
async def mk_store(mode: str, task: Task):
storage_config = icechunk.StorageConfig.s3_from_config(
**task.storage_config,
credentials=icechunk.S3Credentials(
access_key_id="minio123",
secret_access_key="minio123",
),
)
store_config = icechunk.StoreConfig(**task.store_config)
store = await icechunk.IcechunkStore.open(
storage=storage_config,
mode="a",
config=store_config,
)
return store
def generate_task_array(task: Task):
np.random.seed(task.seed)
nx = len(range(*task.area[0].indices(1000)))
ny = len(range(*task.area[1].indices(1000)))
return np.random.rand(nx, ny)
async def execute_task(task: Task):
store = await mk_store("w", task)
group = zarr.group(store=store, overwrite=False)
array = cast(zarr.Array, group["array"])
array[task.area] = generate_task_array(task)
return store.change_set_bytes()
def run_task(task: Task):
return asyncio.run(execute_task(task))
async def test_distributed_writers():
client = Client(n_workers=8)
storage_config = {
"bucket": "testbucket",
"prefix": "python-distributed-writers-test__" + str(time.time()),
"endpoint_url": "http://localhost:9000",
"region": "us-east-1",
"allow_http": True,
}
store_config = {"inline_chunk_threshold_bytes": 5}
ranges = [
(
slice(
x,
min(
x + CHUNKS_PER_TASK * CHUNK_DIM_SIZE,
CHUNK_DIM_SIZE * CHUNKS_PER_DIM,
),
),
slice(
y,
min(
y + CHUNKS_PER_TASK * CHUNK_DIM_SIZE,
CHUNK_DIM_SIZE * CHUNKS_PER_DIM,
),
),
)
for x in range(
0, CHUNK_DIM_SIZE * CHUNKS_PER_DIM, CHUNKS_PER_TASK * CHUNK_DIM_SIZE
)
for y in range(
0, CHUNK_DIM_SIZE * CHUNKS_PER_DIM, CHUNKS_PER_TASK * CHUNK_DIM_SIZE
)
]
tasks = [
Task(
storage_config=storage_config,
store_config=store_config,
area=area,
seed=idx,
)
for idx, area in enumerate(ranges)
]
store = await mk_store("r+", tasks[0])
group = zarr.group(store=store, overwrite=True)
n = CHUNKS_PER_DIM * CHUNK_DIM_SIZE
array = group.create_array(
"array",
shape=(n, n),
chunk_shape=(CHUNK_DIM_SIZE, CHUNK_DIM_SIZE),
dtype="f8",
fill_value=float("nan"),
)
_first_snap = await store.commit("array created")
map_result = client.map(run_task, tasks)
change_sets_bytes = client.gather(map_result)
commit_res = await store.distributed_commit("distributed commit", change_sets_bytes)
assert commit_res
store = await mk_store("r", tasks[0])
all_keys = [key async for key in store.list_prefix("/")]
assert (
len(all_keys) == 1 + 1 + CHUNKS_PER_DIM * CHUNKS_PER_DIM
)
group = zarr.group(store=store, overwrite=False)
for task in tasks:
actual = array[task.area]
expected = generate_task_array(task)
np.testing.assert_array_equal(actual, expected)