use super::{ApiClient, BotConfig, BotResult, Database, super::super::util::Record};
use doc_for::{doc, doc_impl};
use frankenstein::{
AsyncTelegramApi, Error, ParseMode,
client_reqwest::Bot,
input_file::FileUpload,
methods::{SendMessageParams, SendStickerParams, SetMyCommandsParams},
stickers::StickerType,
types::{BotCommand, ChatType, LinkPreviewOptions, Message, ReplyParameters, User},
};
use log::{error, info};
use semantic_search::Embedding;
use std::sync::Arc;
use tokio::sync::Mutex;
const FALLBACK_MESSAGES: [&str; 5] = [
"😹 Maow?",
"😼 Meowww :3",
"🙀 Nyaaa!",
"😿 Mew...",
"😾 Prrrrr...!",
];
#[derive(Clone, Debug)]
#[doc_impl(strip = 1)]
pub enum Command {
Help,
Search(String),
Inline,
Sticker(String),
Add(String),
}
impl Command {
fn description(config: &BotConfig) -> String {
let content = format!(
"{}\n/help - {}\n/search - {}\n/inline - {}\n/sticker - {}\n/add - {}",
doc!(Command),
doc!(Command, Help),
doc!(Command, Search),
doc!(Command, Inline),
doc!(Command, Sticker),
doc!(Command, Add),
);
let postscript = config.postscript.trim();
if postscript.is_empty() {
content
} else {
format!("{content}\n{postscript}")
}
}
fn parse(text: &str, username: &str) -> Option<Self> {
let text = text.trim();
let (command, arg) = text.split_once(' ').unwrap_or((text, ""));
let slash = command.starts_with('/');
if !slash {
return None;
}
let command = &command[1..];
let (command, mention) = command.split_once('@').unwrap_or((command, ""));
if !mention.is_empty() && mention != username {
return None;
}
let command = command.to_lowercase();
match command.as_str() {
"help" => Some(Self::Help),
"search" => Some(Self::Search(arg.to_string())),
"inline" => Some(Self::Inline),
"sticker" => Some(Self::Sticker(arg.to_string())),
"add" => Some(Self::Add(arg.to_string())),
_ => None,
}
}
}
pub async fn set_commands(bot: &Bot) -> BotResult<()> {
let commands = [
("/help", doc!(Command, Help)),
("/search", doc!(Command, Search)),
("/inline", doc!(Command, Inline)),
("/sticker", doc!(Command, Sticker)),
("/add", doc!(Command, Sticker)),
];
let commands: Vec<_> = commands
.into_iter()
.map(|(command, description)| (command.to_string(), description.to_string()))
.map(|(command, description)| BotCommand {
command,
description,
})
.collect();
let set_params = SetMyCommandsParams::builder().commands(commands).build();
bot.set_my_commands(&set_params).await?;
Ok(())
}
pub async fn message_handler(
bot: &Bot,
me: &User,
msg: Message,
db: Arc<Mutex<Database>>,
api: &ApiClient,
config: &BotConfig,
) -> BotResult<()> {
let Some(username) = &me.username else {
log::error!("Bot username not found.");
return Ok(());
};
let Some(text) = &msg.text else {
if let Some(sticker) = &msg.sticker
&& matches!(sticker.sticker_type, StickerType::Regular)
{
let id = &sticker.file_id;
return reply(bot, &msg, format!("Sticker file_id: <code>{id}</code>")).await;
} else {
return answer_fallback(bot, &msg).await;
};
};
let Some(cmd) = Command::parse(text, username) else {
return answer_fallback(bot, &msg).await;
};
info!("Received valid command: `{text}`, parsed as: {cmd:?}");
match answer_command(bot, &msg, cmd, db, api, config).await {
Ok(_) => Ok(()),
Err(e) => {
error!("Failed to answer the command: {e}");
Err(e)
}
}
}
async fn answer_command(
bot: &Bot,
msg: &Message,
cmd: Command,
db: Arc<Mutex<Database>>,
api: &ApiClient,
config: &BotConfig,
) -> BotResult<()> {
let result = match cmd {
Command::Help => {
Ok(Command::description(config))
}
Command::Search(query) => {
answer_search(api, &query, db, config).await
}
Command::Inline => {
Ok("🐾 Just mention me in any chat, followed by your query, and I'll pounce into action to fetch the purr-fect meme for you! 😼✨".to_string())
}
Command::Sticker(file_id) => {
if file_id.is_empty() {
Ok("🐾 Paws and reflect! Please provide a sticker file id... 😾".to_string())
} else {
let sticker = FileUpload::String(file_id);
let send_params = SendStickerParams::builder()
.chat_id(msg.chat.id)
.sticker(sticker)
.build();
if let Err(e) = bot.send_sticker(&send_params).await {
if let Error::Api(e) = e {
if e.description.starts_with("Bad Request: wrong remote file identifier specified") {
Err("🐾 Paws and reflect! Please provide a valid sticker file id... 😾".to_string())
} else {
Err(format!("Failed to send the sticker: Api Error {}", e.description))
}
} else {
Err(format!("Failed to send the sticker: {e}"))
}
} else {
Ok("🐾 Sticker sent! Hope it made your whiskers twitch! 😼".to_string())
}
}
}
Command::Add(description) => {
if let Some(user) = &msg.from {
if user.id != config.owner {
Err("😾 Only my owner can use this command.".to_string())
} else if let Some(reply) = &msg.reply_to_message && let Some(sticker) = &reply.sticker {
insert_sticker(db, api, sticker.file_id.clone(), description).await
} else {
Err("🐾 Paws and reflect! Please reply to a sticker. 😾".to_string())
}
} else {
Err("😾 Who're you?".to_string())
}
}
};
let reply_msg = match result {
Ok(reply) => reply,
Err(error) => {
format!("😿 Oops! Something went wrong...\n{error}")
}
};
reply(bot, msg, reply_msg).await
}
async fn answer_search(
api: &ApiClient,
query: &str,
db: Arc<Mutex<Database>>,
config: &BotConfig,
) -> Result<String, String> {
if query.is_empty() {
return Ok("😾 Please prrr-ovide a query...".to_string());
}
let Ok(raw_embedding) = api.embed(query).await else {
return Err("Failed to embed the query".to_string());
};
let embedding: Embedding = raw_embedding.into();
let results = {
let mut db = db.lock().await;
db.search_with_id(config.num_results, &embedding).await
};
let Ok(results) = results else {
return Err("Failed to search the database".to_string());
};
if results.is_empty() {
return Ok("😿 No results found...".to_string());
}
let message: Vec<_> = results
.iter()
.map(|(path, similarity, file_id)| {
let percent = similarity * 100.0;
format!("🐾 {percent:.2}%: {path} | <code>/sticker {file_id}</code>")
})
.collect();
Ok(message.join("\n"))
}
async fn answer_fallback(bot: &Bot, msg: &Message) -> BotResult<()> {
if !matches!(msg.chat.type_field, ChatType::Private) {
return Ok(());
}
let idx = msg.message_id.unsigned_abs() as usize % FALLBACK_MESSAGES.len();
let reply_msg = FALLBACK_MESSAGES[idx];
reply(bot, msg, reply_msg.to_string()).await
}
async fn reply(bot: &Bot, msg: &Message, text: String) -> BotResult<()> {
let reply_params = ReplyParameters::builder()
.message_id(msg.message_id)
.build();
let link_preview_options = LinkPreviewOptions::DISABLED;
let send_params = SendMessageParams::builder()
.chat_id(msg.chat.id)
.text(text)
.reply_parameters(reply_params)
.parse_mode(ParseMode::Html)
.link_preview_options(link_preview_options)
.build();
bot.send_message(&send_params).await?;
Ok(())
}
async fn insert_sticker(db: Arc<Mutex<Database>>, api: &ApiClient, file_id: String, description: String) -> Result<String, String> {
let Ok(raw_embedding) = api.embed(&description).await else {
return Err("Failed to embed the description".to_string());
};
let embedding: Embedding = raw_embedding.into();
let record = Record {
embedding,
file_hash: "Unknown".to_string(),
file_path: format!("tg-sticker://{file_id}"),
file_id: Some(file_id),
label: description,
};
let mut db = db.lock().await;
if let Err(e) = db.insert(record).await {
Err(format!("Failed to insert record: {e}"))
} else {
Ok("Successfully inserted sticker.".to_string())
}
}