use crate::{
bot::Router,
ctx::Ctx,
error::{Error, Result},
keyboard::{ButtonKind, Embed as FkEmbed, Reply},
platform::PlatformKind,
};
use async_trait::async_trait;
use serenity::{
all::{
ButtonStyle, ChannelId, Command, CommandInteraction, CommandOptionType,
ComponentInteraction, CreateActionRow, CreateButton, CreateCommand, CreateCommandOption,
CreateEmbed, CreateEmbedFooter, CreateInteractionResponse,
CreateInteractionResponseFollowup, CreateInteractionResponseMessage, CreateMessage,
EditInteractionResponse, GatewayIntents, Interaction, Message, Ready,
},
client::{Context as SerenityContext, EventHandler},
Client,
};
use std::sync::Arc;
pub async fn run(
token: String,
router: Arc<Router>,
commands: Vec<(String, Option<String>)>,
) -> Result<()> {
tracing::info!("starting discord adapter");
let intents = GatewayIntents::GUILD_MESSAGES
| GatewayIntents::DIRECT_MESSAGES
| GatewayIntents::MESSAGE_CONTENT;
let mut client = Client::builder(&token, intents)
.event_handler(Handler {
router,
commands: Arc::new(commands),
})
.await
.map_err(|e| Error::platform("discord", e))?;
client
.start()
.await
.map_err(|e| Error::platform("discord", e))?;
Ok(())
}
struct Handler {
router: Arc<Router>,
commands: Arc<Vec<(String, Option<String>)>>,
}
#[async_trait]
impl EventHandler for Handler {
async fn ready(&self, ctx: SerenityContext, ready: Ready) {
tracing::info!(bot = %ready.user.name, "discord adapter ready");
let mut batch: Vec<CreateCommand> = Vec::new();
for (name, desc) in self.commands.iter() {
let trimmed = name.trim_start_matches('/').to_ascii_lowercase();
if trimmed.is_empty() || !is_valid_slash_name(&trimmed) {
continue;
}
let description = desc
.clone()
.filter(|s| !s.is_empty())
.unwrap_or_else(|| trimmed.clone());
let description = truncate_chars(&description, 100);
let opt = CreateCommandOption::new(
CommandOptionType::String,
"args",
"arguments passed to the command (optional)",
)
.required(false);
batch.push(
CreateCommand::new(trimmed)
.description(description)
.add_option(opt),
);
}
match Command::set_global_commands(&ctx.http, batch).await {
Ok(cmds) => {
tracing::info!(
count = cmds.len(),
"discord: registered global slash commands"
);
}
Err(e) => {
tracing::warn!(error = %e, "discord: could not register slash commands");
}
}
}
async fn message(&self, ctx: SerenityContext, msg: Message) {
if msg.author.bot {
return;
}
let router = Arc::clone(&self.router);
let channel_id = msg.channel_id;
let http = ctx.http.clone();
let is_dm = detect_dm(&ctx, channel_id).await;
let reply_fn: crate::ctx::ReplyFn = Box::new(move |reply: Reply| {
let http = http.clone();
Box::pin(async move {
let mut msg = CreateMessage::new();
if !reply.get_text().is_empty() {
msg = msg.content(reply.get_text());
}
if let Some(em) = reply.get_embed() {
msg = msg.add_embed(to_discord_embed(em));
}
if let Some(kb) = reply.get_keyboard() {
msg = msg.components(build_rows(kb));
}
channel_id
.send_message(&http, msg)
.await
.map_err(|e| Error::platform("discord", e))?;
Ok(())
})
});
let fouko_ctx = Ctx::new_full(
PlatformKind::Discord,
channel_id.to_string(),
msg.author.id.to_string(),
msg.content.clone(),
reply_fn,
Some(is_dm),
None,
);
if let Err(e) = router.dispatch(fouko_ctx).await {
tracing::warn!(error = %e, "discord handler error");
}
}
async fn interaction_create(&self, ctx: SerenityContext, interaction: Interaction) {
match interaction {
Interaction::Component(component) => {
handle_component(&ctx, &self.router, component).await;
}
Interaction::Command(command) => {
handle_command(&ctx, &self.router, command).await;
}
_ => {}
}
}
}
async fn handle_command(ctx: &SerenityContext, router: &Router, command: CommandInteraction) {
let channel_id = command.channel_id;
let user_id = command.user.id.to_string();
let is_dm = detect_dm(ctx, channel_id).await;
let mut text = format!("/{}", command.data.name);
for opt in &command.data.options {
if let Some(v) = opt.value.as_str() {
text.push(' ');
text.push_str(v);
}
}
let defer = CreateInteractionResponse::Defer(CreateInteractionResponseMessage::new());
if let Err(e) = command.create_response(&ctx.http, defer).await {
tracing::debug!(error = %e, "discord slash-cmd defer failed");
}
let http = ctx.http.clone();
let cmd_clone = command.clone();
let first_call = Arc::new(std::sync::atomic::AtomicBool::new(true));
let reply_fn: crate::ctx::ReplyFn = Box::new(move |reply: Reply| {
let http = http.clone();
let command = cmd_clone.clone();
let first = first_call.clone();
Box::pin(async move {
let is_first = first
.compare_exchange(
true,
false,
std::sync::atomic::Ordering::SeqCst,
std::sync::atomic::Ordering::SeqCst,
)
.is_ok();
if is_first {
let mut edit = EditInteractionResponse::new();
if !reply.get_text().is_empty() {
edit = edit.content(reply.get_text());
}
if let Some(em) = reply.get_embed() {
edit = edit.embed(to_discord_embed(em));
}
if let Some(kb) = reply.get_keyboard() {
edit = edit.components(build_rows(kb));
}
command
.edit_response(&http, edit)
.await
.map_err(|e| Error::platform("discord", e))?;
} else {
let mut follow = CreateInteractionResponseFollowup::new();
if !reply.get_text().is_empty() {
follow = follow.content(reply.get_text());
}
if let Some(em) = reply.get_embed() {
follow = follow.add_embed(to_discord_embed(em));
}
if let Some(kb) = reply.get_keyboard() {
follow = follow.components(build_rows(kb));
}
command
.create_followup(&http, follow)
.await
.map_err(|e| Error::platform("discord", e))?;
}
Ok(())
})
});
let fouko_ctx = Ctx::new_full(
PlatformKind::Discord,
channel_id.to_string(),
user_id,
text,
reply_fn,
Some(is_dm),
None,
);
if let Err(e) = router.dispatch(fouko_ctx).await {
tracing::warn!(error = %e, "discord slash-cmd handler error");
}
}
async fn handle_component(ctx: &SerenityContext, router: &Router, component: ComponentInteraction) {
let channel_id = component.channel_id;
let http = ctx.http.clone();
let data = component.data.custom_id.clone();
let user_id = component.user.id.to_string();
let is_dm = detect_dm(ctx, channel_id).await;
let defer = CreateInteractionResponse::Acknowledge;
if let Err(e) = component.create_response(&http, defer).await {
tracing::debug!(error = %e, "discord component defer failed");
}
let http_for_reply = http.clone();
let comp_for_reply = component.clone();
let reply_fn: crate::ctx::ReplyFn = Box::new(move |reply: Reply| {
let http = http_for_reply.clone();
let comp = comp_for_reply.clone();
Box::pin(async move {
let mut follow = CreateInteractionResponseFollowup::new();
if !reply.get_text().is_empty() {
follow = follow.content(reply.get_text());
}
if let Some(em) = reply.get_embed() {
follow = follow.add_embed(to_discord_embed(em));
}
if let Some(kb) = reply.get_keyboard() {
follow = follow.components(build_rows(kb));
}
comp.create_followup(&http, follow)
.await
.map_err(|e| Error::platform("discord", e))?;
Ok(())
})
});
let http_for_edit = http.clone();
let comp_for_edit = component.clone();
let edit_fn: crate::ctx::EditFn = Arc::new(move |reply: Reply| {
let http = http_for_edit.clone();
let comp = comp_for_edit.clone();
Box::pin(async move {
let mut edit = serenity::all::EditMessage::new();
edit = edit.content(reply.get_text().to_string());
let mut embeds = Vec::new();
if let Some(em) = reply.get_embed() {
embeds.push(to_discord_embed(em));
}
edit = edit.embeds(embeds);
let components = match reply.get_keyboard() {
Some(kb) => build_rows(kb),
None => Vec::new(),
};
edit = edit.components(components);
let mut msg = comp.message.clone();
msg.edit(&http, edit)
.await
.map_err(|e| Error::platform("discord", e))?;
Ok(())
})
});
let fouko_ctx = Ctx::new_with_edit(
PlatformKind::Discord,
channel_id.to_string(),
user_id,
data.clone(),
reply_fn,
Some(is_dm),
Some(data),
Some(edit_fn),
);
if let Err(e) = router.dispatch(fouko_ctx).await {
tracing::warn!(error = %e, "discord interaction handler error");
}
}
fn build_rows(kb: &crate::keyboard::Keyboard) -> Vec<CreateActionRow> {
kb.rows()
.iter()
.map(|row| {
let buttons: Vec<CreateButton> = row
.iter()
.map(|b| match &b.kind {
ButtonKind::Callback(id) => CreateButton::new(id.clone())
.label(b.label())
.style(ButtonStyle::Primary),
ButtonKind::Url(u) => CreateButton::new_link(u.clone()).label(b.label()),
})
.collect();
CreateActionRow::Buttons(buttons)
})
.collect()
}
fn is_valid_slash_name(s: &str) -> bool {
let mut chars = s.chars();
let Some(first) = chars.next() else {
return false;
};
if !first.is_ascii_lowercase() {
return false;
}
chars.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_' || c == '-')
&& s.len() <= 32
}
fn truncate_chars(s: &str, max_chars: usize) -> String {
let mut out = String::new();
for (i, ch) in s.chars().enumerate() {
if i >= max_chars {
break;
}
out.push(ch);
}
out
}
async fn detect_dm(ctx: &SerenityContext, channel_id: ChannelId) -> bool {
match channel_id.to_channel(&ctx.http).await {
Ok(ch) => ch.private().is_some(),
Err(_) => false,
}
}
fn to_discord_embed(src: &FkEmbed) -> CreateEmbed {
let mut em = CreateEmbed::new();
if let Some(t) = src.get_title() {
em = em.title(t);
}
if let Some(u) = src.get_url() {
em = em.url(u);
}
if let Some(d) = src.get_description() {
em = em.description(d);
}
if let Some(c) = src.get_color() {
em = em.colour(c);
}
for f in src.get_fields() {
em = em.field(f.name(), f.value(), f.is_inline());
}
if let Some(url) = src.get_image() {
em = em.image(url);
}
if let Some(url) = src.get_thumbnail() {
em = em.thumbnail(url);
}
if let Some(foot) = src.get_footer() {
em = em.footer(CreateEmbedFooter::new(foot));
}
em
}