1use std::{collections::HashMap, sync::Arc};
23
24use arrow::{datatypes::Schema, record_batch::RecordBatch};
25use dashmap::{DashMap, mapref::entry::Entry};
26use nautilus_core::Params;
27#[cfg(feature = "python")]
28use pyo3::types::PyAnyMethods;
29
30use crate::data::{CustomData, CustomDataTrait, Data, DataType};
31
32pub type JsonDeserializer =
33 Box<dyn Fn(serde_json::Value) -> Result<Arc<dyn CustomDataTrait>, anyhow::Error> + Send + Sync>;
34pub type ArrowEncoder =
35 Box<dyn Fn(&[Arc<dyn CustomDataTrait>]) -> Result<RecordBatch, anyhow::Error> + Send + Sync>;
36pub type ArrowDecoder = Box<
37 dyn Fn(&HashMap<String, String>, RecordBatch) -> Result<Vec<Data>, anyhow::Error> + Send + Sync,
38>;
39
40struct Registries {
41 json: DashMap<String, JsonDeserializer>,
42 arrow: DashMap<String, (Arc<Schema>, ArrowEncoder, ArrowDecoder)>,
43}
44
45fn registries() -> &'static Registries {
46 static REGISTRIES: std::sync::OnceLock<Registries> = std::sync::OnceLock::new();
47 REGISTRIES.get_or_init(|| Registries {
48 json: DashMap::new(),
49 arrow: DashMap::new(),
50 })
51}
52
53pub fn register_json_deserializer(
59 type_name: &str,
60 deserializer: JsonDeserializer,
61) -> Result<(), anyhow::Error> {
62 let reg = registries();
63 match reg.json.entry(type_name.to_string()) {
64 Entry::Occupied(_) => {
65 anyhow::bail!("Custom data type \"{type_name}\" is already registered for JSON");
66 }
67 Entry::Vacant(v) => {
68 v.insert(deserializer);
69 Ok(())
70 }
71 }
72}
73
74pub fn ensure_json_deserializer_registered(
81 type_name: &str,
82 deserializer: JsonDeserializer,
83) -> Result<(), anyhow::Error> {
84 let reg = registries();
85 reg.json
86 .entry(type_name.to_string())
87 .or_insert_with(|| deserializer);
88 Ok(())
89}
90
91fn parse_data_type_from_value(value: &serde_json::Value) -> Option<DataType> {
93 let obj = value.get("data_type")?.as_object()?;
94 let type_name = obj.get("type_name")?.as_str()?;
95 let metadata = obj.get("metadata").and_then(|m| {
96 if m.is_null() {
97 None
98 } else {
99 let p: Params = serde_json::from_value(m.clone()).ok()?;
100 if p.is_empty() { None } else { Some(p) }
101 }
102 });
103 let identifier = obj
104 .get("identifier")
105 .and_then(|v| v.as_str())
106 .map(String::from);
107 Some(DataType::new(type_name, metadata, identifier))
108}
109
110fn parse_envelope_payload(value: &serde_json::Value) -> Result<serde_json::Value, anyhow::Error> {
114 let payload = value
115 .get("payload")
116 .cloned()
117 .ok_or_else(|| anyhow::anyhow!("CustomData JSON missing 'payload' field"))?;
118 Ok(payload)
119}
120
121pub fn deserialize_custom_from_json(
127 type_name: &str,
128 value: &serde_json::Value,
129) -> Result<Option<Data>, anyhow::Error> {
130 let reg = registries();
131 let deserializer_ref = match reg.json.get(type_name) {
132 Some(d) => d,
133 None => return Ok(None),
134 };
135 let data_type = parse_data_type_from_value(value);
136 let payload = parse_envelope_payload(value)?;
137 let arc = deserializer_ref.value()(payload)?;
138 let custom = match data_type {
139 Some(dt) => CustomData::new(arc, dt),
140 None => CustomData::from_arc(arc),
141 };
142 Ok(Some(Data::Custom(custom)))
143}
144
145pub fn register_arrow(
150 type_name: &str,
151 schema: Arc<Schema>,
152 encoder: ArrowEncoder,
153 decoder: ArrowDecoder,
154) -> Result<(), anyhow::Error> {
155 let reg = registries();
156 match reg.arrow.entry(type_name.to_string()) {
157 Entry::Occupied(_) => {
158 anyhow::bail!("Custom data type \"{type_name}\" is already registered for Arrow");
159 }
160 Entry::Vacant(v) => {
161 v.insert((schema, encoder, decoder));
162 Ok(())
163 }
164 }
165}
166
167pub fn ensure_arrow_registered(
174 type_name: &str,
175 schema: Arc<Schema>,
176 encoder: ArrowEncoder,
177 decoder: ArrowDecoder,
178) -> Result<(), anyhow::Error> {
179 let reg = registries();
180 reg.arrow
181 .entry(type_name.to_string())
182 .or_insert_with(|| (schema, encoder, decoder));
183 Ok(())
184}
185
186#[must_use]
188pub fn get_arrow_schema(type_name: &str) -> Option<Arc<Schema>> {
189 let reg = registries();
190 reg.arrow
191 .get(type_name)
192 .map(|entry| Arc::clone(&entry.value().0))
193}
194
195pub fn encode_custom_to_arrow(
200 type_name: &str,
201 items: &[Arc<dyn CustomDataTrait>],
202) -> Result<Option<RecordBatch>, anyhow::Error> {
203 let reg = registries();
204 let entry = match reg.arrow.get(type_name) {
205 Some(e) => e,
206 None => return Ok(None),
207 };
208 let encoder = &entry.value().1;
209 encoder(items).map(Some)
210}
211
212#[expect(
217 clippy::implicit_hasher,
218 reason = "callers always use the default hasher"
219)]
220pub fn decode_custom_from_arrow(
221 type_name: &str,
222 metadata: &HashMap<String, String>,
223 record_batch: RecordBatch,
224) -> Result<Option<Vec<Data>>, anyhow::Error> {
225 let reg = registries();
226 let entry = match reg.arrow.get(type_name) {
227 Some(e) => e,
228 None => return Ok(None),
229 };
230 let decoder = &entry.value().2;
231 decoder(metadata, record_batch).map(Some)
232}
233
234#[cfg(feature = "python")]
235pub type PyExtractor = Box<
236 dyn for<'a> Fn(&pyo3::Bound<'a, pyo3::PyAny>) -> Option<Arc<dyn CustomDataTrait>> + Send + Sync,
237>;
238
239#[cfg(feature = "python")]
240fn py_extractors() -> &'static DashMap<String, PyExtractor> {
241 static PY_EXTRACTORS: std::sync::OnceLock<DashMap<String, PyExtractor>> =
242 std::sync::OnceLock::new();
243 PY_EXTRACTORS.get_or_init(DashMap::new)
244}
245
246#[cfg(feature = "python")]
252pub fn register_py_extractor(type_name: &str, extractor: PyExtractor) -> Result<(), anyhow::Error> {
253 let reg = py_extractors();
254 match reg.entry(type_name.to_string()) {
255 Entry::Occupied(_) => {
256 anyhow::bail!(
257 "Custom data type \"{type_name}\" is already registered for Python extraction"
258 );
259 }
260 Entry::Vacant(v) => {
261 v.insert(extractor);
262 Ok(())
263 }
264 }
265}
266
267#[cfg(feature = "python")]
274pub fn ensure_py_extractor_registered(
275 type_name: &str,
276 extractor: PyExtractor,
277) -> Result<(), anyhow::Error> {
278 let reg = py_extractors();
279 reg.entry(type_name.to_string())
280 .or_insert_with(|| extractor);
281 Ok(())
282}
283
284#[cfg(feature = "python")]
287#[must_use]
288pub fn try_extract_from_py(
289 type_name: &str,
290 obj: &pyo3::Bound<'_, pyo3::PyAny>,
291) -> Option<Arc<dyn CustomDataTrait>> {
292 let reg = py_extractors();
293 let entry = reg.get(type_name)?;
294 let extractor = entry.value();
295 extractor(obj)
296}
297
298#[cfg(feature = "python")]
299type RustExtractorFactory = Box<dyn Fn() -> PyExtractor + Send + Sync>;
300
301#[cfg(feature = "python")]
302fn rust_extractor_factories() -> &'static DashMap<String, RustExtractorFactory> {
303 static RUST_EXTRACTOR_FACTORIES: std::sync::OnceLock<DashMap<String, RustExtractorFactory>> =
304 std::sync::OnceLock::new();
305 RUST_EXTRACTOR_FACTORIES.get_or_init(DashMap::new)
306}
307
308#[cfg(feature = "python")]
316pub fn register_rust_extractor_factory(
317 type_name: &str,
318 factory: RustExtractorFactory,
319) -> Result<(), anyhow::Error> {
320 let reg = rust_extractor_factories();
321 match reg.entry(type_name.to_string()) {
322 Entry::Occupied(_) => {
323 anyhow::bail!("Rust extractor factory for \"{type_name}\" is already registered");
324 }
325 Entry::Vacant(v) => {
326 v.insert(factory);
327 Ok(())
328 }
329 }
330}
331
332#[cfg(feature = "python")]
339pub fn ensure_rust_extractor_factory_registered(
340 type_name: &str,
341 factory: RustExtractorFactory,
342) -> Result<(), anyhow::Error> {
343 let reg = rust_extractor_factories();
344 reg.entry(type_name.to_string()).or_insert_with(|| factory);
345 Ok(())
346}
347
348#[cfg(feature = "python")]
355pub fn register_rust_extractor<T>() -> Result<(), anyhow::Error>
356where
357 T: CustomDataTrait + for<'a, 'py> pyo3::FromPyObject<'a, 'py> + Send + Sync + 'static,
358{
359 let type_name = T::type_name_static();
360 let factory: RustExtractorFactory = Box::new(|| {
361 Box::new(|obj: &pyo3::Bound<'_, pyo3::PyAny>| {
362 obj.extract::<T>()
363 .ok()
364 .map(|x| Arc::new(x) as Arc<dyn CustomDataTrait>)
365 })
366 });
367 register_rust_extractor_factory(type_name, factory)
368}
369
370#[cfg(feature = "python")]
377pub fn ensure_rust_extractor_registered<T>() -> Result<(), anyhow::Error>
378where
379 T: CustomDataTrait + for<'a, 'py> pyo3::FromPyObject<'a, 'py> + Send + Sync + 'static,
380{
381 let type_name = T::type_name_static();
382 let factory: RustExtractorFactory = Box::new(|| {
383 Box::new(|obj: &pyo3::Bound<'_, pyo3::PyAny>| {
384 obj.extract::<T>()
385 .ok()
386 .map(|x| Arc::new(x) as Arc<dyn CustomDataTrait>)
387 })
388 });
389 ensure_rust_extractor_factory_registered(type_name, factory)
390}
391
392#[cfg(feature = "python")]
394#[must_use]
395pub fn get_rust_extractor(type_name: &str) -> Option<PyExtractor> {
396 let reg = rust_extractor_factories();
397 let factory_ref = reg.get(type_name)?;
398 Some(factory_ref.value()())
399}
400
401#[cfg(test)]
402mod tests {
403 use nautilus_core::UnixNanos;
404 use rstest::rstest;
405 use serde::{Deserialize, Serialize};
406
407 use super::*;
408 use crate::data::{CustomData, custom::register_custom_data_json};
409
410 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
411 struct TestRegCustomData {
412 ts_init: UnixNanos,
413 }
414
415 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
416 #[serde(deny_unknown_fields)]
417 struct StrictRegCustomData {
418 ts_init: UnixNanos,
419 }
420
421 impl crate::data::HasTsInit for TestRegCustomData {
422 fn ts_init(&self) -> UnixNanos {
423 self.ts_init
424 }
425 }
426
427 impl crate::data::custom::CustomDataTrait for TestRegCustomData {
428 fn type_name(&self) -> &'static str {
429 "TestRegCustomData"
430 }
431 fn type_name_static() -> &'static str {
432 "TestRegCustomData"
433 }
434 fn as_any(&self) -> &dyn std::any::Any {
435 self
436 }
437 fn ts_event(&self) -> nautilus_core::UnixNanos {
438 self.ts_init
439 }
440 fn to_json(&self) -> anyhow::Result<String> {
441 Ok(serde_json::to_string(self)?)
442 }
443 fn clone_arc(&self) -> Arc<dyn crate::data::CustomDataTrait> {
444 Arc::new(self.clone())
445 }
446 fn eq_arc(&self, other: &dyn crate::data::CustomDataTrait) -> bool {
447 other.as_any().downcast_ref::<Self>() == Some(self)
448 }
449 fn from_json(
450 value: serde_json::Value,
451 ) -> anyhow::Result<Arc<dyn crate::data::CustomDataTrait>> {
452 let t: Self = serde_json::from_value(value)?;
453 Ok(Arc::new(t))
454 }
455 }
456
457 impl crate::data::HasTsInit for StrictRegCustomData {
458 fn ts_init(&self) -> UnixNanos {
459 self.ts_init
460 }
461 }
462
463 impl crate::data::custom::CustomDataTrait for StrictRegCustomData {
464 fn type_name(&self) -> &'static str {
465 "StrictRegCustomData"
466 }
467 fn type_name_static() -> &'static str {
468 "StrictRegCustomData"
469 }
470 fn as_any(&self) -> &dyn std::any::Any {
471 self
472 }
473 fn ts_event(&self) -> nautilus_core::UnixNanos {
474 self.ts_init
475 }
476 fn to_json(&self) -> anyhow::Result<String> {
477 Ok(serde_json::to_string(self)?)
478 }
479 fn clone_arc(&self) -> Arc<dyn crate::data::CustomDataTrait> {
480 Arc::new(self.clone())
481 }
482 fn eq_arc(&self, other: &dyn crate::data::CustomDataTrait) -> bool {
483 other.as_any().downcast_ref::<Self>() == Some(self)
484 }
485 fn from_json(
486 value: serde_json::Value,
487 ) -> anyhow::Result<Arc<dyn crate::data::CustomDataTrait>> {
488 let t: Self = serde_json::from_value(value)?;
489 Ok(Arc::new(t))
490 }
491 }
492
493 #[rstest]
494 fn json_registry_roundtrip() {
495 let _ = register_custom_data_json::<TestRegCustomData>();
496
497 let data = Data::Custom(CustomData::from_arc(Arc::new(TestRegCustomData {
498 ts_init: UnixNanos::from(100),
499 })));
500
501 let json = serde_json::to_string(&data).unwrap();
502 let back: Data = serde_json::from_str(&json).unwrap();
503
504 match (&data, &back) {
505 (Data::Custom(a), Data::Custom(b)) => {
506 assert_eq!(a.data.type_name(), b.data.type_name());
507 assert_eq!(a.data.ts_init(), b.data.ts_init());
508 }
509 _ => panic!("expected Custom variant"),
510 }
511 }
512
513 #[rstest]
514 fn json_registry_roundtrip_with_deny_unknown_fields() {
515 let _ = register_custom_data_json::<StrictRegCustomData>();
516
517 let data = Data::Custom(CustomData::from_arc(Arc::new(StrictRegCustomData {
518 ts_init: UnixNanos::from(200),
519 })));
520
521 let json = serde_json::to_string(&data).unwrap();
522 let back: Data = serde_json::from_str(&json).unwrap();
523
524 match (&data, &back) {
525 (Data::Custom(a), Data::Custom(b)) => {
526 assert_eq!(a.data.type_name(), b.data.type_name());
527 assert_eq!(a.data.ts_init(), b.data.ts_init());
528 }
529 _ => panic!("expected Custom variant"),
530 }
531 }
532
533 #[rstest]
534 fn ensure_json_deserializer_registered_is_idempotent() {
535 let deserializer: JsonDeserializer = Box::new(|value| {
536 let t: TestRegCustomData = serde_json::from_value(value)?;
537 Ok(Arc::new(t) as Arc<dyn crate::data::CustomDataTrait>)
538 });
539 let r1 = ensure_json_deserializer_registered("IdempotentTestJson", deserializer);
540 assert!(r1.is_ok(), "first registration should succeed");
541 let deserializer2: JsonDeserializer = Box::new(|value| {
542 let t: TestRegCustomData = serde_json::from_value(value)?;
543 Ok(Arc::new(t) as Arc<dyn crate::data::CustomDataTrait>)
544 });
545 let r2 = ensure_json_deserializer_registered("IdempotentTestJson", deserializer2);
546 assert!(
547 r2.is_ok(),
548 "second registration with same type_name should succeed (idempotent)"
549 );
550 }
551
552 #[rstest]
553 fn register_json_deserializer_fails_on_duplicate() {
554 let deserializer: JsonDeserializer = Box::new(|value| {
555 let t: TestRegCustomData = serde_json::from_value(value)?;
556 Ok(Arc::new(t) as Arc<dyn crate::data::CustomDataTrait>)
557 });
558 let r1 = register_json_deserializer("StrictDuplicateTestJson", deserializer);
559 assert!(r1.is_ok());
560 let deserializer2: JsonDeserializer = Box::new(|value| {
561 let t: TestRegCustomData = serde_json::from_value(value)?;
562 Ok(Arc::new(t) as Arc<dyn crate::data::CustomDataTrait>)
563 });
564 let r2 = register_json_deserializer("StrictDuplicateTestJson", deserializer2);
565 assert!(r2.is_err());
566 let err_msg = r2.unwrap_err().to_string();
567 assert!(
568 err_msg.contains("already registered"),
569 "expected 'already registered' in error, found: {err_msg}"
570 );
571 }
572}