use std::{
borrow::Borrow,
cmp::{Ordering, Reverse},
collections::{BinaryHeap, HashMap, HashSet},
hash::Hash,
sync::OnceLock,
};
use ruma_common::{
EventId, MilliSecondsSinceUnixEpoch, OwnedUserId,
room_version_rules::{AuthorizationRules, StateResolutionV2Rules},
};
use ruma_events::{
StateEventType, TimelineEventType,
room::{member::MembershipState, power_levels::UserPowerLevel},
};
use tracing::{debug, info, instrument, trace, warn};
#[cfg(test)]
mod tests;
use crate::{
Error, Event, Result, auth_types_for_event, check_state_dependent_auth_rules,
events::{
RoomCreateEvent, RoomMemberEvent, RoomPowerLevelsEvent, RoomPowerLevelsIntField,
power_levels::RoomPowerLevelsEventOptionExt,
},
utils::{RoomIdExt, event_id_map::EventIdMap, event_id_set::EventIdSet},
};
pub type StateMap<T> = HashMap<(StateEventType, String), T>;
#[instrument(skip_all)]
pub fn resolve<'a, E, MapsIter>(
auth_rules: &AuthorizationRules,
state_res_rules: &StateResolutionV2Rules,
state_maps: impl IntoIterator<IntoIter = MapsIter>,
auth_chains: Vec<EventIdSet<E::Id>>,
fetch_event: impl Fn(&EventId) -> Option<E>,
fetch_conflicted_state_subgraph: impl Fn(&StateMap<Vec<E::Id>>) -> Option<EventIdSet<E::Id>>,
) -> Result<StateMap<E::Id>>
where
E: Event + Clone,
E::Id: 'a,
MapsIter: Iterator<Item = &'a StateMap<E::Id>> + Clone,
{
info!("state resolution starting");
let (unconflicted_state_map, conflicted_state_set) =
split_conflicted_state_set(state_maps.into_iter());
info!(count = unconflicted_state_map.len(), "unconflicted events");
trace!(map = ?unconflicted_state_map, "unconflicted events");
if conflicted_state_set.is_empty() {
info!("no conflicted state found");
return Ok(unconflicted_state_map);
}
info!(count = conflicted_state_set.len(), "conflicted events");
trace!(map = ?conflicted_state_set, "conflicted events");
let conflicted_state_subgraph = if state_res_rules.consider_conflicted_state_subgraph {
let conflicted_state_subgraph = fetch_conflicted_state_subgraph(&conflicted_state_set)
.ok_or(Error::FetchConflictedStateSubgraphFailed)?;
info!(count = conflicted_state_subgraph.len(), "events in conflicted state subgraph");
trace!(set = ?conflicted_state_subgraph, "conflicted state subgraph");
conflicted_state_subgraph
} else {
EventIdSet::new()
};
let full_conflicted_set: EventIdSet<_> = auth_difference(auth_chains)
.chain(conflicted_state_set.into_values().flatten())
.chain(conflicted_state_subgraph)
.filter(|id| fetch_event(id.borrow()).is_some())
.collect();
info!(count = full_conflicted_set.len(), "full conflicted set");
trace!(set = ?full_conflicted_set, "full conflicted set");
let conflicted_power_events = full_conflicted_set
.iter()
.filter(|&id| is_power_event_id(id.borrow(), &fetch_event))
.cloned()
.collect::<Vec<_>>();
let sorted_power_events =
sort_power_events(conflicted_power_events, &full_conflicted_set, auth_rules, &fetch_event)?;
debug!(count = sorted_power_events.len(), "power events");
trace!(list = ?sorted_power_events, "sorted power events");
let initial_state_map = if state_res_rules.begin_iterative_auth_checks_with_empty_state_map {
HashMap::new()
} else {
unconflicted_state_map.clone()
};
let partially_resolved_state =
iterative_auth_checks(auth_rules, &sorted_power_events, initial_state_map, &fetch_event)?;
debug!(count = partially_resolved_state.len(), "resolved power events");
trace!(map = ?partially_resolved_state, "resolved power events");
let sorted_power_events_set = sorted_power_events.into_iter().collect::<EventIdSet<_>>();
let remaining_events = full_conflicted_set
.iter()
.filter(|&id| !sorted_power_events_set.contains(id.borrow()))
.cloned()
.collect::<Vec<_>>();
debug!(count = remaining_events.len(), "events left to resolve");
trace!(list = ?remaining_events, "events left to resolve");
let power_event = partially_resolved_state.get(&(StateEventType::RoomPowerLevels, "".into()));
debug!(event_id = ?power_event, "power event");
let sorted_remaining_events =
mainline_sort(&remaining_events, power_event.cloned(), &fetch_event)?;
trace!(list = ?sorted_remaining_events, "events left, sorted");
let mut resolved_state = iterative_auth_checks(
auth_rules,
&sorted_remaining_events,
partially_resolved_state,
&fetch_event,
)?;
resolved_state.extend(unconflicted_state_map);
info!("state resolution finished");
Ok(resolved_state)
}
fn split_conflicted_state_set<'a, Id>(
state_maps: impl Iterator<Item = &'a StateMap<Id>>,
) -> (StateMap<Id>, StateMap<Vec<Id>>)
where
Id: Clone + Eq + Hash + 'a,
{
let mut state_set_count = 0_usize;
let mut occurrences = HashMap::<_, HashMap<_, _>>::new();
let state_maps = state_maps.inspect(|_| state_set_count += 1);
for (k, v) in state_maps.flatten() {
occurrences.entry(k).or_default().entry(v).and_modify(|x| *x += 1).or_insert(1);
}
let mut unconflicted_state_map = StateMap::new();
let mut conflicted_state_set = StateMap::new();
for (k, v) in occurrences {
for (id, occurrence_count) in v {
if occurrence_count == state_set_count {
unconflicted_state_map.insert((k.0.clone(), k.1.clone()), id.clone());
} else {
conflicted_state_set
.entry((k.0.clone(), k.1.clone()))
.and_modify(|x: &mut Vec<_>| x.push(id.clone()))
.or_insert(vec![id.clone()]);
}
}
}
(unconflicted_state_map, conflicted_state_set)
}
fn auth_difference<Id>(auth_chains: Vec<EventIdSet<Id>>) -> impl Iterator<Item = Id>
where
Id: Eq + Hash + Borrow<EventId>,
{
let num_sets = auth_chains.len();
let mut id_counts: EventIdMap<Id, usize> = EventIdMap::new();
for id in auth_chains.into_iter().flatten() {
*id_counts.entry(id).or_default() += 1;
}
id_counts.into_iter().filter_map(move |(id, count)| (count < num_sets).then_some(id))
}
#[instrument(skip_all)]
fn sort_power_events<E: Event>(
conflicted_power_events: Vec<E::Id>,
full_conflicted_set: &EventIdSet<E::Id>,
rules: &AuthorizationRules,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<Vec<E::Id>> {
debug!("reverse topological sort of power events");
let mut graph = EventIdMap::new();
for event_id in conflicted_power_events {
add_event_and_auth_chain_to_graph(&mut graph, event_id, full_conflicted_set, &fetch_event);
}
let mut event_to_power_level = EventIdMap::new();
let creators_lock = OnceLock::new();
for event_id in graph.keys() {
let sender_power_level =
power_level_for_sender(event_id.borrow(), rules, &creators_lock, &fetch_event)
.map_err(Error::AuthEvent)?;
debug!(
event_id = event_id.borrow().as_str(),
power_level = ?sender_power_level,
"found the power level of an event's sender",
);
event_to_power_level.insert(event_id.clone(), sender_power_level);
}
reverse_topological_power_sort(&graph, |event_id| {
let event = fetch_event(event_id).ok_or_else(|| Error::NotFound(event_id.to_owned()))?;
let power_level = *event_to_power_level
.get(event_id)
.ok_or_else(|| Error::NotFound(event_id.to_owned()))?;
Ok((power_level, event.origin_server_ts()))
})
}
#[instrument(skip_all)]
pub fn reverse_topological_power_sort<Id, F>(
graph: &EventIdMap<Id, EventIdSet<Id>>,
event_details_fn: F,
) -> Result<Vec<Id>>
where
F: Fn(&EventId) -> Result<(UserPowerLevel, MilliSecondsSinceUnixEpoch)>,
Id: Clone + Eq + Ord + Hash + Borrow<EventId>,
{
#[derive(PartialEq, Eq)]
struct TieBreaker<Id> {
power_level: UserPowerLevel,
origin_server_ts: MilliSecondsSinceUnixEpoch,
event_id: Id,
}
impl<Id> Ord for TieBreaker<Id>
where
Id: Ord,
{
fn cmp(&self, other: &Self) -> Ordering {
other
.power_level
.cmp(&self.power_level)
.then(self.origin_server_ts.cmp(&other.origin_server_ts))
.then(self.event_id.cmp(&other.event_id))
}
}
impl<Id> PartialOrd for TieBreaker<Id>
where
Id: Ord,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
let mut outgoing_edges_map: EventIdMap<_, EventIdSet<_>> = EventIdMap::new();
let mut incoming_edges_map: EventIdMap<_, EventIdSet<_>> = EventIdMap::new();
let mut heap = BinaryHeap::new();
for (event_id, outgoing_edges) in graph {
if outgoing_edges.is_empty() {
let (power_level, origin_server_ts) = event_details_fn(event_id.borrow())?;
heap.push(Reverse(TieBreaker {
power_level,
origin_server_ts,
event_id: event_id.clone(),
}));
} else {
for auth_event_id in outgoing_edges {
incoming_edges_map
.entry(auth_event_id.borrow())
.or_default()
.insert(event_id.borrow());
}
outgoing_edges_map
.insert(event_id.clone(), outgoing_edges.iter().map(Borrow::borrow).collect());
}
}
let mut sorted = vec![];
while let Some(Reverse(TieBreaker { event_id, .. })) = heap.pop() {
for &parent_id in incoming_edges_map.get(event_id.borrow()).into_iter().flatten() {
let parent_has_zero_outdegrees = {
let outgoing_edges = outgoing_edges_map.get_mut(parent_id).expect(
"outgoing edges map should have a key for all event IDs with outgoing edges",
);
outgoing_edges.remove(event_id.borrow());
outgoing_edges.is_empty()
};
if parent_has_zero_outdegrees {
let (power_level, origin_server_ts) = event_details_fn(parent_id)?;
let (parent_id, _) = outgoing_edges_map
.remove_entry(parent_id)
.expect("outgoing edges map should have a key for all event IDs");
heap.push(Reverse(TieBreaker {
power_level,
origin_server_ts,
event_id: parent_id,
}));
}
}
sorted.push(event_id);
}
Ok(sorted)
}
fn power_level_for_sender<E: Event>(
event_id: &EventId,
rules: &AuthorizationRules,
creators_lock: &OnceLock<HashSet<OwnedUserId>>,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> std::result::Result<UserPowerLevel, String> {
let event = fetch_event(event_id);
let mut room_create_event = None;
let mut room_power_levels_event = None;
if let Some(event) = &event
&& rules.room_create_event_id_as_room_id
&& creators_lock.get().is_none()
{
room_create_event = event
.room_id()
.and_then(|room_id| room_id.room_create_event_id().ok())
.and_then(|room_create_event_id| fetch_event(&room_create_event_id));
}
for auth_event_id in event.as_ref().map(|pdu| pdu.auth_events()).into_iter().flatten() {
if let Some(auth_event) = fetch_event(auth_event_id.borrow()) {
if is_type_and_key(&auth_event, &TimelineEventType::RoomPowerLevels, "") {
room_power_levels_event = Some(RoomPowerLevelsEvent::new(auth_event));
} else if !rules.room_create_event_id_as_room_id
&& creators_lock.get().is_none()
&& is_type_and_key(&auth_event, &TimelineEventType::RoomCreate, "")
{
room_create_event = Some(auth_event);
}
if room_power_levels_event.is_some()
&& (rules.room_create_event_id_as_room_id
|| creators_lock.get().is_some()
|| room_create_event.is_some())
{
break;
}
}
}
let creators = if let Some(creators) = creators_lock.get() {
Some(creators)
} else if let Some(room_create_event) = room_create_event {
let room_create_event = RoomCreateEvent::new(room_create_event);
let creators = room_create_event.creators(rules)?;
Some(creators_lock.get_or_init(|| creators))
} else {
None
};
if let Some((event, creators)) = event.zip(creators) {
room_power_levels_event.user_power_level(event.sender(), creators, rules)
} else {
room_power_levels_event
.get_as_int_or_default(RoomPowerLevelsIntField::UsersDefault, rules)
.map(Into::into)
}
}
fn iterative_auth_checks<E: Event + Clone>(
rules: &AuthorizationRules,
events: &[E::Id],
mut state: StateMap<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<StateMap<E::Id>> {
debug!("starting iterative auth checks");
trace!(list = ?events, "events to check");
for event_id in events {
let event = fetch_event(event_id.borrow())
.ok_or_else(|| Error::NotFound(event_id.borrow().to_owned()))?;
let state_key = event.state_key().ok_or(Error::MissingStateKey)?;
let mut auth_events = StateMap::new();
for auth_event_id in event.auth_events() {
if let Some(auth_event) = fetch_event(auth_event_id.borrow()) {
if !auth_event.rejected() {
auth_events.insert(
auth_event
.event_type()
.with_state_key(auth_event.state_key().ok_or(Error::MissingStateKey)?),
auth_event,
);
}
} else {
warn!(event_id = %auth_event_id.borrow(), "missing auth event");
}
}
if rules.room_create_event_id_as_room_id
&& *event.event_type() != TimelineEventType::RoomCreate
{
if let Some(room_create_event) = event
.room_id()
.and_then(|room_id| room_id.room_create_event_id().ok())
.and_then(|room_create_event_id| fetch_event(&room_create_event_id))
{
auth_events.insert((StateEventType::RoomCreate, String::new()), room_create_event);
} else {
warn!("missing m.room.create event");
}
}
let auth_types = match auth_types_for_event(
event.event_type(),
event.sender(),
Some(state_key),
event.content(),
rules,
) {
Ok(auth_types) => auth_types,
Err(error) => {
warn!("failed to get list of required auth events for malformed event: {error}");
continue;
}
};
for key in auth_types {
if let Some(auth_event_id) = state.get(&key) {
if let Some(auth_event) = fetch_event(auth_event_id.borrow()) {
if !auth_event.rejected() {
auth_events.insert(key.to_owned(), auth_event);
}
} else {
warn!(event_id = %auth_event_id.borrow(), "missing auth event");
}
}
}
match check_state_dependent_auth_rules(rules, &event, |ty, key| {
auth_events.get(&ty.with_state_key(key))
}) {
Ok(()) => {
state.insert(event.event_type().with_state_key(state_key), event_id.clone());
}
Err(error) => {
warn!(event_id = ?event.event_id(), "event failed the authentication check: {error}");
}
}
}
Ok(state)
}
fn mainline_sort<E: Event>(
events: &[E::Id],
mut power_level: Option<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<Vec<E::Id>> {
debug!("mainline sort of events");
if events.is_empty() {
return Ok(vec![]);
}
let mut mainline = vec![];
while let Some(power_level_event_id) = power_level {
mainline.push(power_level_event_id.clone());
let power_level_event = fetch_event(power_level_event_id.borrow())
.ok_or_else(|| Error::NotFound(power_level_event_id.borrow().to_owned()))?;
power_level = None;
for auth_event_id in power_level_event.auth_events() {
let auth_event = fetch_event(auth_event_id.borrow())
.ok_or_else(|| Error::NotFound(power_level_event_id.borrow().to_owned()))?;
if is_type_and_key(&auth_event, &TimelineEventType::RoomPowerLevels, "") {
power_level = Some(auth_event_id.to_owned());
break;
}
}
}
let mainline_map = mainline
.iter()
.rev()
.enumerate()
.map(|(idx, event_id)| ((*event_id).clone(), idx))
.collect::<EventIdMap<_, _>>();
let mut order_map = HashMap::new();
for event_id in events.iter() {
if let Some(event) = fetch_event(event_id.borrow())
&& let Ok(position) = mainline_position(event, &mainline_map, &fetch_event)
{
order_map.insert(
event_id,
(
position,
fetch_event(event_id.borrow()).map(|event| event.origin_server_ts()),
event_id,
),
);
}
}
let mut sorted_event_ids = order_map.keys().map(|&k| k.clone()).collect::<Vec<_>>();
sorted_event_ids.sort_by_key(|event_id| order_map.get(event_id).unwrap());
Ok(sorted_event_ids)
}
fn mainline_position<E: Event>(
event: E,
mainline_map: &EventIdMap<E::Id, usize>,
fetch_event: impl Fn(&EventId) -> Option<E>,
) -> Result<usize> {
let mut current_event = Some(event);
while let Some(event) = current_event {
let event_id = event.event_id();
debug!(event_id = event_id.borrow().as_str(), "mainline");
if let Some(position) = mainline_map.get(event_id.borrow()) {
return Ok(*position);
}
current_event = None;
for auth_event_id in event.auth_events() {
let auth_event = fetch_event(auth_event_id.borrow())
.ok_or_else(|| Error::NotFound(auth_event_id.borrow().to_owned()))?;
if is_type_and_key(&auth_event, &TimelineEventType::RoomPowerLevels, "") {
current_event = Some(auth_event);
break;
}
}
}
Ok(0)
}
fn add_event_and_auth_chain_to_graph<E: Event>(
graph: &mut EventIdMap<E::Id, EventIdSet<E::Id>>,
event_id: E::Id,
full_conflicted_set: &EventIdSet<E::Id>,
fetch_event: impl Fn(&EventId) -> Option<E>,
) {
let mut state = vec![event_id];
while let Some(event_id) = state.pop() {
graph.entry(event_id.clone()).or_default();
for auth_event_id in fetch_event(event_id.borrow())
.as_ref()
.map(|event| event.auth_events())
.into_iter()
.flatten()
{
if full_conflicted_set.contains(auth_event_id.borrow()) {
if !graph.contains_event_id(auth_event_id.borrow()) {
state.push(auth_event_id.to_owned());
}
graph.get_mut(event_id.borrow()).unwrap().insert(auth_event_id.to_owned());
}
}
}
}
fn is_power_event_id<E: Event>(event_id: &EventId, fetch: impl Fn(&EventId) -> Option<E>) -> bool {
match fetch(event_id).as_ref() {
Some(state) => is_power_event(state),
_ => false,
}
}
fn is_type_and_key(event: impl Event, event_type: &TimelineEventType, state_key: &str) -> bool {
event.event_type() == event_type && event.state_key() == Some(state_key)
}
fn is_power_event(event: impl Event) -> bool {
match event.event_type() {
TimelineEventType::RoomPowerLevels
| TimelineEventType::RoomJoinRules
| TimelineEventType::RoomCreate => event.state_key() == Some(""),
TimelineEventType::RoomMember => {
let room_member_event = RoomMemberEvent::new(event);
if room_member_event.membership().is_ok_and(|membership| {
matches!(membership, MembershipState::Leave | MembershipState::Ban)
}) {
return Some(room_member_event.sender().as_str()) != room_member_event.state_key();
}
false
}
_ => false,
}
}
pub(crate) trait EventTypeExt {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String);
}
impl EventTypeExt for StateEventType {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
(self, state_key.into())
}
}
impl EventTypeExt for TimelineEventType {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
(self.to_string().into(), state_key.into())
}
}
impl<T> EventTypeExt for &T
where
T: EventTypeExt + Clone,
{
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
self.to_owned().with_state_key(state_key)
}
}