use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::fsm::{FsmState, StateContext, StateKey, StateKeyStrategy, StateStorage};
use crate::middleware::{BoxFuture, DispatchResult, Middleware, Next, PanicRecoveryMiddleware};
use crate::update::{IncomingMessage, Update};
pub trait Filter: Send + Sync + 'static {
fn check(&self, msg: &IncomingMessage) -> bool;
}
impl Filter for Arc<dyn Filter> {
fn check(&self, msg: &IncomingMessage) -> bool {
(**self).check(msg)
}
}
#[derive(Clone)]
pub struct BoxFilter(Arc<dyn Filter>);
impl BoxFilter {
fn new<F: Filter>(f: F) -> Self {
BoxFilter(Arc::new(f))
}
}
impl Filter for BoxFilter {
fn check(&self, msg: &IncomingMessage) -> bool {
self.0.check(msg)
}
}
impl std::ops::BitAnd for BoxFilter {
type Output = BoxFilter;
fn bitand(self, rhs: BoxFilter) -> BoxFilter {
BoxFilter::new(AndFilter(self, rhs))
}
}
impl std::ops::BitOr for BoxFilter {
type Output = BoxFilter;
fn bitor(self, rhs: BoxFilter) -> BoxFilter {
BoxFilter::new(OrFilter(self, rhs))
}
}
impl std::ops::Not for BoxFilter {
type Output = BoxFilter;
fn not(self) -> BoxFilter {
BoxFilter::new(NotFilter(self))
}
}
struct AndFilter(BoxFilter, BoxFilter);
impl Filter for AndFilter {
fn check(&self, m: &IncomingMessage) -> bool {
self.0.check(m) && self.1.check(m)
}
}
struct OrFilter(BoxFilter, BoxFilter);
impl Filter for OrFilter {
fn check(&self, m: &IncomingMessage) -> bool {
self.0.check(m) || self.1.check(m)
}
}
struct NotFilter(BoxFilter);
impl Filter for NotFilter {
fn check(&self, m: &IncomingMessage) -> bool {
!self.0.check(m)
}
}
struct FnFilter(Arc<dyn Fn(&IncomingMessage) -> bool + Send + Sync + 'static>);
impl Filter for FnFilter {
fn check(&self, m: &IncomingMessage) -> bool {
(self.0)(m)
}
}
fn make<F>(f: F) -> BoxFilter
where
F: Fn(&IncomingMessage) -> bool + Send + Sync + 'static,
{
BoxFilter::new(FnFilter(Arc::new(f)))
}
pub fn all() -> BoxFilter {
make(|_| true)
}
pub fn none() -> BoxFilter {
make(|_| false)
}
pub fn private() -> BoxFilter {
make(|m| m.is_private())
}
pub fn group() -> BoxFilter {
make(|m| m.is_group())
}
pub fn channel() -> BoxFilter {
make(|m| m.is_channel())
}
pub fn text() -> BoxFilter {
make(|m| m.text().is_some())
}
pub fn media() -> BoxFilter {
make(|m| m.has_media())
}
pub fn photo() -> BoxFilter {
make(|m| m.has_photo())
}
pub fn document() -> BoxFilter {
make(|m| m.has_document())
}
pub fn forwarded() -> BoxFilter {
make(|m| m.is_forwarded())
}
pub fn reply() -> BoxFilter {
make(|m| m.is_reply())
}
pub fn album() -> BoxFilter {
make(|m| m.album_id().is_some())
}
pub fn any_command() -> BoxFilter {
make(|m| m.is_bot_command())
}
pub fn command(name: impl Into<String>) -> BoxFilter {
let name = name.into();
make(move |m| m.is_command_named(&name))
}
pub fn text_contains(needle: impl Into<String>) -> BoxFilter {
let needle = needle.into();
make(move |m| m.text().is_some_and(|t| t.contains(needle.as_str())))
}
pub fn text_starts_with(prefix: impl Into<String>) -> BoxFilter {
let prefix = prefix.into();
make(move |m| m.text().is_some_and(|t| t.starts_with(prefix.as_str())))
}
pub fn from_user(id: i64) -> BoxFilter {
make(move |m| m.sender_user_id() == Some(id))
}
pub fn in_chat(id: i64) -> BoxFilter {
make(move |m| m.chat_id() == id)
}
pub fn custom<F>(f: F) -> BoxFilter
where
F: Fn(&IncomingMessage) -> bool + Send + Sync + 'static,
{
make(f)
}
type MsgFuture = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
type HandlerFn = Arc<dyn Fn(IncomingMessage) -> MsgFuture + Send + Sync + 'static>;
type FsmHandlerFn = Arc<dyn Fn(IncomingMessage, StateContext) -> MsgFuture + Send + Sync + 'static>;
#[derive(Clone)]
pub(crate) struct MessageHandler {
filter: BoxFilter,
handler: HandlerFn,
}
#[derive(Clone)]
pub(crate) struct FsmMessageHandler {
filter: BoxFilter,
expected_state: String,
handler: FsmHandlerFn,
}
pub struct Router {
scope: Option<BoxFilter>,
new_msg: Vec<MessageHandler>,
edited_msg: Vec<MessageHandler>,
fsm_new_msg: Vec<FsmMessageHandler>,
fsm_edited_msg: Vec<FsmMessageHandler>,
children: Vec<Router>,
}
impl Router {
pub fn new() -> Self {
Self {
scope: None,
new_msg: Vec::new(),
edited_msg: Vec::new(),
fsm_new_msg: Vec::new(),
fsm_edited_msg: Vec::new(),
children: Vec::new(),
}
}
pub fn scope(mut self, filter: BoxFilter) -> Self {
self.scope = Some(filter);
self
}
pub fn on_message<H, Fut>(&mut self, filter: BoxFilter, handler: H)
where
H: Fn(IncomingMessage) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let hfn: HandlerFn = Arc::new(move |msg| Box::pin(handler(msg)) as MsgFuture);
self.new_msg.push(MessageHandler {
filter,
handler: hfn,
});
}
pub fn on_edit<H, Fut>(&mut self, filter: BoxFilter, handler: H)
where
H: Fn(IncomingMessage) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let hfn: HandlerFn = Arc::new(move |msg| Box::pin(handler(msg)) as MsgFuture);
self.edited_msg.push(MessageHandler {
filter,
handler: hfn,
});
}
pub fn on_message_fsm<S, H, Fut>(&mut self, filter: BoxFilter, state: S, handler: H)
where
S: FsmState,
H: Fn(IncomingMessage, StateContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let expected_state = state.as_key();
let hfn: FsmHandlerFn = Arc::new(move |msg, ctx| Box::pin(handler(msg, ctx)) as MsgFuture);
self.fsm_new_msg.push(FsmMessageHandler {
filter,
expected_state,
handler: hfn,
});
}
pub fn on_edit_fsm<S, H, Fut>(&mut self, filter: BoxFilter, state: S, handler: H)
where
S: FsmState,
H: Fn(IncomingMessage, StateContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let expected_state = state.as_key();
let hfn: FsmHandlerFn = Arc::new(move |msg, ctx| Box::pin(handler(msg, ctx)) as MsgFuture);
self.fsm_edited_msg.push(FsmMessageHandler {
filter,
expected_state,
handler: hfn,
});
}
pub fn include(&mut self, router: Router) {
self.children.push(router);
}
pub(crate) fn flatten(self, parent_scope: Option<BoxFilter>) -> FlatHandlers {
let combined_scope = combine_scopes(parent_scope, self.scope);
let mut flat = FlatHandlers::default();
for h in self.new_msg {
flat.new_msg.push(scoped(h, combined_scope.as_ref()));
}
for h in self.edited_msg {
flat.edited_msg.push(scoped(h, combined_scope.as_ref()));
}
for h in self.fsm_new_msg {
flat.fsm_new_msg
.push(scoped_fsm(h, combined_scope.as_ref()));
}
for h in self.fsm_edited_msg {
flat.fsm_edited_msg
.push(scoped_fsm(h, combined_scope.as_ref()));
}
for child in self.children {
let child_flat = child.flatten(combined_scope.clone());
flat.new_msg.extend(child_flat.new_msg);
flat.edited_msg.extend(child_flat.edited_msg);
flat.fsm_new_msg.extend(child_flat.fsm_new_msg);
flat.fsm_edited_msg.extend(child_flat.fsm_edited_msg);
}
flat
}
}
fn combine_scopes(parent: Option<BoxFilter>, own: Option<BoxFilter>) -> Option<BoxFilter> {
match (parent, own) {
(Some(p), Some(s)) => Some(p & s),
(Some(p), None) | (None, Some(p)) => Some(p),
(None, None) => None,
}
}
fn scoped(h: MessageHandler, scope: Option<&BoxFilter>) -> MessageHandler {
match scope {
Some(s) => MessageHandler {
filter: s.clone() & h.filter,
handler: h.handler,
},
None => h,
}
}
fn scoped_fsm(h: FsmMessageHandler, scope: Option<&BoxFilter>) -> FsmMessageHandler {
match scope {
Some(s) => FsmMessageHandler {
filter: s.clone() & h.filter,
..h
},
None => h,
}
}
impl Default for Router {
fn default() -> Self {
Self::new()
}
}
#[derive(Default)]
pub(crate) struct FlatHandlers {
pub new_msg: Vec<MessageHandler>,
pub edited_msg: Vec<MessageHandler>,
pub fsm_new_msg: Vec<FsmMessageHandler>,
pub fsm_edited_msg: Vec<FsmMessageHandler>,
}
pub struct Dispatcher {
new_msg: Vec<MessageHandler>,
edited_msg: Vec<MessageHandler>,
fsm_new_msg: Vec<FsmMessageHandler>,
fsm_edited_msg: Vec<FsmMessageHandler>,
middlewares: Vec<Arc<dyn Middleware>>,
state_storage: Option<Arc<dyn StateStorage>>,
key_strategy: StateKeyStrategy,
}
impl Dispatcher {
pub fn new() -> Self {
Self {
new_msg: Vec::new(),
edited_msg: Vec::new(),
fsm_new_msg: Vec::new(),
fsm_edited_msg: Vec::new(),
middlewares: vec![Arc::new(PanicRecoveryMiddleware::new())],
state_storage: None,
key_strategy: StateKeyStrategy::default(),
}
}
pub fn middleware(&mut self, mw: impl Middleware) {
self.middlewares.push(Arc::new(mw));
}
pub fn with_state_storage(&mut self, storage: Arc<dyn StateStorage>) {
self.state_storage = Some(storage);
}
pub fn with_key_strategy(&mut self, strategy: StateKeyStrategy) {
self.key_strategy = strategy;
}
pub fn on_message<H, Fut>(&mut self, filter: BoxFilter, handler: H)
where
H: Fn(IncomingMessage) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let hfn: HandlerFn = Arc::new(move |msg| Box::pin(handler(msg)) as MsgFuture);
self.new_msg.push(MessageHandler {
filter,
handler: hfn,
});
}
pub fn on_edit<H, Fut>(&mut self, filter: BoxFilter, handler: H)
where
H: Fn(IncomingMessage) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let hfn: HandlerFn = Arc::new(move |msg| Box::pin(handler(msg)) as MsgFuture);
self.edited_msg.push(MessageHandler {
filter,
handler: hfn,
});
}
pub fn on_message_fsm<S, H, Fut>(&mut self, filter: BoxFilter, state: S, handler: H)
where
S: FsmState,
H: Fn(IncomingMessage, StateContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
if self.state_storage.is_none() {
tracing::warn!(
"on_message_fsm registered without a StateStorage - \
this handler will never fire. Call dp.with_state_storage(storage) first."
);
}
let expected_state = state.as_key();
let hfn: FsmHandlerFn = Arc::new(move |msg, ctx| Box::pin(handler(msg, ctx)) as MsgFuture);
self.fsm_new_msg.push(FsmMessageHandler {
filter,
expected_state,
handler: hfn,
});
}
pub fn on_edit_fsm<S, H, Fut>(&mut self, filter: BoxFilter, state: S, handler: H)
where
S: FsmState,
H: Fn(IncomingMessage, StateContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
if self.state_storage.is_none() {
tracing::warn!(
"on_edit_fsm registered without a StateStorage - \
this handler will never fire. Call dp.with_state_storage(storage) first."
);
}
let expected_state = state.as_key();
let hfn: FsmHandlerFn = Arc::new(move |msg, ctx| Box::pin(handler(msg, ctx)) as MsgFuture);
self.fsm_edited_msg.push(FsmMessageHandler {
filter,
expected_state,
handler: hfn,
});
}
pub fn include(&mut self, router: Router) {
let flat = router.flatten(None);
self.new_msg.extend(flat.new_msg);
self.edited_msg.extend(flat.edited_msg);
self.fsm_new_msg.extend(flat.fsm_new_msg);
self.fsm_edited_msg.extend(flat.fsm_edited_msg);
}
pub async fn dispatch(&self, update: Update) {
let new_msg = Arc::new(self.new_msg.clone());
let edited_msg = Arc::new(self.edited_msg.clone());
let fsm_new = Arc::new(self.fsm_new_msg.clone());
let fsm_edited = Arc::new(self.fsm_edited_msg.clone());
let storage = self.state_storage.clone(); let strategy = self.key_strategy;
let endpoint: Arc<dyn Fn(Update) -> BoxFuture + Send + Sync> =
Arc::new(move |upd: Update| {
let new_msg = Arc::clone(&new_msg);
let edited_msg = Arc::clone(&edited_msg);
let fsm_new = Arc::clone(&fsm_new);
let fsm_edited = Arc::clone(&fsm_edited);
let storage = storage.clone();
Box::pin(async move {
dispatch_to_handlers(
upd,
&new_msg,
&edited_msg,
&fsm_new,
&fsm_edited,
storage,
strategy,
)
.await;
Ok(()) as DispatchResult
})
});
if self.middlewares.is_empty() {
if let Err(e) = (endpoint)(update).await {
tracing::error!(error = %e, "dispatch error");
}
return;
}
let chain: Arc<[Arc<dyn Middleware>]> = self.middlewares.clone().into();
let next = Next::new(chain, endpoint);
if let Err(e) = next.run(update).await {
tracing::error!(error = %e, "dispatch error");
}
}
}
impl Default for Dispatcher {
fn default() -> Self {
Self::new()
}
}
async fn dispatch_to_handlers(
update: Update,
new_msg: &[MessageHandler],
edited_msg: &[MessageHandler],
fsm_new: &[FsmMessageHandler],
fsm_edited: &[FsmMessageHandler],
storage: Option<Arc<dyn StateStorage>>,
strategy: StateKeyStrategy,
) {
match update {
Update::NewMessage(msg) => {
run_message(msg, new_msg, fsm_new, storage, strategy).await;
}
Update::MessageEdited(msg) => {
run_message(msg, edited_msg, fsm_edited, storage, strategy).await;
}
_ => {
}
}
}
async fn run_message(
msg: IncomingMessage,
regular: &[MessageHandler],
fsm: &[FsmMessageHandler],
storage: Option<Arc<dyn StateStorage>>,
strategy: StateKeyStrategy,
) {
if let Some(ref arc_storage) = storage
&& !fsm.is_empty()
{
let key = StateKey::from_message(&msg, strategy);
let current_state = match arc_storage.get_state(key.clone()).await {
Ok(s) => s,
Err(e) => {
tracing::error!(error = %e, "FSM: failed to read state");
None
}
};
if let Some(ref current) = current_state {
let matched_idx = fsm
.iter()
.position(|h| h.expected_state == *current && h.filter.check(&msg));
if let Some(idx) = matched_idx {
let ctx = StateContext::new(Arc::clone(arc_storage), key, current.clone());
(fsm[idx].handler)(msg, ctx).await;
return;
}
}
}
let matched_idx = regular.iter().position(|h| h.filter.check(&msg));
if let Some(idx) = matched_idx {
(regular[idx].handler)(msg).await;
}
}