Skip to main content

nautilus_model/data/
registry.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Registries for custom data: JSON (de)serialization and Arrow encode/decode.
17//!
18//! Mirrors Python's `register_serializable_type` and `register_arrow` in `custom.py`.
19//! The registry only stores type name -> callbacks for lookup; each type provides
20//! its own deserialize/encode/decode via the trait or registration.
21
22use std::{collections::HashMap, sync::Arc};
23
24use arrow::{datatypes::Schema, record_batch::RecordBatch};
25use dashmap::{DashMap, mapref::entry::Entry};
26use nautilus_core::Params;
27#[cfg(feature = "python")]
28use pyo3::types::PyAnyMethods;
29
30use crate::data::{CustomData, CustomDataTrait, Data, DataType};
31
32pub type JsonDeserializer =
33    Box<dyn Fn(serde_json::Value) -> Result<Arc<dyn CustomDataTrait>, anyhow::Error> + Send + Sync>;
34pub type ArrowEncoder =
35    Box<dyn Fn(&[Arc<dyn CustomDataTrait>]) -> Result<RecordBatch, anyhow::Error> + Send + Sync>;
36pub type ArrowDecoder = Box<
37    dyn Fn(&HashMap<String, String>, RecordBatch) -> Result<Vec<Data>, anyhow::Error> + Send + Sync,
38>;
39
40struct Registries {
41    json: DashMap<String, JsonDeserializer>,
42    arrow: DashMap<String, (Arc<Schema>, ArrowEncoder, ArrowDecoder)>,
43}
44
45fn registries() -> &'static Registries {
46    static REGISTRIES: std::sync::OnceLock<Registries> = std::sync::OnceLock::new();
47    REGISTRIES.get_or_init(|| Registries {
48        json: DashMap::new(),
49        arrow: DashMap::new(),
50    })
51}
52
53/// Registers a JSON deserializer for the given custom data type name.
54/// When `Data::deserialize` sees this type name, it will call this function.
55///
56/// # Errors
57/// Returns an error if the type is already registered.
58pub fn register_json_deserializer(
59    type_name: &str,
60    deserializer: JsonDeserializer,
61) -> Result<(), anyhow::Error> {
62    let reg = registries();
63    match reg.json.entry(type_name.to_string()) {
64        Entry::Occupied(_) => {
65            anyhow::bail!("Custom data type \"{type_name}\" is already registered for JSON");
66        }
67        Entry::Vacant(v) => {
68            v.insert(deserializer);
69            Ok(())
70        }
71    }
72}
73
74/// Registers a JSON deserializer for the given custom data type name if not already registered.
75/// If the type is already registered, returns `Ok(())` without overwriting (idempotent).
76/// Use this where repeated registration can occur (e.g. module init).
77///
78/// # Errors
79/// Does not return an error (idempotent insert into `DashMap`).
80pub fn ensure_json_deserializer_registered(
81    type_name: &str,
82    deserializer: JsonDeserializer,
83) -> Result<(), anyhow::Error> {
84    let reg = registries();
85    reg.json
86        .entry(type_name.to_string())
87        .or_insert_with(|| deserializer);
88    Ok(())
89}
90
91/// Parses a "`data_type`" JSON object into `DataType` (`type_name`, metadata, identifier).
92fn parse_data_type_from_value(value: &serde_json::Value) -> Option<DataType> {
93    let obj = value.get("data_type")?.as_object()?;
94    let type_name = obj.get("type_name")?.as_str()?;
95    let metadata = obj.get("metadata").and_then(|m| {
96        if m.is_null() {
97            None
98        } else {
99            let p: Params = serde_json::from_value(m.clone()).ok()?;
100            if p.is_empty() { None } else { Some(p) }
101        }
102    });
103    let identifier = obj
104        .get("identifier")
105        .and_then(|v| v.as_str())
106        .map(String::from);
107    Some(DataType::new(type_name, metadata, identifier))
108}
109
110/// Parses the canonical `CustomData` JSON envelope `{ type, data_type, payload }` and returns
111/// the payload value to pass to the registered type deserializer. Does not depend on
112/// user payload field names.
113fn parse_envelope_payload(value: &serde_json::Value) -> Result<serde_json::Value, anyhow::Error> {
114    let payload = value
115        .get("payload")
116        .cloned()
117        .ok_or_else(|| anyhow::anyhow!("CustomData JSON missing 'payload' field"))?;
118    Ok(payload)
119}
120
121/// Looks up and runs the JSON deserializer for the given type name.
122/// Returns `None` if the type is not registered.
123///
124/// # Errors
125/// Returns an error if the deserializer fails.
126pub fn deserialize_custom_from_json(
127    type_name: &str,
128    value: &serde_json::Value,
129) -> Result<Option<Data>, anyhow::Error> {
130    let reg = registries();
131    let deserializer_ref = match reg.json.get(type_name) {
132        Some(d) => d,
133        None => return Ok(None),
134    };
135    let data_type = parse_data_type_from_value(value);
136    let payload = parse_envelope_payload(value)?;
137    let arc = deserializer_ref.value()(payload)?;
138    let custom = match data_type {
139        Some(dt) => CustomData::new(arc, dt),
140        None => CustomData::from_arc(arc),
141    };
142    Ok(Some(Data::Custom(custom)))
143}
144
145/// Registers Arrow schema, encoder, and decoder for the given custom data type name.
146///
147/// # Errors
148/// Returns an error if the type is already registered for Arrow.
149pub fn register_arrow(
150    type_name: &str,
151    schema: Arc<Schema>,
152    encoder: ArrowEncoder,
153    decoder: ArrowDecoder,
154) -> Result<(), anyhow::Error> {
155    let reg = registries();
156    match reg.arrow.entry(type_name.to_string()) {
157        Entry::Occupied(_) => {
158            anyhow::bail!("Custom data type \"{type_name}\" is already registered for Arrow");
159        }
160        Entry::Vacant(v) => {
161            v.insert((schema, encoder, decoder));
162            Ok(())
163        }
164    }
165}
166
167/// Registers Arrow schema, encoder, and decoder for the given custom data type name if not already
168/// registered. If the type is already registered, returns `Ok(())` without overwriting (idempotent).
169/// Use this where repeated registration can occur (e.g. module init).
170///
171/// # Errors
172/// Does not return an error (idempotent insert into `DashMap`).
173pub fn ensure_arrow_registered(
174    type_name: &str,
175    schema: Arc<Schema>,
176    encoder: ArrowEncoder,
177    decoder: ArrowDecoder,
178) -> Result<(), anyhow::Error> {
179    let reg = registries();
180    reg.arrow
181        .entry(type_name.to_string())
182        .or_insert_with(|| (schema, encoder, decoder));
183    Ok(())
184}
185
186/// Returns the Arrow schema for the given custom type name, if registered.
187#[must_use]
188pub fn get_arrow_schema(type_name: &str) -> Option<Arc<Schema>> {
189    let reg = registries();
190    reg.arrow
191        .get(type_name)
192        .map(|entry| Arc::clone(&entry.value().0))
193}
194
195/// Encodes a slice of custom data trait objects to a `RecordBatch` using the registered encoder.
196///
197/// # Errors
198/// Returns an error if the type is not registered or encoding fails.
199pub fn encode_custom_to_arrow(
200    type_name: &str,
201    items: &[Arc<dyn CustomDataTrait>],
202) -> Result<Option<RecordBatch>, anyhow::Error> {
203    let reg = registries();
204    let entry = match reg.arrow.get(type_name) {
205        Some(e) => e,
206        None => return Ok(None),
207    };
208    let encoder = &entry.value().1;
209    encoder(items).map(Some)
210}
211
212/// Decodes a `RecordBatch` into `Vec<Data>` using the registered decoder.
213///
214/// # Errors
215/// Returns an error if the type is not registered or decoding fails.
216#[expect(
217    clippy::implicit_hasher,
218    reason = "callers always use the default hasher"
219)]
220pub fn decode_custom_from_arrow(
221    type_name: &str,
222    metadata: &HashMap<String, String>,
223    record_batch: RecordBatch,
224) -> Result<Option<Vec<Data>>, anyhow::Error> {
225    let reg = registries();
226    let entry = match reg.arrow.get(type_name) {
227        Some(e) => e,
228        None => return Ok(None),
229    };
230    let decoder = &entry.value().2;
231    decoder(metadata, record_batch).map(Some)
232}
233
234#[cfg(feature = "python")]
235pub type PyExtractor = Box<
236    dyn for<'a> Fn(&pyo3::Bound<'a, pyo3::PyAny>) -> Option<Arc<dyn CustomDataTrait>> + Send + Sync,
237>;
238
239#[cfg(feature = "python")]
240fn py_extractors() -> &'static DashMap<String, PyExtractor> {
241    static PY_EXTRACTORS: std::sync::OnceLock<DashMap<String, PyExtractor>> =
242        std::sync::OnceLock::new();
243    PY_EXTRACTORS.get_or_init(DashMap::new)
244}
245
246/// Registers a `PyExtractor` for the given custom data type name.
247/// Used by `CustomData` constructor to convert Python objects to `Arc<dyn CustomDataTrait>`.
248///
249/// # Errors
250/// Returns an error if the type is already registered.
251#[cfg(feature = "python")]
252pub fn register_py_extractor(type_name: &str, extractor: PyExtractor) -> Result<(), anyhow::Error> {
253    let reg = py_extractors();
254    match reg.entry(type_name.to_string()) {
255        Entry::Occupied(_) => {
256            anyhow::bail!(
257                "Custom data type \"{type_name}\" is already registered for Python extraction"
258            );
259        }
260        Entry::Vacant(v) => {
261            v.insert(extractor);
262            Ok(())
263        }
264    }
265}
266
267/// Registers a `PyExtractor` for the given custom data type name if not already registered.
268/// If the type is already registered, returns `Ok(())` without overwriting (idempotent).
269/// Use this where repeated registration can occur (e.g. module init).
270///
271/// # Errors
272/// Does not return an error (idempotent insert into `DashMap`).
273#[cfg(feature = "python")]
274pub fn ensure_py_extractor_registered(
275    type_name: &str,
276    extractor: PyExtractor,
277) -> Result<(), anyhow::Error> {
278    let reg = py_extractors();
279    reg.entry(type_name.to_string())
280        .or_insert_with(|| extractor);
281    Ok(())
282}
283
284/// Tries to extract `Arc<dyn CustomDataTrait>` from a Python object using the registered extractor.
285/// Returns None if no extractor is registered or extraction fails.
286#[cfg(feature = "python")]
287#[must_use]
288pub fn try_extract_from_py(
289    type_name: &str,
290    obj: &pyo3::Bound<'_, pyo3::PyAny>,
291) -> Option<Arc<dyn CustomDataTrait>> {
292    let reg = py_extractors();
293    let entry = reg.get(type_name)?;
294    let extractor = entry.value();
295    extractor(obj)
296}
297
298#[cfg(feature = "python")]
299type RustExtractorFactory = Box<dyn Fn() -> PyExtractor + Send + Sync>;
300
301#[cfg(feature = "python")]
302fn rust_extractor_factories() -> &'static DashMap<String, RustExtractorFactory> {
303    static RUST_EXTRACTOR_FACTORIES: std::sync::OnceLock<DashMap<String, RustExtractorFactory>> =
304        std::sync::OnceLock::new();
305    RUST_EXTRACTOR_FACTORIES.get_or_init(DashMap::new)
306}
307
308/// Registers a factory that produces a `PyExtractor` for the given type name.
309/// Crates (e.g. persistence) call this at load time for each Rust custom data type.
310/// When `register_custom_data_class(cls)` is called with that type's class, the factory is invoked
311/// and the extractor is registered in the main `PyExtractor` registry.
312///
313/// # Errors
314/// Returns an error if the type name is already registered.
315#[cfg(feature = "python")]
316pub fn register_rust_extractor_factory(
317    type_name: &str,
318    factory: RustExtractorFactory,
319) -> Result<(), anyhow::Error> {
320    let reg = rust_extractor_factories();
321    match reg.entry(type_name.to_string()) {
322        Entry::Occupied(_) => {
323            anyhow::bail!("Rust extractor factory for \"{type_name}\" is already registered");
324        }
325        Entry::Vacant(v) => {
326            v.insert(factory);
327            Ok(())
328        }
329    }
330}
331
332/// Registers a factory that produces a `PyExtractor` for the given type name if not already
333/// registered. If the type is already registered, returns `Ok(())` without overwriting (idempotent).
334/// Use this where repeated registration can occur (e.g. module load).
335///
336/// # Errors
337/// Does not return an error (idempotent insert into `DashMap`).
338#[cfg(feature = "python")]
339pub fn ensure_rust_extractor_factory_registered(
340    type_name: &str,
341    factory: RustExtractorFactory,
342) -> Result<(), anyhow::Error> {
343    let reg = rust_extractor_factories();
344    reg.entry(type_name.to_string()).or_insert_with(|| factory);
345    Ok(())
346}
347
348/// Registers a Rust custom data type for Python extraction. Call once per type at module load
349/// (e.g. in the persistence PyO3 module). Uses [`register_rust_extractor_factory`] with a
350/// factory that builds the extractor for `T`.
351///
352/// # Errors
353/// Returns an error if the type name is already registered.
354#[cfg(feature = "python")]
355pub fn register_rust_extractor<T>() -> Result<(), anyhow::Error>
356where
357    T: CustomDataTrait + for<'a, 'py> pyo3::FromPyObject<'a, 'py> + Send + Sync + 'static,
358{
359    let type_name = T::type_name_static();
360    let factory: RustExtractorFactory = Box::new(|| {
361        Box::new(|obj: &pyo3::Bound<'_, pyo3::PyAny>| {
362            obj.extract::<T>()
363                .ok()
364                .map(|x| Arc::new(x) as Arc<dyn CustomDataTrait>)
365        })
366    });
367    register_rust_extractor_factory(type_name, factory)
368}
369
370/// Registers a Rust custom data type for Python extraction if not already registered.
371/// If the type is already registered, returns `Ok(())` without overwriting (idempotent).
372/// Use this where repeated registration can occur (e.g. module load).
373///
374/// # Errors
375/// Does not return an error (idempotent insert into `DashMap`).
376#[cfg(feature = "python")]
377pub fn ensure_rust_extractor_registered<T>() -> Result<(), anyhow::Error>
378where
379    T: CustomDataTrait + for<'a, 'py> pyo3::FromPyObject<'a, 'py> + Send + Sync + 'static,
380{
381    let type_name = T::type_name_static();
382    let factory: RustExtractorFactory = Box::new(|| {
383        Box::new(|obj: &pyo3::Bound<'_, pyo3::PyAny>| {
384            obj.extract::<T>()
385                .ok()
386                .map(|x| Arc::new(x) as Arc<dyn CustomDataTrait>)
387        })
388    });
389    ensure_rust_extractor_factory_registered(type_name, factory)
390}
391
392/// Calls the registered factory for the given type name and returns the extractor, if any.
393#[cfg(feature = "python")]
394#[must_use]
395pub fn get_rust_extractor(type_name: &str) -> Option<PyExtractor> {
396    let reg = rust_extractor_factories();
397    let factory_ref = reg.get(type_name)?;
398    Some(factory_ref.value()())
399}
400
401#[cfg(test)]
402mod tests {
403    use nautilus_core::UnixNanos;
404    use rstest::rstest;
405    use serde::{Deserialize, Serialize};
406
407    use super::*;
408    use crate::data::{CustomData, custom::register_custom_data_json};
409
410    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
411    struct TestRegCustomData {
412        ts_init: UnixNanos,
413    }
414
415    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
416    #[serde(deny_unknown_fields)]
417    struct StrictRegCustomData {
418        ts_init: UnixNanos,
419    }
420
421    impl crate::data::HasTsInit for TestRegCustomData {
422        fn ts_init(&self) -> UnixNanos {
423            self.ts_init
424        }
425    }
426
427    impl crate::data::custom::CustomDataTrait for TestRegCustomData {
428        fn type_name(&self) -> &'static str {
429            "TestRegCustomData"
430        }
431        fn type_name_static() -> &'static str {
432            "TestRegCustomData"
433        }
434        fn as_any(&self) -> &dyn std::any::Any {
435            self
436        }
437        fn ts_event(&self) -> nautilus_core::UnixNanos {
438            self.ts_init
439        }
440        fn to_json(&self) -> anyhow::Result<String> {
441            Ok(serde_json::to_string(self)?)
442        }
443        fn clone_arc(&self) -> Arc<dyn crate::data::CustomDataTrait> {
444            Arc::new(self.clone())
445        }
446        fn eq_arc(&self, other: &dyn crate::data::CustomDataTrait) -> bool {
447            other.as_any().downcast_ref::<Self>() == Some(self)
448        }
449        fn from_json(
450            value: serde_json::Value,
451        ) -> anyhow::Result<Arc<dyn crate::data::CustomDataTrait>> {
452            let t: Self = serde_json::from_value(value)?;
453            Ok(Arc::new(t))
454        }
455    }
456
457    impl crate::data::HasTsInit for StrictRegCustomData {
458        fn ts_init(&self) -> UnixNanos {
459            self.ts_init
460        }
461    }
462
463    impl crate::data::custom::CustomDataTrait for StrictRegCustomData {
464        fn type_name(&self) -> &'static str {
465            "StrictRegCustomData"
466        }
467        fn type_name_static() -> &'static str {
468            "StrictRegCustomData"
469        }
470        fn as_any(&self) -> &dyn std::any::Any {
471            self
472        }
473        fn ts_event(&self) -> nautilus_core::UnixNanos {
474            self.ts_init
475        }
476        fn to_json(&self) -> anyhow::Result<String> {
477            Ok(serde_json::to_string(self)?)
478        }
479        fn clone_arc(&self) -> Arc<dyn crate::data::CustomDataTrait> {
480            Arc::new(self.clone())
481        }
482        fn eq_arc(&self, other: &dyn crate::data::CustomDataTrait) -> bool {
483            other.as_any().downcast_ref::<Self>() == Some(self)
484        }
485        fn from_json(
486            value: serde_json::Value,
487        ) -> anyhow::Result<Arc<dyn crate::data::CustomDataTrait>> {
488            let t: Self = serde_json::from_value(value)?;
489            Ok(Arc::new(t))
490        }
491    }
492
493    #[rstest]
494    fn json_registry_roundtrip() {
495        let _ = register_custom_data_json::<TestRegCustomData>();
496
497        let data = Data::Custom(CustomData::from_arc(Arc::new(TestRegCustomData {
498            ts_init: UnixNanos::from(100),
499        })));
500
501        let json = serde_json::to_string(&data).unwrap();
502        let back: Data = serde_json::from_str(&json).unwrap();
503
504        match (&data, &back) {
505            (Data::Custom(a), Data::Custom(b)) => {
506                assert_eq!(a.data.type_name(), b.data.type_name());
507                assert_eq!(a.data.ts_init(), b.data.ts_init());
508            }
509            _ => panic!("expected Custom variant"),
510        }
511    }
512
513    #[rstest]
514    fn json_registry_roundtrip_with_deny_unknown_fields() {
515        let _ = register_custom_data_json::<StrictRegCustomData>();
516
517        let data = Data::Custom(CustomData::from_arc(Arc::new(StrictRegCustomData {
518            ts_init: UnixNanos::from(200),
519        })));
520
521        let json = serde_json::to_string(&data).unwrap();
522        let back: Data = serde_json::from_str(&json).unwrap();
523
524        match (&data, &back) {
525            (Data::Custom(a), Data::Custom(b)) => {
526                assert_eq!(a.data.type_name(), b.data.type_name());
527                assert_eq!(a.data.ts_init(), b.data.ts_init());
528            }
529            _ => panic!("expected Custom variant"),
530        }
531    }
532
533    #[rstest]
534    fn ensure_json_deserializer_registered_is_idempotent() {
535        let deserializer: JsonDeserializer = Box::new(|value| {
536            let t: TestRegCustomData = serde_json::from_value(value)?;
537            Ok(Arc::new(t) as Arc<dyn crate::data::CustomDataTrait>)
538        });
539        let r1 = ensure_json_deserializer_registered("IdempotentTestJson", deserializer);
540        assert!(r1.is_ok(), "first registration should succeed");
541        let deserializer2: JsonDeserializer = Box::new(|value| {
542            let t: TestRegCustomData = serde_json::from_value(value)?;
543            Ok(Arc::new(t) as Arc<dyn crate::data::CustomDataTrait>)
544        });
545        let r2 = ensure_json_deserializer_registered("IdempotentTestJson", deserializer2);
546        assert!(
547            r2.is_ok(),
548            "second registration with same type_name should succeed (idempotent)"
549        );
550    }
551
552    #[rstest]
553    fn register_json_deserializer_fails_on_duplicate() {
554        let deserializer: JsonDeserializer = Box::new(|value| {
555            let t: TestRegCustomData = serde_json::from_value(value)?;
556            Ok(Arc::new(t) as Arc<dyn crate::data::CustomDataTrait>)
557        });
558        let r1 = register_json_deserializer("StrictDuplicateTestJson", deserializer);
559        assert!(r1.is_ok());
560        let deserializer2: JsonDeserializer = Box::new(|value| {
561            let t: TestRegCustomData = serde_json::from_value(value)?;
562            Ok(Arc::new(t) as Arc<dyn crate::data::CustomDataTrait>)
563        });
564        let r2 = register_json_deserializer("StrictDuplicateTestJson", deserializer2);
565        assert!(r2.is_err());
566        let err_msg = r2.unwrap_err().to_string();
567        assert!(
568            err_msg.contains("already registered"),
569            "expected 'already registered' in error, found: {err_msg}"
570        );
571    }
572}