#![allow(non_local_definitions)]
use crate::config::OtlpSdkConfig;
use crate::config::WrapperConfiguration;
use crate::error::ZerobusError;
use crate::wrapper::{TransmissionResult, ZerobusWrapper};
use arrow::datatypes::DataType;
use arrow::record_batch::RecordBatch;
use pyo3::exceptions::{PyException, PyNotImplementedError, PyTypeError};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyModule};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::runtime::Runtime;
pub fn register_module(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyZerobusWrapper>()?;
m.add_class::<PyTransmissionResult>()?;
m.add_class::<PyWrapperConfiguration>()?;
m.add_class::<PyZerobusError>()?;
m.add_class::<PyConfigurationError>()?;
m.add_class::<PyAuthenticationError>()?;
m.add_class::<PyConnectionError>()?;
m.add_class::<PyConversionError>()?;
m.add_class::<PyTransmissionError>()?;
m.add_class::<PyRetryExhausted>()?;
m.add_class::<PyTokenRefreshError>()?;
Ok(())
}
pub fn rust_error_to_python_error(error: ZerobusError) -> PyErr {
match error {
ZerobusError::ConfigurationError(msg) => PyErr::new::<PyConfigurationError, _>(msg),
ZerobusError::AuthenticationError(msg) => PyErr::new::<PyAuthenticationError, _>(msg),
ZerobusError::ConnectionError(msg) => PyErr::new::<PyConnectionError, _>(msg),
ZerobusError::ConversionError(msg) => PyErr::new::<PyConversionError, _>(msg),
ZerobusError::TransmissionError(msg) => PyErr::new::<PyTransmissionError, _>(msg),
ZerobusError::RetryExhausted(msg) => PyErr::new::<PyRetryExhausted, _>(msg),
ZerobusError::TokenRefreshError(msg) => PyErr::new::<PyTokenRefreshError, _>(msg),
}
}
fn parse_error_string(error_msg: String) -> ZerobusError {
if error_msg.starts_with("ConversionError:") {
ZerobusError::ConversionError(
error_msg
.strip_prefix("ConversionError:")
.unwrap()
.trim()
.to_string(),
)
} else if error_msg.starts_with("TransmissionError:") {
ZerobusError::TransmissionError(
error_msg
.strip_prefix("TransmissionError:")
.unwrap()
.trim()
.to_string(),
)
} else if error_msg.starts_with("ConnectionError:") {
ZerobusError::ConnectionError(
error_msg
.strip_prefix("ConnectionError:")
.unwrap()
.trim()
.to_string(),
)
} else if error_msg.starts_with("AuthenticationError:") {
ZerobusError::AuthenticationError(
error_msg
.strip_prefix("AuthenticationError:")
.unwrap()
.trim()
.to_string(),
)
} else if error_msg.starts_with("ConfigurationError:") {
ZerobusError::ConfigurationError(
error_msg
.strip_prefix("ConfigurationError:")
.unwrap()
.trim()
.to_string(),
)
} else if error_msg.starts_with("RetryExhausted:") {
ZerobusError::RetryExhausted(
error_msg
.strip_prefix("RetryExhausted:")
.unwrap()
.trim()
.to_string(),
)
} else if error_msg.starts_with("TokenRefreshError:") {
ZerobusError::TokenRefreshError(
error_msg
.strip_prefix("TokenRefreshError:")
.unwrap()
.trim()
.to_string(),
)
} else {
ZerobusError::ConversionError(error_msg)
}
}
#[pyclass(name = "ZerobusError", extends=PyException)]
#[derive(Debug)]
pub struct PyZerobusError;
#[pymethods]
impl PyZerobusError {
}
#[pyclass(name = "ConfigurationError", extends=PyException)]
#[derive(Debug)]
pub struct PyConfigurationError {
message: String,
}
#[pyclass(name = "AuthenticationError", extends=PyException)]
#[derive(Debug)]
pub struct PyAuthenticationError {
message: String,
}
#[pyclass(name = "ConnectionError", extends=PyException)]
#[derive(Debug)]
pub struct PyConnectionError {
message: String,
}
#[pyclass(name = "ConversionError", extends=PyException)]
#[derive(Debug)]
pub struct PyConversionError {
message: String,
}
#[pyclass(name = "TransmissionError", extends=PyException)]
#[derive(Debug)]
pub struct PyTransmissionError {
message: String,
}
#[pyclass(name = "RetryExhausted", extends=PyException)]
#[derive(Debug)]
pub struct PyRetryExhausted {
message: String,
}
#[pyclass(name = "TokenRefreshError", extends=PyException)]
#[derive(Debug)]
pub struct PyTokenRefreshError {
message: String,
}
#[allow(dead_code)] impl PyConfigurationError {
fn new_err(msg: String) -> PyErr {
PyErr::new::<PyConfigurationError, _>(msg)
}
}
#[allow(dead_code)] impl PyAuthenticationError {
fn new_err(msg: String) -> PyErr {
PyErr::new::<PyAuthenticationError, _>(msg)
}
}
#[allow(dead_code)] impl PyConnectionError {
fn new_err(msg: String) -> PyErr {
PyErr::new::<PyConnectionError, _>(msg)
}
}
#[allow(dead_code)] impl PyConversionError {
fn new_err(msg: String) -> PyErr {
PyErr::new::<PyConversionError, _>(msg)
}
}
#[allow(dead_code)] impl PyTransmissionError {
fn new_err(msg: String) -> PyErr {
PyErr::new::<PyTransmissionError, _>(msg)
}
}
#[allow(dead_code)] impl PyRetryExhausted {
fn new_err(msg: String) -> PyErr {
PyErr::new::<PyRetryExhausted, _>(msg)
}
}
#[allow(dead_code)] impl PyTokenRefreshError {
fn new_err(msg: String) -> PyErr {
PyErr::new::<PyTokenRefreshError, _>(msg)
}
}
#[pymethods]
impl PyConfigurationError {
#[new]
fn new(msg: String) -> Self {
Self { message: msg }
}
fn __str__(&self) -> &str {
&self.message
}
}
#[pymethods]
impl PyAuthenticationError {
#[new]
fn new(msg: String) -> Self {
Self { message: msg }
}
fn __str__(&self) -> &str {
&self.message
}
}
#[pymethods]
impl PyConnectionError {
#[new]
fn new(msg: String) -> Self {
Self { message: msg }
}
fn __str__(&self) -> &str {
&self.message
}
}
#[pymethods]
impl PyConversionError {
#[new]
fn new(msg: String) -> Self {
Self { message: msg }
}
fn __str__(&self) -> &str {
&self.message
}
}
#[pymethods]
impl PyTransmissionError {
#[new]
fn new(msg: String) -> Self {
Self { message: msg }
}
fn __str__(&self) -> &str {
&self.message
}
}
#[pymethods]
impl PyRetryExhausted {
#[new]
fn new(msg: String) -> Self {
Self { message: msg }
}
fn __str__(&self) -> &str {
&self.message
}
}
#[pymethods]
impl PyTokenRefreshError {
#[new]
fn new(msg: String) -> Self {
Self { message: msg }
}
fn __str__(&self) -> &str {
&self.message
}
}
#[pyclass(name = "WrapperConfiguration")]
#[derive(Clone)]
#[allow(non_local_definitions)]
pub struct PyWrapperConfiguration {
inner: WrapperConfiguration,
}
#[pymethods]
#[allow(clippy::too_many_arguments)]
impl PyWrapperConfiguration {
#[new]
#[pyo3(signature = (endpoint, table_name, *, client_id=None, client_secret=None, unity_catalog_url=None, observability_enabled=false, observability_config=None, debug_enabled=false, debug_arrow_enabled=None, debug_protobuf_enabled=None, debug_output_dir=None, debug_flush_interval_secs=5, debug_max_file_size=None, debug_max_files_retained=10, retry_max_attempts=5, retry_base_delay_ms=100, retry_max_delay_ms=30000, zerobus_writer_disabled=false))]
pub fn new(
endpoint: String,
table_name: String,
client_id: Option<String>,
client_secret: Option<String>,
unity_catalog_url: Option<String>,
observability_enabled: bool,
observability_config: Option<PyObject>,
debug_enabled: bool,
debug_arrow_enabled: Option<bool>,
debug_protobuf_enabled: Option<bool>,
debug_output_dir: Option<String>,
debug_flush_interval_secs: u64,
debug_max_file_size: Option<u64>,
debug_max_files_retained: Option<usize>,
retry_max_attempts: u32,
retry_base_delay_ms: u64,
retry_max_delay_ms: u64,
zerobus_writer_disabled: bool,
) -> PyResult<Self> {
let mut config = WrapperConfiguration::new(endpoint, table_name);
if let (Some(cid), Some(cs)) = (client_id, client_secret) {
config = config.with_credentials(cid, cs);
}
if let Some(url) = unity_catalog_url {
config = config.with_unity_catalog(url);
}
if observability_enabled {
let otlp_config = if let Some(config_obj) = observability_config {
Python::with_gil(|py| {
let dict = config_obj.extract::<&PyDict>(py)?;
let endpoint = dict
.get_item("endpoint")?
.and_then(|v| v.extract::<String>().ok());
let output_dir = dict
.get_item("output_dir")?
.and_then(|v| v.extract::<String>().ok())
.map(std::path::PathBuf::from);
let write_interval_secs = dict
.get_item("write_interval_secs")?
.and_then(|v| v.extract::<u64>().ok())
.unwrap_or(5);
let log_level = dict
.get_item("log_level")?
.and_then(|v| v.extract::<String>().ok())
.unwrap_or_else(|| "info".to_string());
let otlp_config = OtlpSdkConfig {
endpoint,
output_dir,
write_interval_secs,
log_level,
};
otlp_config.validate().map_err(|e| {
PyException::new_err(format!("Invalid OTLP SDK configuration: {}", e))
})?;
Ok::<OtlpSdkConfig, PyErr>(otlp_config)
})?
} else {
OtlpSdkConfig::default()
};
config = config.with_observability(otlp_config);
}
if let Some(arrow_enabled) = debug_arrow_enabled {
config.debug_arrow_enabled = arrow_enabled;
}
if let Some(protobuf_enabled) = debug_protobuf_enabled {
config.debug_protobuf_enabled = protobuf_enabled;
}
if debug_enabled {
if debug_arrow_enabled.is_none() && debug_protobuf_enabled.is_none() {
config.debug_arrow_enabled = true;
config.debug_protobuf_enabled = true;
}
config.debug_enabled = true;
}
let any_debug_enabled =
config.debug_arrow_enabled || config.debug_protobuf_enabled || config.debug_enabled;
if any_debug_enabled {
if let Some(output_dir) = debug_output_dir {
config.debug_output_dir = Some(PathBuf::from(output_dir));
config.debug_flush_interval_secs = debug_flush_interval_secs;
config.debug_max_file_size = debug_max_file_size;
config.debug_max_files_retained = debug_max_files_retained;
} else {
return Err(PyConfigurationError::new_err(
"debug_output_dir is required when any debug format is enabled. \
Either provide debug_output_dir or disable all debug flags."
.to_string(),
));
}
}
config =
config.with_retry_config(retry_max_attempts, retry_base_delay_ms, retry_max_delay_ms);
if zerobus_writer_disabled {
config = config.with_zerobus_writer_disabled(true);
}
Ok(Self { inner: config })
}
fn validate(&self) -> PyResult<()> {
self.inner.validate().map_err(rust_error_to_python_error)?;
Ok(())
}
#[getter]
fn endpoint(&self) -> String {
self.inner.zerobus_endpoint.clone()
}
#[getter]
fn table_name(&self) -> String {
self.inner.table_name.clone()
}
#[getter]
fn client_id(&self) -> Option<String> {
use secrecy::ExposeSecret;
self.inner
.client_id
.as_ref()
.map(|s| s.expose_secret().to_string())
}
#[getter]
fn client_secret(&self) -> Option<String> {
use secrecy::ExposeSecret;
self.inner
.client_secret
.as_ref()
.map(|s| s.expose_secret().to_string())
}
#[getter]
fn unity_catalog_url(&self) -> Option<String> {
self.inner.unity_catalog_url.clone()
}
#[getter]
fn debug_enabled(&self) -> bool {
self.inner.debug_enabled
}
#[getter]
fn debug_output_dir(&self) -> Option<String> {
self.inner
.debug_output_dir
.as_ref()
.map(|p| p.to_string_lossy().to_string())
}
#[getter]
fn debug_flush_interval_secs(&self) -> u64 {
self.inner.debug_flush_interval_secs
}
#[getter]
fn debug_max_file_size(&self) -> Option<u64> {
self.inner.debug_max_file_size
}
#[getter]
fn retry_max_attempts(&self) -> u32 {
self.inner.retry_max_attempts
}
#[getter]
fn retry_base_delay_ms(&self) -> u64 {
self.inner.retry_base_delay_ms
}
#[getter]
fn retry_max_delay_ms(&self) -> u64 {
self.inner.retry_max_delay_ms
}
#[getter]
fn observability_enabled(&self) -> bool {
self.inner.observability_enabled
}
#[getter]
fn zerobus_writer_disabled(&self) -> bool {
self.inner.zerobus_writer_disabled
}
}
#[pyclass(name = "TransmissionResult")]
#[derive(Clone)]
pub struct PyTransmissionResult {
#[allow(dead_code)] pub inner: TransmissionResult,
}
#[pymethods]
impl PyTransmissionResult {
#[new]
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (success, *, error=None, attempts=1, latency_ms=None, batch_size_bytes=0, failed_rows=None, successful_rows=None, total_rows=0, successful_count=0, failed_count=0, message=None))]
pub fn new(
success: bool,
error: Option<String>,
attempts: u32,
latency_ms: Option<u64>,
batch_size_bytes: usize,
failed_rows: Option<Vec<(usize, String)>>,
successful_rows: Option<Vec<usize>>,
total_rows: usize,
successful_count: usize,
failed_count: usize,
#[allow(unused_variables)] message: Option<String>,
) -> Self {
let rust_failed_rows = failed_rows.map(|rows| {
rows.into_iter()
.map(|(idx, error_msg)| (idx, parse_error_string(error_msg)))
.collect()
});
let rust_error = error.map(parse_error_string);
Self {
inner: TransmissionResult {
success,
error: rust_error,
attempts,
latency_ms,
batch_size_bytes,
failed_rows: rust_failed_rows,
successful_rows,
total_rows,
successful_count,
failed_count,
},
}
}
#[getter]
pub fn success(&self) -> bool {
self.inner.success
}
#[getter]
pub fn error(&self) -> Option<String> {
self.inner.error.as_ref().map(|e| e.to_string())
}
#[getter]
pub fn attempts(&self) -> u32 {
self.inner.attempts
}
#[getter]
pub fn latency_ms(&self) -> Option<u64> {
self.inner.latency_ms
}
#[getter]
pub fn batch_size_bytes(&self) -> usize {
self.inner.batch_size_bytes
}
#[getter]
pub fn failed_rows(&self) -> Option<Vec<(usize, String)>> {
self.inner.failed_rows.as_ref().map(|rows| {
rows.iter()
.map(|(idx, error)| (*idx, error.to_string()))
.collect()
})
}
#[getter]
pub fn successful_rows(&self) -> Option<Vec<usize>> {
self.inner.successful_rows.clone()
}
#[getter]
pub fn total_rows(&self) -> usize {
self.inner.total_rows
}
#[getter]
pub fn successful_count(&self) -> usize {
self.inner.successful_count
}
#[getter]
pub fn failed_count(&self) -> usize {
self.inner.failed_count
}
pub fn get_failed_row_indices(&self) -> Vec<usize> {
self.inner.get_failed_row_indices()
}
pub fn get_successful_row_indices(&self) -> Vec<usize> {
self.inner.get_successful_row_indices()
}
pub fn extract_failed_batch(
&self,
py: Python,
original_batch: PyObject,
) -> PyResult<Option<PyObject>> {
let rust_batch = pyarrow_to_rust_batch(py, original_batch)?;
match self.inner.extract_failed_batch(&rust_batch) {
Some(batch) => {
let py_batch = rust_batch_to_pyarrow(py, &batch)?;
Ok(Some(py_batch))
}
None => Ok(None),
}
}
pub fn extract_successful_batch(
&self,
py: Python,
original_batch: PyObject,
) -> PyResult<Option<PyObject>> {
let rust_batch = pyarrow_to_rust_batch(py, original_batch)?;
match self.inner.extract_successful_batch(&rust_batch) {
Some(batch) => {
let py_batch = rust_batch_to_pyarrow(py, &batch)?;
Ok(Some(py_batch))
}
None => Ok(None),
}
}
pub fn get_failed_row_indices_by_error_type(&self, error_type: &str) -> Vec<usize> {
self.inner
.get_failed_row_indices_by_error_type(|error| match error_type {
"ConversionError" => matches!(error, ZerobusError::ConversionError(_)),
"TransmissionError" => matches!(error, ZerobusError::TransmissionError(_)),
"ConnectionError" => matches!(error, ZerobusError::ConnectionError(_)),
"AuthenticationError" => matches!(error, ZerobusError::AuthenticationError(_)),
"ConfigurationError" => matches!(error, ZerobusError::ConfigurationError(_)),
"RetryExhausted" => matches!(error, ZerobusError::RetryExhausted(_)),
"TokenRefreshError" => matches!(error, ZerobusError::TokenRefreshError(_)),
_ => false,
})
}
pub fn is_partial_success(&self) -> bool {
self.inner.is_partial_success()
}
pub fn has_failed_rows(&self) -> bool {
self.inner.has_failed_rows()
}
pub fn has_successful_rows(&self) -> bool {
self.inner.has_successful_rows()
}
pub fn group_errors_by_type(&self) -> HashMap<String, Vec<usize>> {
self.inner.group_errors_by_type()
}
pub fn get_error_statistics(&self, py: Python) -> PyResult<PyObject> {
let stats = self.inner.get_error_statistics();
let dict = PyDict::new(py);
dict.set_item("total_rows", stats.total_rows)?;
dict.set_item("successful_count", stats.successful_count)?;
dict.set_item("failed_count", stats.failed_count)?;
dict.set_item("success_rate", stats.success_rate)?;
dict.set_item("failure_rate", stats.failure_rate)?;
let error_type_counts = PyDict::new(py);
for (error_type, count) in stats.error_type_counts {
error_type_counts.set_item(error_type, count)?;
}
dict.set_item("error_type_counts", error_type_counts)?;
Ok(dict.to_object(py))
}
pub fn get_error_messages(&self) -> Vec<String> {
self.inner.get_error_messages()
}
}
#[pyclass(name = "ZerobusWrapper")]
#[allow(non_local_definitions)]
pub struct PyZerobusWrapper {
inner: Arc<ZerobusWrapper>,
runtime: Arc<Runtime>,
}
#[pymethods]
impl PyZerobusWrapper {
#[new]
fn new(config: PyWrapperConfiguration) -> PyResult<Self> {
config.validate()?;
let runtime = Runtime::new()
.map_err(|e| PyException::new_err(format!("Failed to create Tokio runtime: {}", e)))?;
let wrapper = runtime.block_on(async {
ZerobusWrapper::new(config.inner.clone())
.await
.map_err(rust_error_to_python_error)
})?;
Ok(Self {
inner: Arc::new(wrapper),
runtime: Arc::new(runtime),
})
}
fn send_batch(&self, py: Python, batch: PyObject) -> PyResult<PyTransmissionResult> {
let rust_batch = pyarrow_to_rust_batch(py, batch)?;
let result = self
.runtime
.block_on(async { self.inner.send_batch(rust_batch).await });
match result {
Ok(transmission_result) => Ok(PyTransmissionResult {
inner: transmission_result,
}),
Err(e) => Err(rust_error_to_python_error(e)),
}
}
fn flush(&self, _py: Python) -> PyResult<()> {
self.runtime
.block_on(async { self.inner.flush().await })
.map_err(rust_error_to_python_error)?;
Ok(())
}
fn shutdown(&self, _py: Python) -> PyResult<()> {
self.runtime
.block_on(async { self.inner.shutdown().await })
.map_err(rust_error_to_python_error)?;
Ok(())
}
fn __aenter__(&self) -> PyResult<Self> {
Ok(self.clone())
}
fn __aexit__(
&self,
_py: Python,
_exc_type: PyObject,
_exc_val: PyObject,
_exc_tb: PyObject,
) -> PyResult<()> {
self.shutdown(_py)?;
Ok(())
}
}
impl Clone for PyZerobusWrapper {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
runtime: Arc::clone(&self.runtime),
}
}
}
fn pyarrow_to_rust_batch(py: Python, batch: PyObject) -> PyResult<RecordBatch> {
let pyarrow = PyModule::import(py, "pyarrow")?;
let record_batch_class = pyarrow.getattr("RecordBatch")?;
let batch_ref = batch.as_ref(py);
if !batch_ref.is_instance(record_batch_class)? {
return Err(PyTypeError::new_err(
"Expected pyarrow.RecordBatch, got different type",
));
}
if let Ok(c_batch) = pyarrow_to_rust_batch_c_interface(py, batch_ref) {
return Ok(c_batch);
}
pyarrow_to_rust_batch_python_api(py, batch_ref)
}
fn pyarrow_to_rust_batch_c_interface(_py: Python, batch_ref: &PyAny) -> PyResult<RecordBatch> {
use arrow::ipc::reader::StreamReader;
use std::io::Cursor;
let serialized = batch_ref.call_method0("to_pybytes")?;
let bytes: Vec<u8> = serialized.extract()?;
let cursor = Cursor::new(bytes);
let mut reader = StreamReader::try_new(cursor, None)
.map_err(|e| PyException::new_err(format!("Failed to create IPC reader: {}", e)))?;
let batch = reader
.next()
.ok_or_else(|| PyException::new_err("No RecordBatch in IPC stream"))?
.map_err(|e| PyException::new_err(format!("Failed to read RecordBatch: {}", e)))?;
Ok(batch)
}
fn pyarrow_to_rust_batch_python_api(py: Python, batch_ref: &PyAny) -> PyResult<RecordBatch> {
use arrow::array::*;
use arrow::datatypes::{Field, Schema};
use std::sync::Arc;
let schema_obj = batch_ref.getattr("schema")?;
let num_fields = schema_obj.call_method0("__len__")?.extract::<usize>()?;
let mut rust_fields = Vec::new();
let mut rust_arrays = Vec::new();
for i in 0..num_fields {
let field_obj = schema_obj.get_item(i)?;
let field_name = field_obj.getattr("name")?.extract::<String>()?;
let field_type_obj = field_obj.getattr("type")?;
let field_type_str = format!("{}", field_type_obj);
let rust_type = pyarrow_type_to_rust_type(&field_type_str)?;
rust_fields.push(Field::new(field_name.clone(), rust_type.clone(), true));
let array_obj = batch_ref.call_method1("column", (i,))?;
let rust_array = pyarrow_array_to_rust_array(py, array_obj, &rust_type)?;
rust_arrays.push(rust_array);
}
let schema = Schema::new(rust_fields);
RecordBatch::try_new(Arc::new(schema), rust_arrays)
.map_err(|e| PyException::new_err(format!("Failed to create RecordBatch: {}", e)))
}
fn pyarrow_type_to_rust_type(type_str: &str) -> PyResult<DataType> {
if type_str.contains("int64") {
Ok(DataType::Int64)
} else if type_str.contains("int32") {
Ok(DataType::Int32)
} else if type_str.contains("string") || type_str.contains("utf8") {
Ok(DataType::Utf8)
} else if type_str.contains("float64") || type_str.contains("double") {
Ok(DataType::Float64)
} else if type_str.contains("float32") || type_str.contains("float") {
Ok(DataType::Float32)
} else if type_str.contains("bool") {
Ok(DataType::Boolean)
} else if type_str.contains("binary") {
Ok(DataType::Binary)
} else {
Err(PyNotImplementedError::new_err(format!(
"Unsupported PyArrow type: {}",
type_str
)))
}
}
fn pyarrow_array_to_rust_array(
_py: Python,
array_obj: &PyAny,
data_type: &DataType,
) -> PyResult<Arc<dyn arrow::array::Array>> {
use arrow::array::*;
use std::sync::Arc;
let len = array_obj.call_method0("__len__")?.extract::<usize>()?;
match data_type {
DataType::Int64 => {
let values: Vec<Option<i64>> = (0..len)
.map(|i| {
let val = array_obj.get_item(i)?;
if val.is_none() {
Ok(None)
} else {
Ok(Some(val.extract::<i64>()?))
}
})
.collect::<PyResult<Vec<_>>>()?;
Ok(Arc::new(Int64Array::from(values)))
}
DataType::Utf8 => {
let values: Vec<Option<String>> = (0..len)
.map(|i| {
let val = array_obj.get_item(i)?;
if val.is_none() {
Ok(None)
} else {
let py_str = if val.hasattr("as_py")? {
val.call_method0("as_py")?
} else {
val.call_method0("__str__")?
};
Ok(Some(py_str.extract::<String>()?))
}
})
.collect::<PyResult<Vec<_>>>()?;
Ok(Arc::new(StringArray::from(values)))
}
DataType::Float64 => {
let values: Vec<Option<f64>> = (0..len)
.map(|i| {
let val = array_obj.get_item(i)?;
if val.is_none() {
Ok(None)
} else {
Ok(Some(val.extract::<f64>()?))
}
})
.collect::<PyResult<Vec<_>>>()?;
Ok(Arc::new(Float64Array::from(values)))
}
DataType::Boolean => {
let values: Vec<Option<bool>> = (0..len)
.map(|i| {
let val = array_obj.get_item(i)?;
if val.is_none() {
Ok(None)
} else {
Ok(Some(val.extract::<bool>()?))
}
})
.collect::<PyResult<Vec<_>>>()?;
Ok(Arc::new(BooleanArray::from(values)))
}
_ => Err(PyNotImplementedError::new_err(format!(
"Array type conversion not yet implemented for: {:?}",
data_type
))),
}
}
fn rust_batch_to_pyarrow(py: Python, batch: &RecordBatch) -> PyResult<PyObject> {
use arrow::ipc::writer::StreamWriter;
use pyo3::types::PyBytes;
use std::io::Cursor;
let mut buffer = Vec::new();
let cursor = Cursor::new(&mut buffer);
let mut writer = StreamWriter::try_new(cursor, &batch.schema())
.map_err(|e| PyException::new_err(format!("Failed to create IPC writer: {}", e)))?;
writer
.write(batch)
.map_err(|e| PyException::new_err(format!("Failed to write RecordBatch: {}", e)))?;
writer
.finish()
.map_err(|e| PyException::new_err(format!("Failed to finish IPC writer: {}", e)))?;
let pyarrow = PyModule::import(py, "pyarrow")?;
let ipc_module = pyarrow.getattr("ipc")?;
let buffer_reader_class = pyarrow.getattr("BufferReader")?;
let ipc_bytes = PyBytes::new(py, &buffer);
let buffer_reader = buffer_reader_class.call1((ipc_bytes,))?;
let open_stream = ipc_module.getattr("open_stream")?;
let stream_reader = open_stream.call1((buffer_reader,))?;
let read_next = stream_reader.call_method0("read_next_batch")?;
Ok(read_next.to_object(py))
}