use std::{collections::HashMap, sync::Mutex};
use nautilus_core::{MUTEX_POISONED, python::to_pynotimplemented_err};
use pyo3::prelude::*;
use crate::factories::{ClientConfig, DataClientFactory};
pub type FactoryExtractor =
fn(py: Python<'_>, factory: Py<PyAny>) -> PyResult<Box<dyn DataClientFactory>>;
pub type ConfigExtractor = fn(py: Python<'_>, config: Py<PyAny>) -> PyResult<Box<dyn ClientConfig>>;
#[derive(Debug)]
pub struct FactoryRegistry {
factory_extractors: Mutex<HashMap<String, FactoryExtractor>>,
config_extractors: Mutex<HashMap<String, ConfigExtractor>>,
}
impl FactoryRegistry {
#[must_use]
pub fn new() -> Self {
Self {
factory_extractors: Mutex::new(HashMap::new()),
config_extractors: Mutex::new(HashMap::new()),
}
}
pub fn register_factory_extractor(
&self,
name: String,
extractor: FactoryExtractor,
) -> anyhow::Result<()> {
let mut extractors = self.factory_extractors.lock().expect(MUTEX_POISONED);
if extractors.contains_key(&name) {
anyhow::bail!("Factory extractor '{name}' is already registered");
}
extractors.insert(name, extractor);
Ok(())
}
pub fn register_config_extractor(
&self,
type_name: String,
extractor: ConfigExtractor,
) -> anyhow::Result<()> {
let mut extractors = self.config_extractors.lock().expect(MUTEX_POISONED);
if extractors.contains_key(&type_name) {
anyhow::bail!("Config extractor '{type_name}' is already registered");
}
extractors.insert(type_name, extractor);
Ok(())
}
pub fn extract_factory(
&self,
py: Python<'_>,
factory: Py<PyAny>,
) -> PyResult<Box<dyn DataClientFactory>> {
let factory_name = factory
.getattr(py, "name")?
.call0(py)?
.extract::<String>(py)?;
let extractors = self.factory_extractors.lock().expect(MUTEX_POISONED);
if let Some(extractor) = extractors.get(&factory_name) {
extractor(py, factory)
} else {
Err(to_pynotimplemented_err(format!(
"No factory extractor registered for '{factory_name}'"
)))
}
}
pub fn extract_config(
&self,
py: Python<'_>,
config: Py<PyAny>,
) -> PyResult<Box<dyn ClientConfig>> {
let config_type_name = config
.getattr(py, "__class__")?
.getattr(py, "__name__")?
.extract::<String>(py)?;
let extractors = self.config_extractors.lock().expect(MUTEX_POISONED);
if let Some(extractor) = extractors.get(&config_type_name) {
extractor(py, config)
} else {
Err(to_pynotimplemented_err(format!(
"No config extractor registered for '{config_type_name}'"
)))
}
}
}
impl Default for FactoryRegistry {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_PYO3_REGISTRY: std::sync::LazyLock<FactoryRegistry> =
std::sync::LazyLock::new(FactoryRegistry::new);
#[must_use]
pub fn get_global_pyo3_registry() -> &'static FactoryRegistry {
&GLOBAL_PYO3_REGISTRY
}