Skip to main content

nautilus_serialization/arrow/
custom.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this code except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Custom data: registration and dynamic decoding.
17//!
18//! - **Registration:** Call [`ensure_custom_data_registered::<T>()`] once (e.g. before using the
19//!   catalog) for each custom data type `T` produced by the `#[custom_data]` macro. When Python
20//!   support is enabled, also call `nautilus_model::data::register_rust_extractor::<T>()`.
21//! - **Decoder:** [`CustomDataDecoder`] provides [`ArrowSchemaProvider`] and
22//!   [`DecodeDataFromRecordBatch`] for Parquet-backed custom data decoded at runtime by type name.
23//!   Types must be registered via [`ensure_custom_data_registered::<T>()`] before use.
24
25use std::sync::Arc;
26
27use arrow::record_batch::RecordBatch;
28use nautilus_model::data::{
29    ArrowDecoder, ArrowEncoder, CustomData, CustomDataTrait, Data, DataType,
30    decode_custom_from_arrow, ensure_arrow_registered, ensure_custom_data_json_registered,
31    get_arrow_schema,
32};
33
34use super::{ArrowSchemaProvider, DecodeDataFromRecordBatch, EncodeToRecordBatch};
35
36/// Trait for custom data types that support Arrow schema and record batch encoding.
37/// Used as a type bound by the `#[custom_data]` macro; catalog encoding goes through
38/// the registry, not this trait directly.
39///
40/// Implemented by the `#[custom_data]` macro for Rust custom data types. Python custom
41/// types use the registry encoder registered by `register_custom_data_class` instead.
42pub trait CustomDataSerialize: CustomDataTrait {
43    /// Returns the Arrow schema for this custom data type.
44    ///
45    /// # Errors
46    /// Returns an error if schema construction fails.
47    fn schema(&self) -> anyhow::Result<arrow::datatypes::Schema>;
48
49    /// Encodes a batch of custom data items to an Arrow RecordBatch.
50    ///
51    /// # Errors
52    /// Returns an error if encoding fails (e.g. type mismatch or Arrow error).
53    fn encode_record_batch(
54        &self,
55        items: &[Arc<dyn CustomDataTrait>],
56    ) -> anyhow::Result<RecordBatch>;
57}
58
59/// Registers a custom data type in the JSON and Arrow registries. Call once per type
60/// (e.g. at catalog decode or before querying custom data).
61///
62/// Each distinct type `T` is registered at most once (per process). Safe to call
63/// multiple times for the same `T`.
64///
65/// When Python support is enabled, also call
66/// `nautilus_model::data::register_rust_extractor::<T>()` for types exposed to Python.
67pub fn ensure_custom_data_registered<T>()
68where
69    T: CustomDataTrait
70        + ArrowSchemaProvider
71        + EncodeToRecordBatch
72        + DecodeDataFromRecordBatch
73        + Clone
74        + Send
75        + Sync
76        + 'static,
77{
78    let type_name = T::type_name_static();
79
80    // Skip if already registered
81    if get_arrow_schema(type_name).is_some() {
82        return;
83    }
84
85    let _ = ensure_custom_data_json_registered::<T>();
86
87    let schema = Arc::new(T::get_schema(None));
88
89    let encoder: ArrowEncoder = Box::new(|items: &[Arc<dyn CustomDataTrait>]| {
90        let typed: Result<Vec<T>, _> = items
91            .iter()
92            .map(|b| {
93                b.as_any()
94                    .downcast_ref::<T>()
95                    .cloned()
96                    .ok_or_else(|| anyhow::anyhow!("Expected {}", T::type_name_static()))
97            })
98            .collect();
99        let typed = typed?;
100        let metadata = typed
101            .first()
102            .map(EncodeToRecordBatch::metadata)
103            .unwrap_or_default();
104        EncodeToRecordBatch::encode_batch(&metadata, &typed).map_err(|e| anyhow::anyhow!("{e}"))
105    });
106
107    let decoder: ArrowDecoder = Box::new(|metadata, batch| {
108        T::decode_data_batch(metadata, batch).map_err(|e| anyhow::anyhow!("{e}"))
109    });
110
111    let _ = ensure_arrow_registered(type_name, schema, encoder, decoder);
112}
113
114/// Decoder for custom data types that are identified at runtime by metadata (e.g. `type_name`).
115///
116/// Only Rust-registered custom types (e.g. `RustTestCustomData`, `MacroYieldCurveData`) can be
117/// decoded. Unknown types return an error.
118///
119/// **Important:** The caller must ensure that any Rust custom data types are registered
120/// via [`ensure_custom_data_registered::<T>()`] before use.
121#[derive(Debug)]
122pub struct CustomDataDecoder;
123
124impl ArrowSchemaProvider for CustomDataDecoder {
125    fn get_schema(
126        metadata: Option<std::collections::HashMap<String, String>>,
127    ) -> arrow::datatypes::Schema {
128        if let Some(metadata) = metadata
129            && let Some(type_name) = metadata.get("type_name")
130            && let Some(schema) = get_arrow_schema(type_name)
131        {
132            return (*schema).clone();
133        }
134
135        // Unknown type - return minimal schema (caller should not use this for decode)
136        arrow::datatypes::Schema::new(vec![arrow::datatypes::Field::new(
137            "dummy",
138            arrow::datatypes::DataType::Int64,
139            true,
140        )])
141    }
142}
143
144/// Strips the data_type column from a record batch and returns the parsed DataType.
145/// Returns (batch, None) if there is no data_type column.
146fn strip_data_type_column(
147    batch: &RecordBatch,
148) -> Result<(RecordBatch, Option<DataType>), super::EncodingError> {
149    use super::extract_column_string;
150
151    let Some(data_type_col_idx) = batch
152        .schema()
153        .fields()
154        .iter()
155        .position(|f| f.name() == "data_type")
156    else {
157        return Ok((batch.clone(), None));
158    };
159
160    if batch.num_rows() == 0 {
161        return Ok((batch.clone(), None));
162    }
163
164    let cols = batch.columns();
165    let data_type = if cols[data_type_col_idx].is_null(0) {
166        None
167    } else {
168        let string_col =
169            extract_column_string(cols, "data_type", data_type_col_idx).map_err(|e| {
170                super::EncodingError::ParseError("custom_data", format!("data_type column: {e}"))
171            })?;
172        let first_value = string_col.value(0);
173        Some(
174            DataType::from_persistence_json(first_value)
175                .map_err(|e| super::EncodingError::ParseError("custom_data", e.to_string()))?,
176        )
177    };
178
179    let new_fields: Vec<_> = batch
180        .schema()
181        .fields()
182        .iter()
183        .enumerate()
184        .filter(|(i, _)| *i != data_type_col_idx)
185        .map(|(_, f)| f.clone())
186        .collect();
187    let new_columns: Vec<Arc<dyn arrow::array::Array>> = batch
188        .columns()
189        .iter()
190        .enumerate()
191        .filter(|(i, _)| *i != data_type_col_idx)
192        .map(|(_, c)| Arc::clone(c))
193        .collect();
194    let new_schema =
195        arrow::datatypes::Schema::new_with_metadata(new_fields, batch.schema().metadata().clone());
196    let stripped_batch = RecordBatch::try_new(Arc::new(new_schema), new_columns)
197        .map_err(|e| super::EncodingError::ParseError("custom_data", e.to_string()))?;
198
199    Ok((stripped_batch, data_type))
200}
201
202impl DecodeDataFromRecordBatch for CustomDataDecoder {
203    fn decode_data_batch(
204        metadata: &std::collections::HashMap<String, String>,
205        record_batch: RecordBatch,
206    ) -> Result<Vec<Data>, super::EncodingError> {
207        let type_name = metadata
208            .get("type_name")
209            .cloned()
210            .unwrap_or_else(|| "Unknown".to_string());
211
212        let (batch_to_decode, restored_data_type) = strip_data_type_column(&record_batch)?;
213
214        if batch_to_decode.num_rows() == 0 {
215            return Ok(Vec::new());
216        }
217
218        let data = match decode_custom_from_arrow(&type_name, metadata, batch_to_decode) {
219            Ok(Some(d)) => d,
220            Ok(None) => {
221                return Err(super::EncodingError::ParseError(
222                    "custom_data",
223                    format!(
224                        "unknown custom data type '{type_name}'; only Rust-registered types are supported"
225                    ),
226                ));
227            }
228            Err(e) => {
229                return Err(super::EncodingError::ParseError(
230                    "custom_data",
231                    format!("decode_custom_from_arrow: {e}"),
232                ));
233            }
234        };
235
236        if let Some(dt) = restored_data_type {
237            Ok(data
238                .into_iter()
239                .map(|d| {
240                    if let Data::Custom(c) = d {
241                        Data::Custom(CustomData::new(Arc::clone(&c.data), dt.clone()))
242                    } else {
243                        d
244                    }
245                })
246                .collect())
247        } else {
248            Ok(data)
249        }
250    }
251}