import pytest
import pyarrow as pa
import asyncio
try:
import arrow_zerobus_sdk_wrapper
from arrow_zerobus_sdk_wrapper import (
ZerobusWrapper,
WrapperConfiguration,
TransmissionResult,
ConfigurationError,
AuthenticationError,
ConnectionError,
ConversionError,
TransmissionError,
RetryExhausted,
TokenRefreshError,
)
except ImportError:
pytestmark = pytest.mark.skip("arrow_zerobus_sdk_wrapper not available")
def test_error_translation():
assert issubclass(ConfigurationError, Exception)
assert issubclass(AuthenticationError, Exception)
assert issubclass(ConnectionError, Exception)
assert issubclass(ConversionError, Exception)
assert issubclass(TransmissionError, Exception)
assert issubclass(RetryExhausted, Exception)
assert issubclass(TokenRefreshError, Exception)
config_error = ConfigurationError("test config error")
assert isinstance(config_error, Exception)
assert str(config_error) == "test config error"
auth_error = AuthenticationError("test auth error")
assert isinstance(auth_error, Exception)
assert str(auth_error) == "test auth error"
def test_error_translation_all_types():
error_messages = {
ConfigurationError: "Configuration error",
AuthenticationError: "Authentication error",
ConnectionError: "Connection error",
ConversionError: "Conversion error",
TransmissionError: "Transmission error",
RetryExhausted: "Retry exhausted",
TokenRefreshError: "Token refresh error",
}
for error_class, message in error_messages.items():
error = error_class(message)
assert isinstance(error, Exception)
assert str(error) == message
@pytest.mark.asyncio
async def test_async_context_manager():
try:
config = WrapperConfiguration(
endpoint="https://test.cloud.databricks.com",
table_name="test_table",
)
wrapper = ZerobusWrapper(config)
async with wrapper:
assert wrapper is not None
except Exception as e:
assert isinstance(e, (ConfigurationError, ConnectionError, ImportError))
@pytest.mark.asyncio
async def test_concurrent_python_operations():
async def create_wrapper():
try:
config = WrapperConfiguration(
endpoint="https://test.cloud.databricks.com",
table_name="test_table",
)
wrapper = ZerobusWrapper(config)
return wrapper
except Exception:
return None
tasks = [create_wrapper() for _ in range(5)]
results = await asyncio.gather(*tasks, return_exceptions=True)
assert len(results) == 5
for result in results:
if isinstance(result, Exception):
assert isinstance(
result, (ConfigurationError, ConnectionError, ImportError)
)
def test_record_batch_conversion():
schema = pa.schema(
[
pa.field("id", pa.int64()),
pa.field("name", pa.string()),
pa.field("score", pa.float64()),
]
)
arrays = [
pa.array([1, 2, 3], type=pa.int64()),
pa.array(["Alice", "Bob", "Charlie"], type=pa.string()),
pa.array([95.5, 87.0, 92.5], type=pa.float64()),
]
batch = pa.RecordBatch.from_arrays(arrays, schema=schema)
assert batch.num_rows == 3
assert batch.num_columns == 3
assert len(batch.schema) == 3
try:
config = WrapperConfiguration(
endpoint="https://test.cloud.databricks.com",
table_name="test_table",
)
wrapper = ZerobusWrapper(config)
assert batch is not None
assert wrapper is not None
except Exception:
pass
def test_record_batch_various_types():
schema = pa.schema(
[
pa.field("int32", pa.int32()),
pa.field("int64", pa.int64()),
pa.field("float32", pa.float32()),
pa.field("float64", pa.float64()),
pa.field("string", pa.string()),
pa.field("bool", pa.bool_()),
]
)
arrays = [
pa.array([1, 2, 3], type=pa.int32()),
pa.array([10, 20, 30], type=pa.int64()),
pa.array([1.5, 2.5, 3.5], type=pa.float32()),
pa.array([10.5, 20.5, 30.5], type=pa.float64()),
pa.array(["a", "b", "c"], type=pa.string()),
pa.array([True, False, True], type=pa.bool_()),
]
batch = pa.RecordBatch.from_arrays(arrays, schema=schema)
assert batch.num_rows == 3
assert batch.num_columns == 6
def test_record_batch_with_nulls():
schema = pa.schema(
[
pa.field("id", pa.int64()),
pa.field("name", pa.string()),
]
)
arrays = [
pa.array([1, None, 3], type=pa.int64()),
pa.array(["Alice", "Bob", None], type=pa.string()),
]
batch = pa.RecordBatch.from_arrays(arrays, schema=schema)
assert batch.num_rows == 3
assert batch.num_columns == 2
def test_wrapper_configuration_methods():
config = WrapperConfiguration(
endpoint="https://test.cloud.databricks.com",
table_name="test_table",
)
assert hasattr(config, "endpoint")
assert hasattr(config, "table_name")
config_with_creds = WrapperConfiguration(
endpoint="https://test.cloud.databricks.com",
table_name="test_table",
client_id="test_id",
client_secret="test_secret",
)
assert config_with_creds is not None
def test_transmission_result_structure():
assert TransmissionResult is not None
def test_wrapper_initialization_with_options():
try:
config1 = WrapperConfiguration(
endpoint="https://test.cloud.databricks.com",
table_name="test_table",
)
wrapper1 = ZerobusWrapper(config1)
assert wrapper1 is not None
except Exception:
pass
try:
config2 = WrapperConfiguration(
endpoint="https://test.cloud.databricks.com",
table_name="test_table",
client_id="test_id",
client_secret="test_secret",
unity_catalog_url="https://unity-catalog-url",
)
wrapper2 = ZerobusWrapper(config2)
assert wrapper2 is not None
except Exception:
pass
def test_error_hierarchy():
assert issubclass(ConfigurationError, Exception)
assert issubclass(AuthenticationError, Exception)
assert issubclass(ConnectionError, Exception)
assert issubclass(ConversionError, Exception)
assert issubclass(TransmissionError, Exception)
assert issubclass(RetryExhausted, Exception)
assert issubclass(TokenRefreshError, Exception)
def test_configuration_validation_from_python():
valid_config = WrapperConfiguration(
endpoint="https://test.cloud.databricks.com",
table_name="test_table",
)
try:
valid_config.validate()
except Exception as e:
assert isinstance(e, ConfigurationError)
invalid_config = WrapperConfiguration(
endpoint="invalid-endpoint",
table_name="test_table",
)
try:
invalid_config.validate()
except Exception as e:
assert isinstance(e, ConfigurationError)
@pytest.mark.asyncio
async def test_async_send_batch():
try:
config = WrapperConfiguration(
endpoint="https://test.cloud.databricks.com",
table_name="test_table",
)
wrapper = ZerobusWrapper(config)
schema = pa.schema(
[
pa.field("id", pa.int64()),
pa.field("name", pa.string()),
]
)
arrays = [
pa.array([1, 2, 3], type=pa.int64()),
pa.array(["Alice", "Bob", "Charlie"], type=pa.string()),
]
batch = pa.RecordBatch.from_arrays(arrays, schema=schema)
try:
result = await wrapper.send_batch(batch)
assert hasattr(result, "success")
assert hasattr(result, "attempts")
assert hasattr(result, "batch_size_bytes")
except Exception as e:
assert isinstance(
e, (ConfigurationError, AuthenticationError, ConnectionError)
)
except Exception:
pass
def test_module_imports():
assert hasattr(arrow_zerobus_sdk_wrapper, "ZerobusWrapper")
assert hasattr(arrow_zerobus_sdk_wrapper, "WrapperConfiguration")
assert hasattr(arrow_zerobus_sdk_wrapper, "TransmissionResult")
assert hasattr(arrow_zerobus_sdk_wrapper, "ZerobusError")
assert hasattr(arrow_zerobus_sdk_wrapper, "ConfigurationError")
assert hasattr(arrow_zerobus_sdk_wrapper, "AuthenticationError")
assert hasattr(arrow_zerobus_sdk_wrapper, "ConnectionError")
assert hasattr(arrow_zerobus_sdk_wrapper, "ConversionError")
assert hasattr(arrow_zerobus_sdk_wrapper, "TransmissionError")
assert hasattr(arrow_zerobus_sdk_wrapper, "RetryExhausted")
assert hasattr(arrow_zerobus_sdk_wrapper, "TokenRefreshError")
def test_pyarrow_compatibility():
schema = pa.schema(
[
pa.field("id", pa.int64()),
pa.field("name", pa.string()),
]
)
id_array = pa.array([1, 2, 3], type=pa.int64())
name_array = pa.array(["Alice", "Bob", "Charlie"], type=pa.string())
batch = pa.RecordBatch.from_arrays([id_array, name_array], schema=schema)
assert isinstance(batch, pa.RecordBatch)
assert batch.num_rows == 3
assert batch.num_columns == 2
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, batch.schema) as writer:
writer.write_batch(batch)
sink_bytes = sink.getvalue()
assert len(sink_bytes) > 0