Skip to main content

rig_core/integrations/
discord_bot.rs

1//! Integration for deploying your Rig agents (and more) as Discord bots.
2//! This feature is not WASM-compatible (and as such, is incompatible with the `worker` feature).
3use crate::agent::Agent;
4use crate::completion::{CompletionModel, request::Chat};
5use crate::message::Message as RigMessage;
6use serenity::all::{
7    Command, CommandInteraction, Context, CreateCommand, CreateThread, EventHandler,
8    GatewayIntents, Interaction, Message, Ready, async_trait,
9};
10use std::collections::HashMap;
11use std::env;
12use std::sync::Arc;
13use thiserror::Error;
14use tokio::sync::RwLock;
15
16#[derive(Debug, Error)]
17pub enum DiscordBotError {
18    #[error("Discord bot token missing from environment: {0}")]
19    MissingToken(#[from] env::VarError),
20    #[error("Failed to build Discord client: {0}")]
21    ClientBuild(#[from] serenity::Error),
22}
23
24// Bot state containing the agent and conversation histories
25struct BotState<M: CompletionModel> {
26    agent: Agent<M>,
27    conversations: Arc<RwLock<HashMap<u64, Vec<RigMessage>>>>,
28}
29
30impl<M: CompletionModel> BotState<M> {
31    fn new(agent: Agent<M>) -> Self {
32        Self {
33            agent,
34            conversations: Arc::new(RwLock::new(HashMap::new())),
35        }
36    }
37}
38
39// Event handler for the Discord bot
40struct Handler<M: CompletionModel> {
41    state: Arc<BotState<M>>,
42}
43
44#[async_trait]
45impl<M> EventHandler for Handler<M>
46where
47    M: CompletionModel + Send + Sync + 'static,
48{
49    async fn ready(&self, ctx: Context, ready: Ready) {
50        println!("{} is connected!", ready.user.name);
51
52        let register_cmd =
53            CreateCommand::new("new").description("Start a new chat session with the bot");
54
55        // Register slash command globally
56        let command = Command::create_global_command(&ctx.http, register_cmd).await;
57
58        match command {
59            Ok(cmd) => println!("Registered global command: {}", cmd.name),
60            Err(e) => eprintln!("Failed to register command: {}", e),
61        }
62    }
63
64    async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
65        if let Interaction::Command(command) = interaction {
66            self.handle_command(&ctx, &command).await;
67        }
68    }
69
70    async fn message(&self, ctx: Context, msg: Message) {
71        // Ignore bot's own messages
72        if msg.author.bot {
73            return;
74        }
75
76        // Only respond to messages in threads created by the bot
77        let conversations = self.state.conversations.read().await;
78        if conversations.contains_key(&msg.channel_id.get()) {
79            drop(conversations);
80            self.handle_thread_message(&ctx, &msg).await;
81        }
82    }
83}
84
85impl<M> Handler<M>
86where
87    M: CompletionModel + Send + Sync + 'static,
88{
89    async fn handle_command(&self, ctx: &Context, command: &CommandInteraction) {
90        if command.data.name.as_str() == "new" {
91            // Defer the response to prevent timeout
92            if let Err(e) = command.defer(&ctx.http).await {
93                eprintln!("Failed to defer command: {}", e);
94                return;
95            }
96
97            // Create a new thread
98            let thread_name = format!("AI Conversation - {}", command.user.name);
99
100            let thread = match command
101                .channel_id
102                .create_thread(
103                    &ctx.http,
104                    CreateThread::new(thread_name)
105                        .kind(serenity::all::ChannelType::PublicThread)
106                        .auto_archive_duration(serenity::all::AutoArchiveDuration::OneDay),
107                )
108                .await
109            {
110                Ok(t) => t,
111                Err(e) => {
112                    eprintln!("Failed to create thread: {}", e);
113                    let _ = command
114                        .edit_response(
115                            &ctx.http,
116                            serenity::all::EditInteractionResponse::new()
117                                .content("Failed to create thread. Please try again."),
118                        )
119                        .await;
120                    return;
121                }
122            };
123
124            // Initialize conversation history for this thread
125            let mut conversations = self.state.conversations.write().await;
126            conversations.insert(thread.id.get(), Vec::new());
127            drop(conversations);
128
129            // Edit the deferred response
130            if let Err(e) = command
131                .edit_response(
132                    &ctx.http,
133                    serenity::all::EditInteractionResponse::new()
134                        .content(format!(
135                            "Started a new conversation in <#{}>! Send messages there to chat with the AI.",
136                            thread.id
137                        ))
138                )
139                .await
140            {
141                eprintln!("Failed to edit response: {}", e);
142            }
143
144            // Send welcome message to the thread
145            if let Err(e) = thread
146                .send_message(
147                    &ctx.http,
148                    serenity::all::CreateMessage::new()
149                        .content("Hello! I'm ready to help. What would you like to talk about?"),
150                )
151                .await
152            {
153                eprintln!("Failed to send welcome message: {}", e);
154            }
155        }
156    }
157
158    async fn handle_thread_message(&self, ctx: &Context, msg: &Message) {
159        let thread_id = msg.channel_id.get();
160
161        // Show typing indicator
162        let _ = msg.channel_id.broadcast_typing(&ctx.http).await;
163
164        // Get conversation history snapshot.
165        let conversations = self.state.conversations.read().await;
166        let mut history = if let Some(history) = conversations.get(&thread_id) {
167            history.clone()
168        } else {
169            vec![]
170        };
171        drop(conversations);
172
173        // Generate response. `chat` appends the user prompt and generated
174        // assistant/tool messages onto the history snapshot.
175        let response = match self.state.agent.chat(&msg.content, &mut history).await {
176            Ok(resp) => resp,
177            Err(e) => {
178                eprintln!("Agent error: {}", e);
179                let _ = msg
180                    .channel_id
181                    .say(
182                        &ctx.http,
183                        "Sorry, I encountered an error processing your message.",
184                    )
185                    .await;
186                return;
187            }
188        };
189
190        // Persist the round-tripped history back into the conversations map.
191        {
192            let mut conversations = self.state.conversations.write().await;
193            conversations.insert(thread_id, history);
194        }
195
196        // Send response (split if too long for Discord's 2000 char limit)
197        let chunks: Vec<String> = response
198            .chars()
199            .collect::<Vec<_>>()
200            .chunks(1900)
201            .map(|c| c.iter().collect())
202            .collect();
203
204        for chunk in chunks {
205            if let Err(e) = msg.channel_id.say(&ctx.http, &chunk).await {
206                eprintln!("Failed to send message: {}", e);
207            }
208        }
209    }
210}
211
212/// A trait for turning a type into a `serenity` client.
213///
214pub trait DiscordExt: Sized + Send + Sync
215where
216    Self: 'static,
217{
218    fn into_discord_bot(
219        self,
220        token: &str,
221    ) -> impl std::future::Future<Output = Result<serenity::Client, DiscordBotError>> + Send;
222
223    fn into_discord_bot_from_env(
224        self,
225    ) -> impl std::future::Future<Output = Result<serenity::Client, DiscordBotError>> + Send {
226        async move {
227            let token = std::env::var("DISCORD_BOT_TOKEN")?;
228            DiscordExt::into_discord_bot(self, &token).await
229        }
230    }
231}
232
233impl<M> DiscordExt for Agent<M>
234where
235    M: CompletionModel + Send + Sync + 'static,
236{
237    async fn into_discord_bot(self, token: &str) -> Result<serenity::Client, DiscordBotError> {
238        let intents = GatewayIntents::GUILDS
239            | GatewayIntents::GUILD_MESSAGES
240            | GatewayIntents::MESSAGE_CONTENT;
241
242        let state = Arc::new(BotState::new(self));
243        let handler = Handler {
244            state: state.clone(),
245        };
246
247        serenity::Client::builder(token, intents)
248            .event_handler(handler)
249            .await
250            .map_err(DiscordBotError::from)
251    }
252}