teamtalk 6.0.0

TeamTalk SDK for Rust
Documentation
use super::router::HandlerResult;
use crate::events::Result;
use crate::types::{UserId, UserRights};
use std::collections::HashMap;
use std::time::{Duration, Instant};

pub trait Middleware {
    fn before(&mut self, _ctx: &mut super::Context<'_>) -> Result<HandlerResult> {
        Ok(HandlerResult::Continue)
    }

    fn after(&mut self, _ctx: &mut super::Context<'_>) -> Result<()> {
        Ok(())
    }
}

type BeforeHook = dyn FnMut(&mut super::Context<'_>) -> Result<HandlerResult> + Send;
type AfterHook = dyn FnMut(&mut super::Context<'_>) -> Result<()> + Send;

pub struct FnMiddleware {
    before: Box<BeforeHook>,
    after: Option<Box<AfterHook>>,
}

impl FnMiddleware {
    pub fn new<F>(before: F) -> Self
    where
        F: FnMut(&mut super::Context<'_>) -> Result<HandlerResult> + Send + 'static,
    {
        Self {
            before: Box::new(before),
            after: None,
        }
    }

    pub fn with_after<F, A>(before: F, after: A) -> Self
    where
        F: FnMut(&mut super::Context<'_>) -> Result<HandlerResult> + Send + 'static,
        A: FnMut(&mut super::Context<'_>) -> Result<()> + Send + 'static,
    {
        Self {
            before: Box::new(before),
            after: Some(Box::new(after)),
        }
    }
}

impl Middleware for FnMiddleware {
    fn before(&mut self, ctx: &mut super::Context<'_>) -> Result<HandlerResult> {
        (self.before)(ctx)
    }

    fn after(&mut self, ctx: &mut super::Context<'_>) -> Result<()> {
        if let Some(after) = self.after.as_mut() {
            return after(ctx);
        }
        Ok(())
    }
}

pub struct CommandOnly;

impl Middleware for CommandOnly {
    fn before(&mut self, ctx: &mut super::Context<'_>) -> Result<HandlerResult> {
        Ok(if ctx.command.is_some() {
            HandlerResult::Continue
        } else {
            HandlerResult::Stop
        })
    }
}

pub struct RequirePrivateMessage;

impl Middleware for RequirePrivateMessage {
    fn before(&mut self, ctx: &mut super::Context<'_>) -> Result<HandlerResult> {
        Ok(if ctx.channel_id().is_none() {
            HandlerResult::Continue
        } else {
            HandlerResult::Stop
        })
    }
}

pub struct RequireChannelMessage;

impl Middleware for RequireChannelMessage {
    fn before(&mut self, ctx: &mut super::Context<'_>) -> Result<HandlerResult> {
        Ok(if ctx.channel_id().is_some() {
            HandlerResult::Continue
        } else {
            HandlerResult::Stop
        })
    }
}

pub struct RequireCommand {
    command: String,
}

impl RequireCommand {
    pub fn new(command: impl Into<String>) -> Self {
        Self {
            command: command.into(),
        }
    }
}

impl Middleware for RequireCommand {
    fn before(&mut self, ctx: &mut super::Context<'_>) -> Result<HandlerResult> {
        Ok(if ctx.is_command(&self.command) {
            HandlerResult::Continue
        } else {
            HandlerResult::Stop
        })
    }
}

pub struct RequireCommandPrefix {
    prefix: char,
}

impl RequireCommandPrefix {
    pub fn new(prefix: char) -> Self {
        Self { prefix }
    }
}

impl Middleware for RequireCommandPrefix {
    fn before(&mut self, ctx: &mut super::Context<'_>) -> Result<HandlerResult> {
        Ok(
            if ctx
                .command
                .as_ref()
                .is_some_and(|command| command.prefix == self.prefix)
            {
                HandlerResult::Continue
            } else {
                HandlerResult::Stop
            },
        )
    }
}

pub struct RequireUserIds {
    allowed: Vec<UserId>,
}

impl RequireUserIds {
    pub fn new(allowed: impl Into<Vec<UserId>>) -> Self {
        Self {
            allowed: allowed.into(),
        }
    }
}

impl Middleware for RequireUserIds {
    fn before(&mut self, ctx: &mut super::Context<'_>) -> Result<HandlerResult> {
        Ok(if self.allowed.contains(&ctx.sender_id()) {
            HandlerResult::Continue
        } else {
            HandlerResult::Stop
        })
    }
}

pub struct RequireUserType {
    allowed: Vec<u32>,
}

impl RequireUserType {
    pub fn new(allowed: impl Into<Vec<u32>>) -> Self {
        Self {
            allowed: allowed.into(),
        }
    }
}

impl Middleware for RequireUserType {
    fn before(&mut self, ctx: &mut super::Context<'_>) -> Result<HandlerResult> {
        let mut raw = unsafe { std::mem::zeroed::<teamtalk_sys::User>() };
        if !ctx
            .client
            .backend()
            .get_user(ctx.client.ptr.0, ctx.sender_id().0, &mut raw)
        {
            return Ok(HandlerResult::Stop);
        }
        Ok(if self.allowed.contains(&raw.uUserType) {
            HandlerResult::Continue
        } else {
            HandlerResult::Stop
        })
    }
}

pub struct RequireClientRightsAny {
    rights: UserRights,
}

impl RequireClientRightsAny {
    pub fn new(rights: UserRights) -> Self {
        Self { rights }
    }
}

impl Middleware for RequireClientRightsAny {
    fn before(&mut self, ctx: &mut super::Context<'_>) -> Result<HandlerResult> {
        Ok(if ctx.client.my_user_rights().has_any(self.rights) {
            HandlerResult::Continue
        } else {
            HandlerResult::Stop
        })
    }
}

pub struct RequireClientRightsAll {
    rights: UserRights,
}

impl RequireClientRightsAll {
    pub fn new(rights: UserRights) -> Self {
        Self { rights }
    }
}

impl Middleware for RequireClientRightsAll {
    fn before(&mut self, ctx: &mut super::Context<'_>) -> Result<HandlerResult> {
        Ok(if ctx.client.my_user_rights().has_all(self.rights) {
            HandlerResult::Continue
        } else {
            HandlerResult::Stop
        })
    }
}

pub struct RateLimitBySource {
    period: Duration,
    seen: HashMap<i32, Instant>,
}

impl RateLimitBySource {
    pub fn new(period: Duration) -> Self {
        Self {
            period: period.max(Duration::from_millis(50)),
            seen: HashMap::new(),
        }
    }
}

impl Middleware for RateLimitBySource {
    fn before(&mut self, ctx: &mut super::Context<'_>) -> Result<HandlerResult> {
        let source = ctx.message.source();
        let now = Instant::now();
        if let Some(last) = self.seen.get(&source)
            && now.duration_since(*last) < self.period
        {
            return Ok(HandlerResult::Stop);
        }
        self.seen.insert(source, now);
        Ok(HandlerResult::Continue)
    }
}