use crate::channel::{
Attachment, Channel, IncomingMessage, MAX_ATTACHMENT_BYTES, OutgoingMessage, RoomInfo,
};
use crate::config::DiscordConfig;
use anyhow::{Context, Result};
use async_trait::async_trait;
use serenity::Client;
use serenity::all::{
ChannelId, Context as SerenityCtx, EventHandler, GatewayIntents, Message, Ready,
};
use serenity::http::Http;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::{OnceCell, mpsc};
use tracing::{debug, info, warn};
struct DiscordHandler {
tx: mpsc::Sender<IncomingMessage>,
allowed_channel_ids: HashSet<u64>,
allowed_user_ids: HashSet<u64>,
}
#[serenity::async_trait]
impl EventHandler for DiscordHandler {
async fn ready(&self, _ctx: SerenityCtx, ready: Ready) {
info!("Discord bot connected as {}", ready.user.name);
}
async fn message(&self, _ctx: SerenityCtx, msg: Message) {
if msg.author.bot {
return;
}
let channel_id = msg.channel_id.get();
if !self.allowed_channel_ids.is_empty() && !self.allowed_channel_ids.contains(&channel_id) {
debug!("Ignoring message from channel {channel_id} (not in allowed list)");
return;
}
let user_id = msg.author.id.get();
if !self.allowed_user_ids.is_empty() && !self.allowed_user_ids.contains(&user_id) {
debug!("Ignoring message from user {user_id} (not in allowed list)");
return;
}
let content = msg.content.trim().to_string();
let attachments = download_image_attachments(&msg).await;
if content.is_empty() && attachments.is_empty() {
return;
}
let incoming = IncomingMessage {
id: msg.id.to_string(),
sender: msg.author.id.to_string(),
content,
room_id: msg.channel_id.to_string(),
timestamp: msg.timestamp.unix_timestamp() as u64 * 1000,
thread_id: None,
attachments,
};
if let Err(e) = self.tx.send(incoming).await {
warn!("Failed to forward Discord message: {e}");
}
}
}
pub struct DiscordChannel {
token: String,
channel_ids: HashSet<u64>,
allowed_user_ids: HashSet<u64>,
http: Arc<OnceCell<Arc<Http>>>,
}
impl std::fmt::Debug for DiscordChannel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DiscordChannel")
.field("channel_ids", &self.channel_ids)
.finish_non_exhaustive()
}
}
impl DiscordChannel {
pub fn new(cfg: &DiscordConfig) -> Result<Self> {
let channel_ids = cfg
.channel_ids
.iter()
.map(|s| s.parse::<u64>().context("Invalid Discord channel_id"))
.collect::<Result<HashSet<_>>>()?;
let allowed_user_ids = cfg
.allowed_users
.iter()
.map(|s| {
s.parse::<u64>()
.context("Invalid Discord user ID in allowed_users")
})
.collect::<Result<HashSet<_>>>()?;
Ok(Self {
token: cfg.bot_token.clone(),
channel_ids,
allowed_user_ids,
http: Arc::new(OnceCell::new()),
})
}
fn get_http_or_new(&self) -> Arc<Http> {
if let Some(http) = self.http.get() {
Arc::clone(http)
} else {
Arc::new(Http::new(&self.token))
}
}
}
#[async_trait]
impl Channel for DiscordChannel {
fn name(&self) -> &str {
"discord"
}
async fn send(&self, message: &OutgoingMessage) -> Result<()> {
let http = self.get_http_or_new();
let channel_id: u64 = message
.room_id
.parse()
.context("Discord room_id is not a valid channel ID")?;
let channel_id = ChannelId::new(channel_id);
for chunk in split_for_discord(&message.content) {
channel_id
.say(http.as_ref(), chunk)
.await
.context("Failed to send Discord message")?;
}
Ok(())
}
async fn listen(&self, tx: mpsc::Sender<IncomingMessage>) -> Result<()> {
let intents = GatewayIntents::GUILD_MESSAGES
| GatewayIntents::DIRECT_MESSAGES
| GatewayIntents::MESSAGE_CONTENT;
let min_backoff = std::time::Duration::from_secs(1);
let max_backoff = std::time::Duration::from_secs(300);
let stable_threshold = std::time::Duration::from_secs(60);
let mut backoff = min_backoff;
loop {
let handler = DiscordHandler {
tx: tx.clone(),
allowed_channel_ids: self.channel_ids.clone(),
allowed_user_ids: self.allowed_user_ids.clone(),
};
let mut client = Client::builder(&self.token, intents)
.event_handler(handler)
.await
.context("Failed to build Discord client")?;
let _ = self.http.set(Arc::clone(&client.http));
info!("Starting Discord gateway...");
let started = std::time::Instant::now();
match client.start().await {
Ok(()) => {
warn!("Discord gateway exited without error; reconnecting in {backoff:?}");
}
Err(e) => {
warn!("Discord gateway exited with error: {e}; reconnecting in {backoff:?}");
}
}
tokio::time::sleep(backoff).await;
if started.elapsed() >= stable_threshold {
backoff = min_backoff;
} else {
backoff = (backoff * 2).min(max_backoff);
}
info!("Reconnecting Discord gateway...");
}
}
async fn start_typing(&self, room_id: &str) -> Result<()> {
let http = self.get_http_or_new();
let channel_id: u64 = room_id.parse().context("Invalid Discord channel ID")?;
ChannelId::new(channel_id)
.broadcast_typing(http.as_ref())
.await
.context("Failed to send typing indicator")?;
Ok(())
}
async fn room_info(&self, room_id: &str) -> Option<RoomInfo> {
use serenity::all::{Channel as SerenityChannel, ChannelType};
let channel_id: u64 = room_id.parse().ok()?;
let http = self.get_http_or_new();
let channel = ChannelId::new(channel_id)
.to_channel(http.as_ref())
.await
.ok()?;
match channel {
SerenityChannel::Guild(gc) => {
let kind = match gc.kind {
ChannelType::Voice | ChannelType::Stage => "discord-voice",
_ => "discord",
};
Some(RoomInfo {
name: gc.name.clone(),
description: gc.topic.clone().filter(|t| !t.is_empty()),
kind: kind.to_string(),
})
}
SerenityChannel::Private(pc) => Some(RoomInfo {
name: format!("DM with {}", pc.recipient.name),
description: None,
kind: "discord-dm".to_string(),
}),
_ => None,
}
}
}
async fn download_image_attachments(msg: &Message) -> Vec<Attachment> {
let mut out = Vec::new();
for att in &msg.attachments {
let Some(ct) = att.content_type.as_deref() else {
continue;
};
const SUPPORTED: &[&str] = &["image/jpeg", "image/png", "image/gif", "image/webp"];
if !SUPPORTED.contains(&ct) {
warn!(
"Discord image '{}' has unsupported MIME type '{}'; skipping",
att.filename, ct
);
continue;
}
if (att.size as usize) > MAX_ATTACHMENT_BYTES {
warn!(
"Discord image '{}' is {} bytes (>5MB); skipping",
att.filename, att.size
);
continue;
}
match att.download().await {
Ok(bytes) if bytes.len() <= MAX_ATTACHMENT_BYTES => {
out.push(Attachment {
media_type: ct.to_string(),
data: bytes,
});
}
Ok(bytes) => warn!(
"Discord image '{}' decoded to {} bytes (>5MB); skipping",
att.filename,
bytes.len()
),
Err(e) => warn!(
"Failed to download Discord attachment '{}': {e}",
att.filename
),
}
}
out
}
fn split_for_discord(content: &str) -> Vec<String> {
const LIMIT: usize = 1990; if content.len() <= LIMIT {
return vec![content.to_owned()];
}
let mut chunks = Vec::new();
let mut remaining = content;
while remaining.len() > LIMIT {
let mut split = LIMIT;
while !remaining.is_char_boundary(split) {
split -= 1;
}
if let Some(nl) = remaining[..split].rfind('\n') {
split = nl + 1;
}
chunks.push(remaining[..split].to_owned());
remaining = remaining[split..].trim_start_matches('\n');
}
if !remaining.is_empty() {
chunks.push(remaining.to_owned());
}
chunks
}