tgreddit 0.1.3

Get the top posts of your favorite subreddits to Telegram
use crate::*;
use anyhow::Result;
use lazy_static::lazy_static;
use regex::Regex;
use std::sync::Arc;
use teloxide::{
    dispatching::DefaultKey,
    utils::command::{BotCommands, ParseError},
};

#[derive(BotCommands, Clone)]
#[command(
    rename_rule = "lowercase",
    description = "These commands are supported:"
)]
pub enum Command {
    #[command(description = "display this text")]
    Help,
    #[command(
        description = "subscribe to subreddit's top posts",
        parse_with = parse_subscribe_message
    )]
    Sub(SubscriptionArgs),
    #[command(description = "unsubscribe from subreddit's top posts")]
    Unsub(String),
    #[command(description = "list subreddit subscriptions")]
    ListSubs,
    #[command(description = "get top posts", parse_with = parse_subscribe_message)]
    Get(SubscriptionArgs),
}

pub struct MyBot {
    pub dispatcher: Dispatcher<Arc<Bot>, anyhow::Error, DefaultKey>,
    pub tg: Arc<Bot>,
}

impl MyBot {
    pub async fn new(config: Arc<config::Config>) -> Result<Self> {
        let tg = Arc::new(Bot::new(config.telegram_bot_token.expose_secret()));
        tg.set_my_commands(Command::bot_commands()).await?;

        let handler = Update::filter_message().branch(
            dptree::filter(|msg: Message, config: Arc<config::Config>| {
                msg.from()
                    .map(|user| config.authorized_user_ids.contains(&user.id.0))
                    .unwrap_or_default()
            })
            .filter_command::<Command>()
            .endpoint(handle_command),
        );

        let dispatcher = Dispatcher::builder(tg.clone(), handler)
            .dependencies(dptree::deps![config.clone()])
            .default_handler(|upd| async move {
                warn!("unhandled update: {:?}", upd);
            })
            .error_handler(LoggingErrorHandler::with_custom_text(
                "an error has occurred in the dispatcher",
            ))
            .build();

        let my_bot = MyBot {
            dispatcher,
            tg: tg.clone(),
        };
        Ok(my_bot)
    }

    pub fn spawn(
        mut self,
    ) -> (
        tokio::task::JoinHandle<()>,
        teloxide::dispatching::ShutdownToken,
    ) {
        let shutdown_token = self.dispatcher.shutdown_token();
        (
            tokio::spawn(async move { self.dispatcher.dispatch().await }),
            shutdown_token,
        )
    }
}

pub async fn handle_command(
    message: Message,
    tg: Arc<Bot>,
    command: Command,
    config: Arc<config::Config>,
) -> Result<()> {
    async fn handle(
        message: &Message,
        tg: &Bot,
        command: Command,
        config: Arc<config::Config>,
    ) -> Result<()> {
        match command {
            Command::Help => {
                tg.send_message(message.chat.id, Command::descriptions().to_string())
                    .await?;
            }
            Command::Sub(mut args) => {
                let db = db::Database::open(&config)?;
                let chat_id = message.chat.id.0;
                let subreddit_about = reddit::get_subreddit_about(&args.subreddit).await;
                match subreddit_about {
                    Ok(data) => {
                        args.subreddit = data.display_name;
                        db.subscribe(chat_id, &args)?;
                        info!("subscribed in chat id {chat_id} with {args:#?};");
                        tg.send_message(
                            ChatId(chat_id),
                            format!("Subscribed to r/{}", args.subreddit),
                        )
                        .await?;
                    }
                    Err(reddit::SubredditAboutError::NoSuchSubreddit) => {
                        tg.send_message(ChatId(chat_id), "No such subreddit")
                            .await?;
                    }
                    Err(err) => {
                        Err(err)?;
                    }
                }
            }
            Command::Unsub(subreddit) => {
                let db = db::Database::open(&config)?;
                let chat_id = message.chat.id.0;
                let subreddit = subreddit.replace("r/", "");
                let reply = match db.unsubscribe(chat_id, &subreddit) {
                    Ok(sub) => format!("Unsubscribed from r/{sub}"),
                    Err(_) => format!("Error: Not subscribed to r/{subreddit}"),
                };
                tg.send_message(ChatId(chat_id), reply).await?;
            }
            Command::ListSubs => {
                let db = db::Database::open(&config)?;
                let subs = db.get_subscriptions_for_chat(message.chat.id.0)?;
                let reply = messages::format_subscription_list(&subs);
                tg.send_message(message.chat.id, reply).await?;
            }
            Command::Get(args) => {
                let subreddit = &args.subreddit;
                let limit = args
                    .limit
                    .or(config.default_limit)
                    .unwrap_or(config::DEFAULT_LIMIT);
                let time = args
                    .time
                    .or(config.default_time)
                    .unwrap_or(config::DEFAULT_TIME_PERIOD);
                let filter = args.filter.or(config.default_filter);
                let chat_id = message.chat.id.0;

                let posts = reddit::get_subreddit_top_posts(subreddit, limit, &time)
                    .await
                    .context("failed to get posts")?
                    .into_iter()
                    .filter(|p| {
                        if filter.is_some() {
                            filter.as_ref() == Some(&p.post_type)
                        } else {
                            true
                        }
                    })
                    .collect::<Vec<_>>();

                debug!("got {} post(s) for subreddit /r/{}", posts.len(), subreddit);

                if !posts.is_empty() {
                    for post in posts {
                        if let Err(e) = handle_new_post(&config, tg, chat_id, &post).await {
                            error!("failed to handle new post: {e}");
                        }
                    }
                } else {
                    tg.send_message(message.chat.id, "No posts found").await?;
                }
            }
        };

        Ok(())
    }

    if let Err(err) = handle(&message, &tg, command, config).await {
        error!("failed to handle message: {}", err);
        tg.send_message(message.chat.id, "Something went wrong")
            .await?;
    }

    Ok(())
}

fn parse_subscribe_message(input: String) -> Result<(SubscriptionArgs,), ParseError> {
    lazy_static! {
        static ref SUBREDDIT_RE: Regex = Regex::new(r"^[^\s]+").unwrap();
        static ref LIMIT_RE: Regex = Regex::new(r"\blimit=(\d+)\b").unwrap();
        static ref TIME_RE: Regex = Regex::new(r"\btime=(\w+)\b").unwrap();
        static ref FILTER_RE: Regex = Regex::new(r"\bfilter=(\w+)\b").unwrap();
    }

    let subreddit_match = SUBREDDIT_RE
        .find(&input)
        .ok_or_else(|| ParseError::Custom("No subreddit given".into()))?;
    let subreddit = subreddit_match
        .as_str()
        .to_string()
        .replace("/r/", "")
        .replace("r/", "");
    let rest = &input[(subreddit_match.end())..];

    let limit: Option<u32> = LIMIT_RE
        .captures(rest)
        .and_then(|caps| caps.get(1))
        .and_then(|m| m.as_str().parse().ok());

    let time = Ok(TIME_RE.captures(rest))
        .map(|o| o.and_then(|caps| caps.get(1)))
        .and_then(|o| match o {
            Some(m) => m
                .as_str()
                .parse::<TopPostsTimePeriod>()
                .map(Some)
                .map_err(|e| ParseError::IncorrectFormat(e.into())),
            None => Ok(None),
        })?;

    let filter = Ok(FILTER_RE.captures(rest))
        .map(|o| o.and_then(|caps| caps.get(1)))
        .and_then(|o| match o {
            Some(m) => m
                .as_str()
                .parse::<PostType>()
                .map(Some)
                .map_err(|e| ParseError::IncorrectFormat(e.into())),
            None => Ok(None),
        })?;

    let args = SubscriptionArgs {
        subreddit,
        limit,
        time,
        filter,
    };

    Ok((args,))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_parse_subscribe_message_only_subreddit() {
        let args = parse_subscribe_message("AnimalsBeingJerks".to_string()).unwrap();
        assert_eq!(
            args.0,
            SubscriptionArgs {
                subreddit: "AnimalsBeingJerks".to_string(),
                limit: None,
                time: None,
                filter: None,
            },
        )
    }

    #[test]
    fn test_parse_subscribe_message_strips_prefix() {
        let args = parse_subscribe_message("r/AnimalsBeingJerks".to_string()).unwrap();
        assert_eq!(
            args.0,
            SubscriptionArgs {
                subreddit: "AnimalsBeingJerks".to_string(),
                limit: None,
                time: None,
                filter: None,
            },
        );

        let args = parse_subscribe_message("/r/AnimalsBeingJerks".to_string()).unwrap();
        assert_eq!(
            args.0,
            SubscriptionArgs {
                subreddit: "AnimalsBeingJerks".to_string(),
                limit: None,
                time: None,
                filter: None,
            },
        )
    }

    #[test]
    fn test_parse_subscribe_message() {
        let args =
            parse_subscribe_message("AnimalsBeingJerks limit=5 time=week filter=video".to_string())
                .unwrap();
        assert_eq!(
            args.0,
            SubscriptionArgs {
                subreddit: "AnimalsBeingJerks".to_string(),
                limit: Some(5),
                time: Some(TopPostsTimePeriod::Week),
                filter: Some(PostType::Video),
            },
        )
    }
}