use std::{fmt::Debug, ops::Deref};
use arrow::{
array::{make_array, ArrayData, ArrayRef},
datatypes::{DataType, FieldRef, SchemaRef},
pyarrow::{ArrowException, PyArrowType, ToPyArrow},
record_batch::RecordBatch,
};
use ella::{table::Publisher, tensor::DataFrame};
use futures::SinkExt;
use pyo3::{
exceptions::{PyKeyError, PyValueError},
prelude::*,
types::{PyDict, PyList, PyTuple},
};
use crate::{dataframe::PyDataFrame, utils::wait_for_future};
#[derive(Debug)]
#[pyclass(name = "Publisher")]
pub struct PyPublisher {
len: usize,
capacity: usize,
columns: Py<PyDict>,
schema: SchemaRef,
inner: Publisher,
time: Option<String>,
}
#[pymethods]
impl PyPublisher {
#[pyo3(signature = (*args, **kwargs))]
fn write(&mut self, py: Python, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult<()> {
self.append_row(py, args, kwargs)?;
self.len += 1;
self.maybe_write(py, false)?;
Ok(())
}
fn flush(&mut self, py: Python) -> PyResult<()> {
self.maybe_write(py, true)?;
wait_for_future(py, self.inner.flush())?;
Ok(())
}
#[pyo3(signature = (*args, **kwargs))]
fn write_batch(&mut self, py: Python, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult<()> {
self.maybe_write(py, true)?;
if args.len() == 1 && kwargs.is_none() {
if let Ok(batch) = args[0].extract::<PyArrowType<RecordBatch>>() {
wait_for_future(py, self.inner.send(batch.0))?;
return Ok(());
}
if let Ok(df) = args[0].extract::<PyDataFrame>() {
wait_for_future(py, self.inner.send(DataFrame::from(df).into()))?;
return Ok(());
}
}
let mut arrays = vec![];
for (items, field) in self.map_args(py, args, kwargs)? {
let array = items_to_array(py, items, field)?;
arrays.push(array);
}
let batch = RecordBatch::try_new(self.schema.clone(), arrays)
.map_err(|err| ArrowException::new_err(err.to_string()))?;
wait_for_future(py, self.inner.send(batch))?;
Ok(())
}
fn close(&mut self, py: Python) -> PyResult<()> {
self.maybe_write(py, true)?;
wait_for_future(py, self.inner.close())?;
Ok(())
}
fn __enter__<'py>(this: PyRef<'py, Self>, _py: Python<'py>) -> PyResult<PyRef<'py, Self>> {
Ok(this)
}
fn __exit__(
&mut self,
py: Python,
_exc_type: &PyAny,
_exc_value: &PyAny,
_traceback: &PyAny,
) -> PyResult<()> {
self.close(py)
}
}
impl Drop for PyPublisher {
fn drop(&mut self) {
Python::with_gil(|py| {
if let Err(err) = self.close(py) {
err.restore(py);
}
})
}
}
impl PyPublisher {
pub fn new(py: Python, inner: Publisher, capacity: usize) -> PyResult<Self> {
let schema = inner.arrow_schema().clone();
let columns = PyDict::new(py);
for field in schema.fields() {
columns.set_item(field.name(), PyList::empty(py).to_object(py))?;
}
let time = if schema.fields.len() > 0
&& matches!(schema.fields[0].data_type(), DataType::Timestamp(_, _))
{
Some(schema.fields[0].name().clone())
} else {
None
};
Ok(Self {
len: 0,
capacity,
schema,
inner,
columns: columns.into(),
time,
})
}
fn maybe_write(&mut self, py: Python, force: bool) -> PyResult<()> {
if self.len >= self.capacity || self.len > 0 && force {
self.write_rows(py)?;
let columns = self.columns.as_ref(py);
for (_, col) in columns.iter() {
col.call_method0("clear")?;
}
self.len = 0;
}
Ok(())
}
fn write_rows(&mut self, py: Python) -> PyResult<()> {
let mut arrays = vec![];
let columns = self.columns.as_ref(py);
for field in self.schema.fields() {
let items = columns.get_item(field.name()).unwrap();
arrays.push(items_to_array(py, items, field)?);
}
let batch = RecordBatch::try_new(self.schema.clone(), arrays)
.map_err(|err| ArrowException::new_err(err.to_string()))?;
wait_for_future(py, self.inner.send(batch))?;
Ok(())
}
fn append_row(&self, py: Python, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult<()> {
let columns = self.columns.as_ref(py);
for (value, field) in self.map_args(py, args, kwargs)? {
let col = PyAny::get_item(columns.deref(), field.name())?.downcast::<PyList>()?;
col.append(value)?;
}
Ok(())
}
fn map_args<'a>(
&'a self,
py: Python<'a>,
args: &'a PyTuple,
kwargs: Option<&'a PyDict>,
) -> PyResult<Vec<(&'a PyAny, &'a FieldRef)>> {
if !args.is_empty() {
if kwargs.is_some() {
return Err(PyValueError::new_err(
"publisher accepts position arguments or keyword arguments but not both",
));
}
if args.len() == self.schema.fields.len() {
Ok(args.iter().zip(self.schema.fields()).collect())
} else if args.len() + 1 == self.schema.fields.len() && self.time.is_some() {
Ok(std::iter::once(crate::now(py)?)
.chain(args)
.zip(self.schema.fields())
.collect())
} else {
Err(PyValueError::new_err(format!(
"publisher expects {} arguments but received {}",
self.schema.fields.len(),
args.len(),
)))
}
} else if let Some(kwargs) = kwargs {
if let Some(time) = &self.time {
if !(kwargs.contains(time)?) {
kwargs.set_item(time, crate::now(py)?)?;
}
}
if kwargs.len() != self.schema.fields.len() {
return Err(PyValueError::new_err(format!(
"publisher expects {} arguments but received {}",
self.schema.fields.len(),
kwargs.len(),
)));
}
let mut fields = vec![];
for (k, v) in kwargs.iter() {
let name = k
.extract::<&str>()
.expect("keyword argument key should always be string");
let idx = self
.schema
.index_of(name)
.map_err(|_| PyKeyError::new_err(format!("column not found: \"{name}\"")))?;
fields.push((v, &self.schema.fields[idx]));
}
Ok(fields)
} else {
return Err(PyValueError::new_err(format!(
"publisher expects {} arguments but received 0",
self.schema.fields.len(),
)));
}
}
}
fn items_to_array(py: Python, items: &PyAny, field: &FieldRef) -> PyResult<ArrayRef> {
let pyarrow = py.import("pyarrow")?;
let numpy = py.import("numpy")?;
let data = if let Ok(data) = items.extract::<PyArrowType<ArrayData>>() {
data.0
} else if matches!(field.data_type(), DataType::FixedSizeList(_, _)) {
let items = numpy.getattr("ma")?.getattr("asarray")?.call1((items,))?;
let array = pyarrow
.getattr("FixedShapeTensorArray")?
.call_method1("from_numpy_ndarray", (items,))?
.extract::<PyArrowType<ArrayData>>()?;
array.0
} else {
let pytype = field.data_type().to_pyarrow(py)?;
let array = pyarrow
.getattr("array")?
.call1((items,))?
.call_method1("cast", (pytype,))?
.extract::<PyArrowType<ArrayData>>()?;
array.0
};
Ok(make_array(data))
}