use std::sync::Arc;
use ahash::{HashMap, HashSet};
use itertools::Itertools as _;
use nohash_hasher::{IntMap, IntSet};
use re_chunk::{ComponentIdentifier, ComponentType};
use re_sdk_types::ViewClassIdentifier;
use super::view_class_placeholder::ViewClassPlaceholder;
use super::visualizer_entity_subscriber::{VisualizerEntityConfig, VisualizerEntitySubscriber};
use crate::view::view_context_system::ViewContextSystemOncePerFrameResult;
use crate::{
IdentifiedViewSystem, QueryContext, ViewClass, ViewContextCollection, ViewContextSystem,
ViewSystemIdentifier, ViewerContext, VisualizerCollection, VisualizerSystem,
};
use crate::{
component_fallbacks::FallbackProviderRegistry, view::view_context_system::ViewSystemState,
};
#[derive(Debug, thiserror::Error)]
pub enum ViewClassRegistryError {
#[error("View with class identifier {0:?} was already registered.")]
DuplicateClassIdentifier(ViewClassIdentifier),
#[error("A context system with identifier {0:?} was already registered.")]
IdentifierAlreadyInUseForContextSystem(&'static str),
#[error("A visualizer system with identifier {0:?} was already registered.")]
IdentifierAlreadyInUseForVisualizer(&'static str),
#[error("View with class identifier {0:?} was not registered.")]
UnknownClassIdentifier(ViewClassIdentifier),
}
pub struct ViewSystemRegistrator<'a> {
registry: &'a mut ViewClassRegistry,
fallback_registry: &'a mut FallbackProviderRegistry,
identifier: ViewClassIdentifier,
context_systems: HashSet<ViewSystemIdentifier>,
visualizers: HashSet<ViewSystemIdentifier>,
app_options: &'a crate::AppOptions,
known_builtin_enum_components: Arc<IntSet<ComponentType>>,
}
impl ViewSystemRegistrator<'_> {
pub fn register_context_system<
T: ViewContextSystem + IdentifiedViewSystem + Default + 'static,
>(
&mut self,
) -> Result<(), ViewClassRegistryError> {
if self.registry.visualizers.contains_key(&T::identifier()) {
return Err(ViewClassRegistryError::IdentifierAlreadyInUseForVisualizer(
T::identifier().as_str(),
));
}
if self.context_systems.insert(T::identifier()) {
self.registry
.context_systems
.entry(T::identifier())
.or_insert_with(|| ContextSystemTypeRegistryEntry {
factory_method: Box::new(|| Box::<T>::default()),
once_per_frame_execution_method: T::execute_once_per_frame,
used_by: Default::default(),
})
.used_by
.insert(self.identifier);
Ok(())
} else {
Err(
ViewClassRegistryError::IdentifierAlreadyInUseForContextSystem(
T::identifier().as_str(),
),
)
}
}
pub fn register_visualizer<T: VisualizerSystem + IdentifiedViewSystem + Default + 'static>(
&mut self,
) -> Result<(), ViewClassRegistryError> {
if self.registry.context_systems.contains_key(&T::identifier()) {
return Err(
ViewClassRegistryError::IdentifierAlreadyInUseForContextSystem(
T::identifier().as_str(),
),
);
}
if self.visualizers.insert(T::identifier()) {
let app_options = self.app_options;
let known_builtin_enum_components = Arc::clone(&self.known_builtin_enum_components);
self.registry
.visualizers
.entry(T::identifier())
.or_insert_with(move || {
let visualizer = T::default();
let visualizer_query_info = visualizer.visualizer_query_info(app_options);
let entity_config = VisualizerEntityConfig {
visualizer: T::identifier(),
relevant_archetype: visualizer_query_info.relevant_archetype,
constraints: Arc::new(visualizer_query_info.constraints),
known_builtin_enum_components,
};
VisualizerTypeRegistryEntry {
factory_method: Box::new(|| Box::<T>::default()),
used_by: Default::default(),
entity_config,
}
})
.used_by
.insert(self.identifier);
Ok(())
} else {
Err(ViewClassRegistryError::IdentifierAlreadyInUseForVisualizer(
T::identifier().as_str(),
))
}
}
pub fn register_fallback_provider<C: re_sdk_types::Component>(
&mut self,
component: ComponentIdentifier,
provider: impl Fn(&QueryContext<'_>) -> C + Send + Sync + 'static,
) {
self.fallback_registry.register_view_fallback_provider(
self.identifier,
component,
provider,
);
}
pub fn register_array_fallback_provider<
C: re_sdk_types::Component,
I: IntoIterator<Item = C>,
>(
&mut self,
component: ComponentIdentifier,
provider: impl Fn(&QueryContext<'_>) -> I + Send + Sync + 'static,
) {
self.fallback_registry
.register_view_array_fallback_provider(self.identifier, component, provider);
}
}
pub struct ViewClassRegistryEntry {
pub class: Box<dyn ViewClass>,
pub identifier: ViewClassIdentifier,
pub context_system_ids: HashSet<ViewSystemIdentifier>,
pub visualizer_system_ids: HashSet<ViewSystemIdentifier>,
}
impl Default for ViewClassRegistryEntry {
fn default() -> Self {
Self {
class: Box::<ViewClassPlaceholder>::default(),
identifier: ViewClassPlaceholder::identifier(),
context_system_ids: Default::default(),
visualizer_system_ids: Default::default(),
}
}
}
struct ContextSystemTypeRegistryEntry {
factory_method: Box<dyn Fn() -> Box<dyn ViewContextSystem> + Send + Sync>,
once_per_frame_execution_method: fn(&ViewerContext<'_>) -> ViewContextSystemOncePerFrameResult,
used_by: HashSet<ViewClassIdentifier>,
}
struct VisualizerTypeRegistryEntry {
factory_method: Box<dyn Fn() -> Box<dyn VisualizerSystem> + Send + Sync>,
used_by: HashSet<ViewClassIdentifier>,
entity_config: VisualizerEntityConfig,
}
#[derive(Default)]
pub struct ViewClassRegistry {
view_classes: HashMap<ViewClassIdentifier, ViewClassRegistryEntry>,
context_systems: HashMap<ViewSystemIdentifier, ContextSystemTypeRegistryEntry>,
visualizers: HashMap<ViewSystemIdentifier, VisualizerTypeRegistryEntry>,
placeholder: ViewClassRegistryEntry,
}
impl ViewClassRegistry {
pub fn add_class<T: ViewClass + Default + 'static>(
&mut self,
reflection: &re_types_core::reflection::Reflection,
app_options: &crate::AppOptions,
fallback_registry: &mut FallbackProviderRegistry,
) -> Result<(), ViewClassRegistryError> {
let identifier = T::identifier();
if self.view_classes.contains_key(&identifier) {
return Err(ViewClassRegistryError::DuplicateClassIdentifier(identifier));
}
self.view_classes.insert(
identifier,
ViewClassRegistryEntry {
class: Box::<T>::default(),
identifier,
context_system_ids: Default::default(),
visualizer_system_ids: Default::default(),
},
);
self.extend_class(
identifier,
reflection,
app_options,
fallback_registry,
|reg| T::default().on_register(reg),
)?;
Ok(())
}
pub fn extend_class(
&mut self,
view_class: ViewClassIdentifier,
reflection: &re_sdk_types::reflection::Reflection,
app_options: &crate::AppOptions,
fallback_registry: &mut FallbackProviderRegistry,
register_fn: impl FnOnce(&mut ViewSystemRegistrator<'_>) -> Result<(), ViewClassRegistryError>,
) -> Result<(), ViewClassRegistryError> {
let Some(mut class_entry) = self.view_classes.remove(&view_class) else {
return Err(ViewClassRegistryError::UnknownClassIdentifier(view_class));
};
let known_builtin_enum_components: Arc<IntSet<ComponentType>> = Arc::new(
reflection
.components
.iter()
.filter(|(_, r)| r.is_enum)
.map(|(ct, _)| *ct)
.collect(),
);
let mut registrator = ViewSystemRegistrator {
registry: self,
identifier: view_class,
context_systems: class_entry.context_system_ids,
visualizers: class_entry.visualizer_system_ids,
fallback_registry,
app_options,
known_builtin_enum_components,
};
register_fn(&mut registrator)?;
let ViewSystemRegistrator {
registry: _,
identifier: _,
context_systems,
visualizers,
fallback_registry: _,
app_options: _,
known_builtin_enum_components: _,
} = registrator;
class_entry.context_system_ids = context_systems;
class_entry.visualizer_system_ids = visualizers;
self.view_classes
.insert(class_entry.identifier, class_entry);
Ok(())
}
pub fn remove_class<T: ViewClass + Sized>(&mut self) -> Result<(), ViewClassRegistryError> {
let identifier = T::identifier();
if self.view_classes.remove(&identifier).is_none() {
return Err(ViewClassRegistryError::UnknownClassIdentifier(identifier));
}
self.context_systems.retain(|_, context_system_entry| {
context_system_entry.used_by.remove(&identifier);
!context_system_entry.used_by.is_empty()
});
self.visualizers.retain(|_, visualizer_entry| {
visualizer_entry.used_by.remove(&identifier);
!visualizer_entry.used_by.is_empty()
});
Ok(())
}
pub fn register_context_system<
T: ViewContextSystem + IdentifiedViewSystem + Default + 'static,
>(
&mut self,
view_class: ViewClassIdentifier,
) -> Result<(), ViewClassRegistryError> {
if self.visualizers.contains_key(&T::identifier()) {
return Err(ViewClassRegistryError::IdentifierAlreadyInUseForVisualizer(
T::identifier().as_str(),
));
}
let class_entry = self
.view_classes
.get_mut(&view_class)
.ok_or(ViewClassRegistryError::UnknownClassIdentifier(view_class))?;
if class_entry.context_system_ids.insert(T::identifier()) {
self.context_systems
.entry(T::identifier())
.or_insert_with(|| ContextSystemTypeRegistryEntry {
factory_method: Box::new(|| Box::<T>::default()),
once_per_frame_execution_method: T::execute_once_per_frame,
used_by: Default::default(),
})
.used_by
.insert(view_class);
Ok(())
} else {
Err(
ViewClassRegistryError::IdentifierAlreadyInUseForContextSystem(
T::identifier().as_str(),
),
)
}
}
pub fn class_entry(&self, name: ViewClassIdentifier) -> Option<&ViewClassRegistryEntry> {
self.view_classes.get(&name)
}
pub fn get_class_entry_or_log_error(
&self,
name: ViewClassIdentifier,
) -> &ViewClassRegistryEntry {
if let Some(result) = self.class_entry(name) {
result
} else {
re_log::error_once!("Unknown view class {:?}", name);
&self.placeholder
}
}
pub fn class(&self, name: ViewClassIdentifier) -> Option<&dyn ViewClass> {
self.class_entry(name).map(|e| e.class.as_ref())
}
pub fn get_class_or_log_error(&self, name: ViewClassIdentifier) -> &dyn ViewClass {
self.get_class_entry_or_log_error(name).class.as_ref()
}
pub fn display_name(&self, name: ViewClassIdentifier) -> &'static str {
self.view_classes
.get(&name)
.map_or("<unknown view class>", |boxed| boxed.class.display_name())
}
pub fn iter_registry(&self) -> impl Iterator<Item = &ViewClassRegistryEntry> {
self.view_classes
.values()
.sorted_by_key(|entry| entry.class.display_name())
}
pub fn create_entity_subscribers(
&self,
) -> IntMap<ViewSystemIdentifier, VisualizerEntitySubscriber> {
self.visualizers
.iter()
.map(|(id, entry)| (*id, entry.entity_config.create_subscriber()))
.collect()
}
pub fn run_once_per_frame_context_systems(
&self,
viewer_ctx: &ViewerContext<'_>,
view_classes: impl Iterator<Item = ViewClassIdentifier>,
) -> IntMap<ViewSystemIdentifier, ViewContextSystemOncePerFrameResult> {
re_tracing::profile_function!();
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
let context_system_ids = view_classes
.filter_map(|view_class_identifier| self.view_classes.get(&view_class_identifier))
.flat_map(|view_class| view_class.context_system_ids.iter().copied())
.unique()
.collect_vec();
context_system_ids
.into_par_iter()
.filter_map(|context_system_id| {
self.context_systems.get(&context_system_id).map(|entry| {
(
context_system_id,
(entry.once_per_frame_execution_method)(viewer_ctx),
)
})
})
.collect()
}
pub fn new_context_collection(
&self,
view_class_identifier: ViewClassIdentifier,
) -> ViewContextCollection {
re_tracing::profile_function!();
let Some(class) = self.view_classes.get(&view_class_identifier) else {
return ViewContextCollection {
systems: Default::default(),
view_class_identifier,
};
};
ViewContextCollection {
systems: class
.context_system_ids
.iter()
.filter_map(|name| {
self.context_systems.get(name).map(|entry| {
let system = (entry.factory_method)();
(*name, (system, ViewSystemState::default()))
})
})
.collect(),
view_class_identifier,
}
}
pub fn new_visualizer_collection(
&self,
view_class_identifier: ViewClassIdentifier,
) -> VisualizerCollection {
re_tracing::profile_function!();
let Some(class) = self.view_classes.get(&view_class_identifier) else {
return VisualizerCollection {
systems: Default::default(),
};
};
VisualizerCollection {
systems: class
.visualizer_system_ids
.iter()
.filter_map(|name| {
self.visualizers.get(name).map(|entry| {
let system = (entry.factory_method)();
(*name, system)
})
})
.collect(),
}
}
}