use indexmap::IndexSet;
use palpo_macros::StringEnum;
use salvo::oapi::ToSchema;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use super::{
Action, AnyPushRuleRef, ConditionalPushRule, FlattenedJson, InsertPushRuleError, NewPushRule, PatternedPushRule,
PushConditionRoomCtx, RuleKind, RuleNotFoundError, RulesetIter, SimplePushRule, insert_and_move_rule,
};
use crate::push::RemovePushRuleError;
use crate::serde::RawJson;
use crate::{OwnedRoomId, OwnedUserId, PrivOwnedStr};
#[derive(ToSchema, Deserialize, Serialize, Default, Clone, Debug)]
pub struct Ruleset {
#[serde(default, skip_serializing_if = "IndexSet::is_empty")]
#[salvo(schema(value_type = HashSet<PatternedPushRule>))]
pub content: IndexSet<PatternedPushRule>,
#[serde(rename = "override", default, skip_serializing_if = "IndexSet::is_empty")]
#[salvo(schema(value_type = HashSet<ConditionalPushRule>))]
pub override_: IndexSet<ConditionalPushRule>,
#[serde(default, skip_serializing_if = "IndexSet::is_empty")]
#[salvo(schema(value_type = HashSet<SimplePushRule<OwnedRoomId>>))]
pub room: IndexSet<SimplePushRule<OwnedRoomId>>,
#[serde(default, skip_serializing_if = "IndexSet::is_empty")]
#[salvo(schema(value_type = HashSet<SimplePushRule<OwnedUserId>>))]
pub sender: IndexSet<SimplePushRule<OwnedUserId>>,
#[serde(default, skip_serializing_if = "IndexSet::is_empty")]
#[salvo(schema(value_type = HashSet<ConditionalPushRule>))]
pub underride: IndexSet<ConditionalPushRule>,
}
impl Ruleset {
pub fn new() -> Self {
Default::default()
}
pub fn iter(&self) -> RulesetIter<'_> {
self.into_iter()
}
pub fn insert(
&mut self,
rule: NewPushRule,
after: Option<&str>,
before: Option<&str>,
) -> Result<(), InsertPushRuleError> {
let rule_id = rule.rule_id();
if rule_id.starts_with('.') {
return Err(InsertPushRuleError::ServerDefaultRuleId);
}
if rule_id.contains('/') {
return Err(InsertPushRuleError::InvalidRuleId);
}
if rule_id.contains('\\') {
return Err(InsertPushRuleError::InvalidRuleId);
}
if after.is_some_and(|s| s.starts_with('.')) {
return Err(InsertPushRuleError::RelativeToServerDefaultRule);
}
if before.is_some_and(|s| s.starts_with('.')) {
return Err(InsertPushRuleError::RelativeToServerDefaultRule);
}
match rule {
NewPushRule::Override(r) => {
let mut rule = ConditionalPushRule::from(r);
if let Some(prev_rule) = self.override_.get(rule.rule_id.as_str()) {
rule.enabled = prev_rule.enabled;
}
let default_position = 1;
insert_and_move_rule(&mut self.override_, rule, default_position, after, before)
}
NewPushRule::Underride(r) => {
let mut rule = ConditionalPushRule::from(r);
if let Some(prev_rule) = self.underride.get(rule.rule_id.as_str()) {
rule.enabled = prev_rule.enabled;
}
insert_and_move_rule(&mut self.underride, rule, 0, after, before)
}
NewPushRule::Content(r) => {
let mut rule = PatternedPushRule::from(r);
if let Some(prev_rule) = self.content.get(rule.rule_id.as_str()) {
rule.enabled = prev_rule.enabled;
}
insert_and_move_rule(&mut self.content, rule, 0, after, before)
}
NewPushRule::Room(r) => {
let mut rule = SimplePushRule::from(r);
if let Some(prev_rule) = self.room.get(rule.rule_id.as_str()) {
rule.enabled = prev_rule.enabled;
}
insert_and_move_rule(&mut self.room, rule, 0, after, before)
}
NewPushRule::Sender(r) => {
let mut rule = SimplePushRule::from(r);
if let Some(prev_rule) = self.sender.get(rule.rule_id.as_str()) {
rule.enabled = prev_rule.enabled;
}
insert_and_move_rule(&mut self.sender, rule, 0, after, before)
}
}
}
pub fn get(&self, kind: RuleKind, rule_id: impl AsRef<str>) -> Option<AnyPushRuleRef<'_>> {
let rule_id = rule_id.as_ref();
match kind {
RuleKind::Override => self.override_.get(rule_id).map(AnyPushRuleRef::Override),
RuleKind::Underride => self.underride.get(rule_id).map(AnyPushRuleRef::Underride),
RuleKind::Sender => self.sender.get(rule_id).map(AnyPushRuleRef::Sender),
RuleKind::Room => self.room.get(rule_id).map(AnyPushRuleRef::Room),
RuleKind::Content => self.content.get(rule_id).map(AnyPushRuleRef::Content),
RuleKind::_Custom(_) => None,
}
}
pub fn set_enabled(
&mut self,
kind: RuleKind,
rule_id: impl AsRef<str>,
enabled: bool,
) -> Result<(), RuleNotFoundError> {
let rule_id = rule_id.as_ref();
match kind {
RuleKind::Override => {
let mut rule = self.override_.get(rule_id).ok_or(RuleNotFoundError)?.clone();
rule.enabled = enabled;
self.override_.replace(rule);
}
RuleKind::Underride => {
let mut rule = self.underride.get(rule_id).ok_or(RuleNotFoundError)?.clone();
rule.enabled = enabled;
self.underride.replace(rule);
}
RuleKind::Sender => {
let mut rule = self.sender.get(rule_id).ok_or(RuleNotFoundError)?.clone();
rule.enabled = enabled;
self.sender.replace(rule);
}
RuleKind::Room => {
let mut rule = self.room.get(rule_id).ok_or(RuleNotFoundError)?.clone();
rule.enabled = enabled;
self.room.replace(rule);
}
RuleKind::Content => {
let mut rule = self.content.get(rule_id).ok_or(RuleNotFoundError)?.clone();
rule.enabled = enabled;
self.content.replace(rule);
}
RuleKind::_Custom(_) => return Err(RuleNotFoundError),
}
Ok(())
}
pub fn set_actions(
&mut self,
kind: RuleKind,
rule_id: impl AsRef<str>,
actions: Vec<Action>,
) -> Result<(), RuleNotFoundError> {
let rule_id = rule_id.as_ref();
match kind {
RuleKind::Override => {
let mut rule = self.override_.get(rule_id).ok_or(RuleNotFoundError)?.clone();
rule.actions = actions;
self.override_.replace(rule);
}
RuleKind::Underride => {
let mut rule = self.underride.get(rule_id).ok_or(RuleNotFoundError)?.clone();
rule.actions = actions;
self.underride.replace(rule);
}
RuleKind::Sender => {
let mut rule = self.sender.get(rule_id).ok_or(RuleNotFoundError)?.clone();
rule.actions = actions;
self.sender.replace(rule);
}
RuleKind::Room => {
let mut rule = self.room.get(rule_id).ok_or(RuleNotFoundError)?.clone();
rule.actions = actions;
self.room.replace(rule);
}
RuleKind::Content => {
let mut rule = self.content.get(rule_id).ok_or(RuleNotFoundError)?.clone();
rule.actions = actions;
self.content.replace(rule);
}
RuleKind::_Custom(_) => return Err(RuleNotFoundError),
}
Ok(())
}
#[instrument(skip_all, fields(context.room_id = %context.room_id))]
pub fn get_match<T>(&self, event: &RawJson<T>, context: &PushConditionRoomCtx) -> Option<AnyPushRuleRef<'_>> {
let event = FlattenedJson::from_raw(event);
if event.get_str("sender").is_some_and(|sender| sender == context.user_id) {
None
} else {
self.iter().find(|rule| rule.applies(&event, context))
}
}
#[instrument(skip_all, fields(context.room_id = %context.room_id))]
pub fn get_actions<T>(&self, event: &RawJson<T>, context: &PushConditionRoomCtx) -> &[Action] {
self.get_match(event, context).map(|rule| rule.actions()).unwrap_or(&[])
}
pub fn remove(&mut self, kind: RuleKind, rule_id: impl AsRef<str>) -> Result<(), RemovePushRuleError> {
let rule_id = rule_id.as_ref();
if let Some(rule) = self.get(kind.clone(), rule_id) {
if rule.is_server_default() {
return Err(RemovePushRuleError::ServerDefault);
}
} else {
return Err(RemovePushRuleError::NotFound);
}
match kind {
RuleKind::Override => {
self.override_.shift_remove(rule_id);
}
RuleKind::Underride => {
self.underride.shift_remove(rule_id);
}
RuleKind::Sender => {
self.sender.shift_remove(rule_id);
}
RuleKind::Room => {
self.room.shift_remove(rule_id);
}
RuleKind::Content => {
self.content.shift_remove(rule_id);
}
RuleKind::_Custom(_) => unreachable!(),
}
Ok(())
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[non_exhaustive]
pub enum PredefinedRuleId {
Override(PredefinedOverrideRuleId),
Underride(PredefinedUnderrideRuleId),
Content(PredefinedContentRuleId),
}
impl PredefinedRuleId {
pub fn as_str(&self) -> &str {
match self {
Self::Override(id) => id.as_str(),
Self::Underride(id) => id.as_str(),
Self::Content(id) => id.as_str(),
}
}
pub fn kind(&self) -> RuleKind {
match self {
Self::Override(id) => id.kind(),
Self::Underride(id) => id.kind(),
Self::Content(id) => id.kind(),
}
}
}
impl AsRef<str> for PredefinedRuleId {
fn as_ref(&self) -> &str {
self.as_str()
}
}
#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/doc/string_enum.md"))]
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, StringEnum)]
#[palpo_enum(rename_all = ".m.rule.snake_case")]
#[non_exhaustive]
pub enum PredefinedOverrideRuleId {
Master,
SuppressNotices,
InviteForMe,
MemberEvent,
IsUserMention,
IsRoomMention,
Tombstone,
Reaction,
#[palpo_enum(rename = ".m.rule.room.server_acl")]
RoomServerAcl,
SuppressEdits,
#[cfg(feature = "unstable-msc3930")]
#[palpo_enum(rename = ".org.matrix.msc3930.rule.poll_response")]
PollResponse,
#[doc(hidden)]
_Custom(PrivOwnedStr),
}
impl PredefinedOverrideRuleId {
pub fn kind(&self) -> RuleKind {
RuleKind::Override
}
}
#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/doc/string_enum.md"))]
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, StringEnum)]
#[palpo_enum(rename_all = ".m.rule.snake_case")]
#[non_exhaustive]
pub enum PredefinedUnderrideRuleId {
Call,
EncryptedRoomOneToOne,
RoomOneToOne,
Message,
Encrypted,
#[cfg(feature = "unstable-msc3930")]
#[palpo_enum(rename = ".org.matrix.msc3930.rule.poll_start_one_to_one")]
PollStartOneToOne,
#[cfg(feature = "unstable-msc3930")]
#[palpo_enum(rename = ".org.matrix.msc3930.rule.poll_start")]
PollStart,
#[cfg(feature = "unstable-msc3930")]
#[palpo_enum(rename = ".org.matrix.msc3930.rule.poll_end_one_to_one")]
PollEndOneToOne,
#[cfg(feature = "unstable-msc3930")]
#[palpo_enum(rename = ".org.matrix.msc3930.rule.poll_end")]
PollEnd,
#[doc(hidden)]
_Custom(PrivOwnedStr),
}
impl PredefinedUnderrideRuleId {
pub fn kind(&self) -> RuleKind {
RuleKind::Underride
}
}
#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/doc/string_enum.md"))]
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, StringEnum)]
#[palpo_enum(rename_all = ".m.rule.snake_case")]
#[non_exhaustive]
pub enum PredefinedContentRuleId {
#[doc(hidden)]
_Custom(PrivOwnedStr),
}
impl PredefinedContentRuleId {
pub fn kind(&self) -> RuleKind {
RuleKind::Content
}
}