nautilus_serialization/arrow/
custom.rs1use std::sync::Arc;
26
27use arrow::record_batch::RecordBatch;
28use nautilus_model::data::{
29 ArrowDecoder, ArrowEncoder, CustomData, CustomDataTrait, Data, DataType,
30 decode_custom_from_arrow, ensure_arrow_registered, ensure_custom_data_json_registered,
31 get_arrow_schema,
32};
33
34use super::{ArrowSchemaProvider, DecodeDataFromRecordBatch, EncodeToRecordBatch};
35
36pub trait CustomDataSerialize: CustomDataTrait {
43 fn schema(&self) -> anyhow::Result<arrow::datatypes::Schema>;
48
49 fn encode_record_batch(
54 &self,
55 items: &[Arc<dyn CustomDataTrait>],
56 ) -> anyhow::Result<RecordBatch>;
57}
58
59pub fn ensure_custom_data_registered<T>()
68where
69 T: CustomDataTrait
70 + ArrowSchemaProvider
71 + EncodeToRecordBatch
72 + DecodeDataFromRecordBatch
73 + Clone
74 + Send
75 + Sync
76 + 'static,
77{
78 let type_name = T::type_name_static();
79
80 if get_arrow_schema(type_name).is_some() {
82 return;
83 }
84
85 let _ = ensure_custom_data_json_registered::<T>();
86
87 let schema = Arc::new(T::get_schema(None));
88
89 let encoder: ArrowEncoder = Box::new(|items: &[Arc<dyn CustomDataTrait>]| {
90 let typed: Result<Vec<T>, _> = items
91 .iter()
92 .map(|b| {
93 b.as_any()
94 .downcast_ref::<T>()
95 .cloned()
96 .ok_or_else(|| anyhow::anyhow!("Expected {}", T::type_name_static()))
97 })
98 .collect();
99 let typed = typed?;
100 let metadata = typed
101 .first()
102 .map(EncodeToRecordBatch::metadata)
103 .unwrap_or_default();
104 EncodeToRecordBatch::encode_batch(&metadata, &typed).map_err(|e| anyhow::anyhow!("{e}"))
105 });
106
107 let decoder: ArrowDecoder = Box::new(|metadata, batch| {
108 T::decode_data_batch(metadata, batch).map_err(|e| anyhow::anyhow!("{e}"))
109 });
110
111 let _ = ensure_arrow_registered(type_name, schema, encoder, decoder);
112}
113
114#[derive(Debug)]
122pub struct CustomDataDecoder;
123
124impl ArrowSchemaProvider for CustomDataDecoder {
125 fn get_schema(
126 metadata: Option<std::collections::HashMap<String, String>>,
127 ) -> arrow::datatypes::Schema {
128 if let Some(metadata) = metadata
129 && let Some(type_name) = metadata.get("type_name")
130 && let Some(schema) = get_arrow_schema(type_name)
131 {
132 return (*schema).clone();
133 }
134
135 arrow::datatypes::Schema::new(vec![arrow::datatypes::Field::new(
137 "dummy",
138 arrow::datatypes::DataType::Int64,
139 true,
140 )])
141 }
142}
143
144fn strip_data_type_column(
147 batch: &RecordBatch,
148) -> Result<(RecordBatch, Option<DataType>), super::EncodingError> {
149 use super::extract_column_string;
150
151 let Some(data_type_col_idx) = batch
152 .schema()
153 .fields()
154 .iter()
155 .position(|f| f.name() == "data_type")
156 else {
157 return Ok((batch.clone(), None));
158 };
159
160 if batch.num_rows() == 0 {
161 return Ok((batch.clone(), None));
162 }
163
164 let cols = batch.columns();
165 let data_type = if cols[data_type_col_idx].is_null(0) {
166 None
167 } else {
168 let string_col =
169 extract_column_string(cols, "data_type", data_type_col_idx).map_err(|e| {
170 super::EncodingError::ParseError("custom_data", format!("data_type column: {e}"))
171 })?;
172 let first_value = string_col.value(0);
173 Some(
174 DataType::from_persistence_json(first_value)
175 .map_err(|e| super::EncodingError::ParseError("custom_data", e.to_string()))?,
176 )
177 };
178
179 let new_fields: Vec<_> = batch
180 .schema()
181 .fields()
182 .iter()
183 .enumerate()
184 .filter(|(i, _)| *i != data_type_col_idx)
185 .map(|(_, f)| f.clone())
186 .collect();
187 let new_columns: Vec<Arc<dyn arrow::array::Array>> = batch
188 .columns()
189 .iter()
190 .enumerate()
191 .filter(|(i, _)| *i != data_type_col_idx)
192 .map(|(_, c)| Arc::clone(c))
193 .collect();
194 let new_schema =
195 arrow::datatypes::Schema::new_with_metadata(new_fields, batch.schema().metadata().clone());
196 let stripped_batch = RecordBatch::try_new(Arc::new(new_schema), new_columns)
197 .map_err(|e| super::EncodingError::ParseError("custom_data", e.to_string()))?;
198
199 Ok((stripped_batch, data_type))
200}
201
202impl DecodeDataFromRecordBatch for CustomDataDecoder {
203 fn decode_data_batch(
204 metadata: &std::collections::HashMap<String, String>,
205 record_batch: RecordBatch,
206 ) -> Result<Vec<Data>, super::EncodingError> {
207 let type_name = metadata
208 .get("type_name")
209 .cloned()
210 .unwrap_or_else(|| "Unknown".to_string());
211
212 let (batch_to_decode, restored_data_type) = strip_data_type_column(&record_batch)?;
213
214 if batch_to_decode.num_rows() == 0 {
215 return Ok(Vec::new());
216 }
217
218 let data = match decode_custom_from_arrow(&type_name, metadata, batch_to_decode) {
219 Ok(Some(d)) => d,
220 Ok(None) => {
221 return Err(super::EncodingError::ParseError(
222 "custom_data",
223 format!(
224 "unknown custom data type '{type_name}'; only Rust-registered types are supported"
225 ),
226 ));
227 }
228 Err(e) => {
229 return Err(super::EncodingError::ParseError(
230 "custom_data",
231 format!("decode_custom_from_arrow: {e}"),
232 ));
233 }
234 };
235
236 if let Some(dt) = restored_data_type {
237 Ok(data
238 .into_iter()
239 .map(|d| {
240 if let Data::Custom(c) = d {
241 Data::Custom(CustomData::new(Arc::clone(&c.data), dt.clone()))
242 } else {
243 d
244 }
245 })
246 .collect())
247 } else {
248 Ok(data)
249 }
250 }
251}