use super::{Middleware, MiddlewareResponse};
use crate::{errors::EventErrorKind, event::EventReturn, Request};
use tracing::instrument;
#[derive(Debug, Default, Clone)]
pub struct UserContext;
impl UserContext {
#[inline]
#[must_use]
pub const fn new() -> Self {
Self {}
}
}
impl<Client> Middleware<Client> for UserContext
where
Client: Send + Sync + 'static,
{
#[instrument(skip_all)]
async fn call(
&mut self,
mut request: Request<Client>,
) -> Result<MiddlewareResponse<Client>, EventErrorKind> {
if let Some(from) = request.update.from() {
request.context.insert("event_user", from.clone());
}
if let Some(chat) = request.update.chat() {
request.context.insert("event_chat", chat.clone());
}
if let Some(message_thread_id) = request.update.message_thread_id() {
request
.context
.insert("event_message_thread_id", message_thread_id);
}
if let Some(business_connection_id) = request.update.business_connection_id() {
request.context.insert(
"event_business_connection_id",
business_connection_id.to_owned(),
);
}
Ok((request, EventReturn::default()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
client::Reqwest,
context::Context,
enums::UpdateType,
event::telegram::Handler,
router::{PropagateEvent as _, Router},
types::{Chat, ChatPrivate, Message, MessageText, Update, UpdateMessage, User},
Bot, Extensions,
};
use std::{convert::Infallible, sync::Arc};
#[tokio::test]
async fn test_user_context() {
let router = Router::new("main")
.on_update(|observer| observer.register_outer_middleware(UserContext))
.on_message(|observer| {
observer.register(Handler::new(|context: Context| async move {
context.get::<User>("event_user").unwrap();
context.get::<Chat>("event_chat").unwrap();
context.get::<i64>("event_message_thread_id").unwrap();
Ok::<_, Infallible>(EventReturn::default())
}))
});
let mut router_configured = router.configure_default();
let request = Request::<Reqwest> {
update: Arc::new(Update::Message(UpdateMessage::new(
0,
Message::Text(
MessageText::new(0, 0, ChatPrivate::new(0), "")
.from(User::new(0, true, ""))
.message_thread_id(0),
),
))),
bot: Bot::default(),
context: Context::default(),
extensions: Extensions::default(),
};
router_configured
.propagate_event(UpdateType::Message, request)
.await
.unwrap();
}
#[tokio::test]
#[should_panic]
async fn test_user_context_panic() {
let router = Router::new("main")
.on_message(|observer| {
observer.register(Handler::new(|context: Context| async move {
context.get::<User>("event_user").unwrap();
context.get::<Chat>("event_chat").unwrap();
context.get::<i64>("event_message_thread_id").unwrap();
Ok::<_, Infallible>(EventReturn::default())
}))
})
.on_update(|observer| observer.register_outer_middleware(UserContext));
let mut router_configured = router.configure_default();
let request = Request::<Reqwest> {
update: Arc::new(Update::Message(UpdateMessage::new(
0,
MessageText::new(0, 0, ChatPrivate::new(0), ""),
))),
bot: Bot::default(),
context: crate::Context::default(),
extensions: Extensions::default(),
};
router_configured
.propagate_event(UpdateType::Message, request)
.await
.unwrap();
}
}