use std::sync::Arc;
use anyhow::{Context as _, Result};
use async_trait::async_trait;
use serenity::all::{
ButtonStyle, Client, Command, CreateActionRow, CreateAttachment, CreateButton, CreateCommand,
CreateInteractionResponse, CreateMessage, EditMessage, GatewayIntents, Interaction, Message,
Ready,
};
use tokio::sync::mpsc;
use super::adapter::{ChannelId as ColletChannelId, IncomingCommand, PlatformAdapter};
use super::commands::parse_remote_command;
pub struct DiscordAdapter {
token: String,
http: Arc<serenity::http::Http>,
}
impl DiscordAdapter {
pub fn new(token: String) -> Self {
let http = Arc::new(serenity::http::Http::new(&token));
Self { token, http }
}
fn parse_channel(channel: &ColletChannelId) -> Result<serenity::all::ChannelId> {
let id: u64 = channel
.channel
.parse()
.context("invalid Discord channel ID")?;
Ok(serenity::all::ChannelId::new(id))
}
}
#[async_trait]
impl PlatformAdapter for DiscordAdapter {
fn platform_name(&self) -> &str {
"discord"
}
fn max_message_length(&self) -> usize {
2000
}
async fn register_commands(&self, commands: &[(&str, &str)]) -> Result<()> {
if commands.is_empty() {
return Ok(());
}
if self.http.application_id().is_none() {
let info = self.http.get_current_application_info().await?;
self.http.set_application_id(info.id);
}
let cmds: Vec<CreateCommand> = commands
.iter()
.map(|(name, desc)| CreateCommand::new(*name).description(*desc))
.collect();
Command::set_global_commands(&self.http, cmds).await?;
tracing::info!("[discord] registered {} skill command(s)", commands.len());
Ok(())
}
async fn send_typing(&self, channel: &ColletChannelId) -> Result<()> {
let ch = Self::parse_channel(channel)?;
ch.broadcast_typing(&self.http).await?;
Ok(())
}
async fn send_message(&self, channel: &ColletChannelId, text: &str) -> Result<()> {
let ch = Self::parse_channel(channel)?;
ch.say(&self.http, text).await?;
Ok(())
}
async fn send_long_message(
&self,
channel: &ColletChannelId,
text: &str,
filename: Option<&str>,
) -> Result<()> {
let ch = Self::parse_channel(channel)?;
let name = filename.unwrap_or("output.txt");
let attachment = CreateAttachment::bytes(text.as_bytes(), name);
let msg = CreateMessage::new().add_file(attachment);
ch.send_message(&self.http, msg).await?;
Ok(())
}
async fn send_buttons(
&self,
channel: &ColletChannelId,
text: &str,
buttons: &[(String, String)],
) -> Result<()> {
let ch = Self::parse_channel(channel)?;
let btns: Vec<CreateButton> = buttons
.iter()
.map(|(id, label)| {
CreateButton::new(id)
.label(label)
.style(ButtonStyle::Primary)
})
.collect();
let row = CreateActionRow::Buttons(btns);
let msg = CreateMessage::new().content(text).components(vec![row]);
ch.send_message(&self.http, msg).await?;
Ok(())
}
async fn edit_message(
&self,
channel: &ColletChannelId,
message_id: &str,
new_text: &str,
) -> Result<bool> {
let ch = Self::parse_channel(channel)?;
let msg_id = serenity::all::MessageId::new(
message_id
.parse::<u64>()
.context("invalid Discord message ID")?,
);
ch.edit_message(&self.http, msg_id, EditMessage::new().content(new_text))
.await?;
Ok(true)
}
async fn run(&self, command_tx: mpsc::UnboundedSender<IncomingCommand>) -> Result<()> {
let handler = DiscordHandler { command_tx };
let intents = GatewayIntents::GUILD_MESSAGES
| GatewayIntents::DIRECT_MESSAGES
| GatewayIntents::MESSAGE_CONTENT;
let mut client = Client::builder(&self.token, intents)
.event_handler(handler)
.await
.context("failed to build Discord client")?;
client.start().await.context("Discord client error")?;
Ok(())
}
}
struct DiscordHandler {
command_tx: mpsc::UnboundedSender<IncomingCommand>,
}
#[async_trait]
impl serenity::prelude::EventHandler for DiscordHandler {
async fn message(&self, _ctx: serenity::prelude::Context, msg: Message) {
if msg.author.bot {
return;
}
let channel = ColletChannelId {
platform: "discord".into(),
channel: msg.channel_id.to_string(),
thread: None,
};
let command = parse_remote_command(&msg.content);
let incoming = IncomingCommand {
channel,
user_id: msg.author.id.to_string(),
command,
};
if let Err(e) = self.command_tx.send(incoming) {
tracing::error!("[discord] failed to forward command: {e}");
}
}
async fn ready(&self, _ctx: serenity::prelude::Context, ready: Ready) {
tracing::info!("[discord] connected as {}", ready.user.name);
}
async fn interaction_create(&self, ctx: serenity::prelude::Context, interaction: Interaction) {
match interaction {
Interaction::Component(component) => {
let custom_id = &component.data.custom_id;
let command = parse_remote_command(custom_id);
let channel = ColletChannelId {
platform: "discord".into(),
channel: component.channel_id.to_string(),
thread: None,
};
let incoming = IncomingCommand {
channel,
user_id: component.user.id.to_string(),
command,
};
if let Err(e) = self.command_tx.send(incoming) {
tracing::error!("[discord] failed to forward button interaction: {e}");
}
if let Err(e) = component
.create_response(&ctx.http, CreateInteractionResponse::Acknowledge)
.await
{
tracing::error!("[discord] failed to acknowledge interaction: {e}");
}
}
Interaction::Command(cmd) => {
let command_name = format!("/{}", cmd.data.name);
let command = parse_remote_command(&command_name);
let channel = ColletChannelId {
platform: "discord".into(),
channel: cmd.channel_id.to_string(),
thread: None,
};
let incoming = IncomingCommand {
channel,
user_id: cmd.user.id.to_string(),
command,
};
if let Err(e) = self.command_tx.send(incoming) {
tracing::error!("[discord] failed to forward slash command: {e}");
}
if let Err(e) = cmd
.create_response(&ctx.http, CreateInteractionResponse::Acknowledge)
.await
{
tracing::error!("[discord] failed to acknowledge slash command: {e}");
}
}
_ => {}
}
}
}