use crate::{
errors::EventErrorKind,
event::{
bases::{EventReturn, PropagateEventResult},
service::Service,
telegram::handler::Handler,
},
filters::Filter,
middlewares::{
inner::{wrap_to_next, Manager as InnerMiddlewareManager},
outer::Manager as OuterMiddlewareManager,
InnerMiddleware, OuterMiddleware,
},
Request,
};
use std::{
convert::Infallible,
fmt::{self, Debug, Formatter},
};
use tracing::{event, instrument, Level};
pub struct Response<Client> {
pub request: Request<Client>,
pub propagate_result: PropagateEventResult<Client>,
}
impl<Client> Debug for Response<Client> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Response")
.field("request", &self.request)
.field("propagate_result", &self.propagate_result)
.finish()
}
}
pub struct Observer<Client> {
pub(crate) event_name: &'static str,
pub(crate) handlers: Vec<Handler<Client>>,
pub(crate) common: Option<Handler<Client>>,
pub inner_middlewares: InnerMiddlewareManager<Client>,
pub outer_middlewares: OuterMiddlewareManager<Client>,
}
impl<Client> Observer<Client>
where
Client: Send + Sync + 'static,
{
#[must_use]
pub fn new(event_name: &'static str) -> Self {
Self {
event_name,
handlers: vec![],
common: None,
inner_middlewares: InnerMiddlewareManager::<Client>::default(),
outer_middlewares: OuterMiddlewareManager::<Client>::default(),
}
}
#[inline]
#[must_use]
pub fn register(mut self, handler: Handler<Client>) -> Self {
self.handlers.push(handler);
self
}
#[inline]
#[must_use]
pub fn on(self, handler: Handler<Client>) -> Self {
self.register(handler)
}
#[must_use]
pub fn registers(mut self, handlers: impl IntoIterator<Item = Handler<Client>>) -> Self {
self.handlers.extend(handlers);
self
}
#[must_use]
pub fn filter(mut self, val: impl Filter<Client>) -> Self {
if let Some(common) = self.common.take() {
self.common = Some(common.filter(val));
} else {
let handler_fn = || async move {
unreachable!("This handler never will be used");
#[allow(unreachable_code)]
Ok::<_, Infallible>(())
};
self.common = Some(Handler::new(handler_fn).filter(val));
}
self
}
#[inline]
#[must_use]
pub fn register_inner_middleware(mut self, middleware: impl InnerMiddleware<Client>) -> Self {
self.inner_middlewares.register(middleware);
self
}
#[inline]
#[must_use]
pub fn register_outer_middleware(mut self, middleware: impl OuterMiddleware<Client>) -> Self {
self.outer_middlewares.register(middleware);
self
}
}
impl<Client> Observer<Client> {
#[instrument(skip_all)]
pub async fn trigger(
&mut self,
request: Request<Client>,
) -> Result<Response<Client>, EventErrorKind>
where
Client: Send + Sync + Clone + 'static,
{
let mut request = match self.common.as_mut() {
Some(common) => {
let (result, request) = common.check(request).await?;
if !result {
event!(Level::TRACE, "Request are not pass observer filters");
return Ok(Response {
request,
propagate_result: PropagateEventResult::Rejected,
});
}
request
}
None => request,
};
for handler in &mut self.handlers {
let (result, new_request) = handler.check(request).await?;
request = new_request;
if !result {
continue;
}
event!(Level::TRACE, "Request are pass handler filters");
let response = match self.inner_middlewares.middlewares.split_first_mut() {
Some((middleware, middlewares)) => {
let next = wrap_to_next(
handler.service.clone(),
middlewares.to_vec().into_boxed_slice(),
);
middleware.call((request.clone(), next)).await
}
None => handler
.call(request.clone())
.await
.map_err(EventErrorKind::Extraction),
}?;
return match response.result {
Ok(EventReturn::Skip) => {
event!(Level::TRACE, "Handler returns skip");
continue;
}
Ok(EventReturn::Cancel) => {
event!(Level::TRACE, "Handler returns cancel");
Ok(Response {
request,
propagate_result: PropagateEventResult::Rejected,
})
}
Ok(EventReturn::Finish) => {
event!(Level::TRACE, "Handler returns finish");
Ok(Response {
request,
propagate_result: PropagateEventResult::Handled(response),
})
}
Err(_) => {
event!(Level::TRACE, "Handler returns error");
Ok(Response {
request,
propagate_result: PropagateEventResult::Handled(response),
})
}
};
}
event!(Level::TRACE, "Request are not pass handlers filters");
Ok(Response {
request,
propagate_result: PropagateEventResult::Unhandled,
})
}
}
impl<Client> Debug for Observer<Client> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Observer")
.field("event_name", &self.event_name)
.field("handlers", &self.handlers.len())
.finish_non_exhaustive()
}
}
impl<Client> Default for Observer<Client>
where
Client: Send + Sync + 'static,
{
fn default() -> Self {
Self::new("message")
}
}
impl<Client> Clone for Observer<Client> {
fn clone(&self) -> Self {
Self {
event_name: self.event_name,
handlers: self.handlers.clone(),
common: self.common.clone(),
inner_middlewares: self.inner_middlewares.clone(),
outer_middlewares: self.outer_middlewares.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
client::Reqwest,
errors::HandlerError,
filters::Command,
types::{ChatPrivate, MessageText, Update, UpdateMessage},
Bot, Extensions,
};
use anyhow::anyhow;
use std::sync::Arc;
use tokio;
#[allow(unreachable_code)]
#[tokio::test]
async fn test_observer_trigger() {
let mut observer = Observer::default()
.filter(Command::one("start"))
.register(Handler::new(|| async {
Ok::<_, Infallible>(EventReturn::Finish)
}))
.register(Handler::new(|| async {
unreachable!("It's shouldn't trigger because the first handler handles the event");
Ok::<_, Infallible>(EventReturn::Finish)
}));
let mut request = Request::<Reqwest> {
update: Arc::new(Update::Message(UpdateMessage::new(
0,
MessageText::new(0, 0, ChatPrivate::new(0), ""),
))),
bot: Bot::default(),
context: crate::Context::default(),
extensions: Extensions::default(),
};
let response = observer.trigger(request.clone()).await.unwrap();
match response.propagate_result {
PropagateEventResult::Rejected => {}
_ => panic!("Unexpected result"),
}
request.update = Arc::new(Update::Message(UpdateMessage::new(
0,
MessageText::new(0, 0, ChatPrivate::new(0), "/start"),
)));
let response = observer.trigger(request).await.unwrap();
match response.propagate_result {
PropagateEventResult::Handled(_) => {}
_ => panic!("Unexpected result"),
}
}
#[allow(unreachable_code)]
#[tokio::test]
async fn test_observer_trigger_error() {
let mut observer = Observer::<Reqwest>::default()
.register(Handler::new(|| async {
Err::<(), _>(HandlerError::new(anyhow!("test")))
}))
.register(Handler::new(|| async {
unreachable!("It's shouldn't trigger because the first handler handles the event");
Ok::<_, Infallible>(EventReturn::Finish)
}));
let request = Request::<Reqwest> {
update: Arc::new(Update::Message(UpdateMessage::new(
0,
MessageText::new(0, 0, ChatPrivate::new(0), ""),
))),
bot: Bot::default(),
context: crate::Context::default(),
extensions: Extensions::default(),
};
let response = observer.trigger(request).await.unwrap();
match response.propagate_result {
PropagateEventResult::Handled(response) => match response.result {
Err(_) => {}
_ => panic!("Unexpected result"),
},
_ => panic!("Unexpected result"),
}
}
#[tokio::test]
async fn test_observer_event_return() {
let mut observer = Observer::default()
.register(Handler::new(|| async {
Ok::<_, Infallible>(EventReturn::Skip)
}))
.register(Handler::new(|| async {
Ok::<_, Infallible>(EventReturn::Finish)
}));
let request = Request::<Reqwest> {
update: Arc::new(Update::Message(UpdateMessage::new(
0,
MessageText::new(0, 0, ChatPrivate::new(0), "/start"),
))),
bot: Bot::default(),
context: crate::Context::default(),
extensions: Extensions::default(),
};
let response = observer.trigger(request.clone()).await.unwrap();
match response.propagate_result {
PropagateEventResult::Handled(response) => match response.result {
Ok(EventReturn::Finish) => {}
_ => panic!("Unexpected result"),
},
_ => panic!("Unexpected result"),
}
let mut observer = Observer::default()
.register(Handler::new(|| async {
Ok::<_, Infallible>(EventReturn::Skip)
}))
.register(Handler::new(|| async {
Ok::<_, Infallible>(EventReturn::Cancel)
}));
let response = observer.trigger(request).await.unwrap();
match response.propagate_result {
PropagateEventResult::Rejected => {}
_ => panic!("Unexpected result"),
}
}
}