use rustc_hash::FxHashMap;
use std::sync::Arc;
use crate::{
node::NodePartial,
reducers::{AddErrors, AddMessages, MapMerge, Reducer, ReducerError},
state::VersionedState,
types::ChannelType,
};
use tracing::instrument;
#[derive(Clone)]
pub struct ReducerRegistry {
reducer_map: FxHashMap<ChannelType, Vec<Arc<dyn Reducer>>>,
}
fn channel_has_data(channel: &ChannelType, partial: &NodePartial) -> bool {
match channel {
ChannelType::Message => partial.messages.as_ref().is_some_and(|v| !v.is_empty()),
ChannelType::Extra => partial.extra.as_ref().is_some_and(|m| !m.is_empty()),
ChannelType::Error => partial.errors.as_ref().is_some_and(|v| !v.is_empty()),
}
}
impl Default for ReducerRegistry {
fn default() -> Self {
Self::new()
.with_reducer(ChannelType::Message, Arc::new(AddMessages))
.with_reducer(ChannelType::Extra, Arc::new(MapMerge))
.with_reducer(ChannelType::Error, Arc::new(AddErrors))
}
}
impl ReducerRegistry {
pub fn new() -> Self {
Self {
reducer_map: FxHashMap::default(),
}
}
pub fn register(&mut self, channel: ChannelType, reducer: Arc<dyn Reducer>) -> &mut Self {
self.reducer_map.entry(channel).or_default().push(reducer);
self
}
pub fn with_reducer(mut self, channel: ChannelType, reducer: Arc<dyn Reducer>) -> Self {
self.register(channel, reducer);
self
}
#[must_use]
pub fn definition_signature(&self) -> Vec<String> {
let mut entries: Vec<String> = self
.reducer_map
.iter()
.map(|(channel, reducers)| {
let labels = reducers
.iter()
.enumerate()
.map(|(i, r)| format!("{i}:{}", r.definition_label()))
.collect::<Vec<_>>()
.join(",");
format!("{channel}:[{labels}]")
})
.collect();
entries.sort();
entries
}
#[instrument(skip(self, state, to_update), err)]
pub fn try_update(
&self,
channel_type: ChannelType,
state: &mut VersionedState,
to_update: &NodePartial,
) -> Result<(), ReducerError> {
if !channel_has_data(&channel_type, to_update) {
return Ok(());
}
match self.reducer_map.get(&channel_type) {
Some(reducers) => {
for reducer in reducers {
reducer.apply(state, to_update);
}
Ok(())
}
None => Err(ReducerError::UnknownChannel(channel_type)),
}
}
#[instrument(skip(self, state, merged_updates), err)]
pub fn apply_all(
&self,
state: &mut VersionedState,
merged_updates: &NodePartial,
) -> Result<(), ReducerError> {
for channel in self.reducer_map.keys() {
self.try_update(channel.clone(), state, merged_updates)?;
}
Ok(())
}
}