use std::sync::Arc;
use arrow::record_batch::RecordBatch;
use nautilus_model::data::{
ArrowDecoder, ArrowEncoder, CustomData, CustomDataTrait, Data, DataType,
decode_custom_from_arrow, ensure_arrow_registered, ensure_custom_data_json_registered,
get_arrow_schema,
};
use super::{ArrowSchemaProvider, DecodeDataFromRecordBatch, EncodeToRecordBatch};
pub trait CustomDataSerialize: CustomDataTrait {
fn schema(&self) -> anyhow::Result<arrow::datatypes::Schema>;
fn encode_record_batch(
&self,
items: &[Arc<dyn CustomDataTrait>],
) -> anyhow::Result<RecordBatch>;
}
pub fn ensure_custom_data_registered<T>()
where
T: CustomDataTrait
+ ArrowSchemaProvider
+ EncodeToRecordBatch
+ DecodeDataFromRecordBatch
+ Clone
+ Send
+ Sync
+ 'static,
{
let type_name = T::type_name_static();
if get_arrow_schema(type_name).is_some() {
return;
}
let _ = ensure_custom_data_json_registered::<T>();
let schema = Arc::new(T::get_schema(None));
let encoder: ArrowEncoder = Box::new(|items: &[Arc<dyn CustomDataTrait>]| {
let typed: Result<Vec<T>, _> = items
.iter()
.map(|b| {
b.as_any()
.downcast_ref::<T>()
.cloned()
.ok_or_else(|| anyhow::anyhow!("Expected {}", T::type_name_static()))
})
.collect();
let typed = typed?;
let metadata = typed
.first()
.map(EncodeToRecordBatch::metadata)
.unwrap_or_default();
EncodeToRecordBatch::encode_batch(&metadata, &typed).map_err(|e| anyhow::anyhow!("{e}"))
});
let decoder: ArrowDecoder = Box::new(|metadata, batch| {
T::decode_data_batch(metadata, batch).map_err(|e| anyhow::anyhow!("{e}"))
});
let _ = ensure_arrow_registered(type_name, schema, encoder, decoder);
}
#[derive(Debug)]
pub struct CustomDataDecoder;
impl ArrowSchemaProvider for CustomDataDecoder {
fn get_schema(
metadata: Option<std::collections::HashMap<String, String>>,
) -> arrow::datatypes::Schema {
if let Some(metadata) = metadata
&& let Some(type_name) = metadata.get("type_name")
&& let Some(schema) = get_arrow_schema(type_name)
{
return (*schema).clone();
}
arrow::datatypes::Schema::new(vec![arrow::datatypes::Field::new(
"dummy",
arrow::datatypes::DataType::Int64,
true,
)])
}
}
fn strip_data_type_column(
batch: &RecordBatch,
) -> Result<(RecordBatch, Option<DataType>), super::EncodingError> {
use super::extract_column_string;
let Some(data_type_col_idx) = batch
.schema()
.fields()
.iter()
.position(|f| f.name() == "data_type")
else {
return Ok((batch.clone(), None));
};
if batch.num_rows() == 0 {
return Ok((batch.clone(), None));
}
let cols = batch.columns();
let string_col = extract_column_string(cols, "data_type", data_type_col_idx).map_err(|e| {
super::EncodingError::ParseError("custom_data", format!("data_type column: {e}"))
})?;
let first_value = string_col.value(0);
let data_type = DataType::from_persistence_json(first_value)
.map_err(|e| super::EncodingError::ParseError("custom_data", e.to_string()))?;
let new_fields: Vec<_> = batch
.schema()
.fields()
.iter()
.enumerate()
.filter(|(i, _)| *i != data_type_col_idx)
.map(|(_, f)| f.clone())
.collect();
let new_columns: Vec<Arc<dyn arrow::array::Array>> = batch
.columns()
.iter()
.enumerate()
.filter(|(i, _)| *i != data_type_col_idx)
.map(|(_, c)| Arc::clone(c))
.collect();
let new_schema =
arrow::datatypes::Schema::new_with_metadata(new_fields, batch.schema().metadata().clone());
let stripped_batch = RecordBatch::try_new(Arc::new(new_schema), new_columns)
.map_err(|e| super::EncodingError::ParseError("custom_data", e.to_string()))?;
Ok((stripped_batch, Some(data_type)))
}
impl DecodeDataFromRecordBatch for CustomDataDecoder {
fn decode_data_batch(
metadata: &std::collections::HashMap<String, String>,
record_batch: RecordBatch,
) -> Result<Vec<Data>, super::EncodingError> {
let type_name = metadata
.get("type_name")
.cloned()
.unwrap_or_else(|| "Unknown".to_string());
let (batch_to_decode, restored_data_type) = strip_data_type_column(&record_batch)?;
if batch_to_decode.num_rows() == 0 {
return Ok(Vec::new());
}
let data = match decode_custom_from_arrow(&type_name, metadata, batch_to_decode) {
Ok(Some(d)) => d,
Ok(None) => {
return Err(super::EncodingError::ParseError(
"custom_data",
format!(
"unknown custom data type '{type_name}'; only Rust-registered types are supported"
),
));
}
Err(e) => {
return Err(super::EncodingError::ParseError(
"custom_data",
format!("decode_custom_from_arrow: {e}"),
));
}
};
if let Some(dt) = restored_data_type {
Ok(data
.into_iter()
.map(|d| {
if let Data::Custom(c) = d {
Data::Custom(CustomData::new(Arc::clone(&c.data), dt.clone()))
} else {
d
}
})
.collect())
} else {
Ok(data)
}
}
}