Skip to main content

rig/integrations/
discord_bot.rs

1//! Integration for deploying your Rig agents (and more) as Discord bots.
2//! This feature is not WASM-compatible (and as such, is incompatible with the `worker` feature).
3use crate::OneOrMany;
4use crate::agent::Agent;
5use crate::completion::{AssistantContent, CompletionModel, request::Chat};
6use crate::message::{Message as RigMessage, UserContent};
7use serenity::all::{
8    Command, CommandInteraction, Context, CreateCommand, CreateThread, EventHandler,
9    GatewayIntents, Interaction, Message, Ready, async_trait,
10};
11use std::collections::HashMap;
12use std::env;
13use std::sync::Arc;
14use thiserror::Error;
15use tokio::sync::RwLock;
16
17#[derive(Debug, Error)]
18pub enum DiscordBotError {
19    #[error("Discord bot token missing from environment: {0}")]
20    MissingToken(#[from] env::VarError),
21    #[error("Failed to build Discord client: {0}")]
22    ClientBuild(#[from] serenity::Error),
23}
24
25// Bot state containing the agent and conversation histories
26struct BotState<M: CompletionModel> {
27    agent: Agent<M>,
28    conversations: Arc<RwLock<HashMap<u64, Vec<RigMessage>>>>,
29}
30
31impl<M: CompletionModel> BotState<M> {
32    fn new(agent: Agent<M>) -> Self {
33        Self {
34            agent,
35            conversations: Arc::new(RwLock::new(HashMap::new())),
36        }
37    }
38}
39
40// Event handler for the Discord bot
41struct Handler<M: CompletionModel> {
42    state: Arc<BotState<M>>,
43}
44
45#[async_trait]
46impl<M> EventHandler for Handler<M>
47where
48    M: CompletionModel + Send + Sync + 'static,
49{
50    async fn ready(&self, ctx: Context, ready: Ready) {
51        println!("{} is connected!", ready.user.name);
52
53        let register_cmd =
54            CreateCommand::new("new").description("Start a new chat session with the bot");
55
56        // Register slash command globally
57        let command = Command::create_global_command(&ctx.http, register_cmd).await;
58
59        match command {
60            Ok(cmd) => println!("Registered global command: {}", cmd.name),
61            Err(e) => eprintln!("Failed to register command: {}", e),
62        }
63    }
64
65    async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
66        if let Interaction::Command(command) = interaction {
67            self.handle_command(&ctx, &command).await;
68        }
69    }
70
71    async fn message(&self, ctx: Context, msg: Message) {
72        // Ignore bot's own messages
73        if msg.author.bot {
74            return;
75        }
76
77        // Only respond to messages in threads created by the bot
78        let conversations = self.state.conversations.read().await;
79        if conversations.contains_key(&msg.channel_id.get()) {
80            drop(conversations);
81            self.handle_thread_message(&ctx, &msg).await;
82        }
83    }
84}
85
86impl<M> Handler<M>
87where
88    M: CompletionModel + Send + Sync + 'static,
89{
90    async fn handle_command(&self, ctx: &Context, command: &CommandInteraction) {
91        if command.data.name.as_str() == "new" {
92            // Defer the response to prevent timeout
93            if let Err(e) = command.defer(&ctx.http).await {
94                eprintln!("Failed to defer command: {}", e);
95                return;
96            }
97
98            // Create a new thread
99            let thread_name = format!("AI Conversation - {}", command.user.name);
100
101            let thread = match command
102                .channel_id
103                .create_thread(
104                    &ctx.http,
105                    CreateThread::new(thread_name)
106                        .kind(serenity::all::ChannelType::PublicThread)
107                        .auto_archive_duration(serenity::all::AutoArchiveDuration::OneDay),
108                )
109                .await
110            {
111                Ok(t) => t,
112                Err(e) => {
113                    eprintln!("Failed to create thread: {}", e);
114                    let _ = command
115                        .edit_response(
116                            &ctx.http,
117                            serenity::all::EditInteractionResponse::new()
118                                .content("Failed to create thread. Please try again."),
119                        )
120                        .await;
121                    return;
122                }
123            };
124
125            // Initialize conversation history for this thread
126            let mut conversations = self.state.conversations.write().await;
127            conversations.insert(thread.id.get(), Vec::new());
128            drop(conversations);
129
130            // Edit the deferred response
131            if let Err(e) = command
132                .edit_response(
133                    &ctx.http,
134                    serenity::all::EditInteractionResponse::new()
135                        .content(format!(
136                            "Started a new conversation in <#{}>! Send messages there to chat with the AI.",
137                            thread.id
138                        ))
139                )
140                .await
141            {
142                eprintln!("Failed to edit response: {}", e);
143            }
144
145            // Send welcome message to the thread
146            if let Err(e) = thread
147                .send_message(
148                    &ctx.http,
149                    serenity::all::CreateMessage::new()
150                        .content("Hello! I'm ready to help. What would you like to talk about?"),
151                )
152                .await
153            {
154                eprintln!("Failed to send welcome message: {}", e);
155            }
156        }
157    }
158
159    async fn handle_thread_message(&self, ctx: &Context, msg: &Message) {
160        let thread_id = msg.channel_id.get();
161
162        // Add user message to history
163        {
164            let mut conversations = self.state.conversations.write().await;
165            if let Some(history) = conversations.get_mut(&thread_id) {
166                history.push(RigMessage::User {
167                    content: OneOrMany::one(UserContent::text(msg.content.clone())),
168                });
169            }
170        }
171
172        // Show typing indicator
173        let _ = msg.channel_id.broadcast_typing(&ctx.http).await;
174
175        // Get conversation history
176        let conversations = self.state.conversations.read().await;
177        let history = if let Some(history) = conversations.get(&thread_id) {
178            history.clone()
179        } else {
180            vec![]
181        };
182        drop(conversations);
183
184        // Generate response using the agent with conversation history
185        let response = match self.state.agent.chat(&msg.content, history).await {
186            Ok(resp) => resp,
187            Err(e) => {
188                eprintln!("Agent error: {}", e);
189                let _ = msg
190                    .channel_id
191                    .say(
192                        &ctx.http,
193                        "Sorry, I encountered an error processing your message.",
194                    )
195                    .await;
196                return;
197            }
198        };
199
200        // Add assistant response to history
201        {
202            let mut conversations = self.state.conversations.write().await;
203            if let Some(history) = conversations.get_mut(&thread_id) {
204                history.push(RigMessage::Assistant {
205                    content: OneOrMany::one(AssistantContent::text(msg.content.clone())),
206                    id: None,
207                });
208            }
209        }
210
211        // Send response (split if too long for Discord's 2000 char limit)
212        let chunks: Vec<String> = response
213            .chars()
214            .collect::<Vec<_>>()
215            .chunks(1900)
216            .map(|c| c.iter().collect())
217            .collect();
218
219        for chunk in chunks {
220            if let Err(e) = msg.channel_id.say(&ctx.http, &chunk).await {
221                eprintln!("Failed to send message: {}", e);
222            }
223        }
224    }
225}
226
227/// A trait for turning a type into a `serenity` client.
228///
229pub trait DiscordExt: Sized + Send + Sync
230where
231    Self: 'static,
232{
233    fn into_discord_bot(
234        self,
235        token: &str,
236    ) -> impl std::future::Future<Output = Result<serenity::Client, DiscordBotError>> + Send;
237
238    fn into_discord_bot_from_env(
239        self,
240    ) -> impl std::future::Future<Output = Result<serenity::Client, DiscordBotError>> + Send {
241        async move {
242            let token = std::env::var("DISCORD_BOT_TOKEN")?;
243            DiscordExt::into_discord_bot(self, &token).await
244        }
245    }
246}
247
248impl<M> DiscordExt for Agent<M>
249where
250    M: CompletionModel + Send + Sync + 'static,
251{
252    async fn into_discord_bot(self, token: &str) -> Result<serenity::Client, DiscordBotError> {
253        let intents = GatewayIntents::GUILDS
254            | GatewayIntents::GUILD_MESSAGES
255            | GatewayIntents::MESSAGE_CONTENT;
256
257        let state = Arc::new(BotState::new(self));
258        let handler = Handler {
259            state: state.clone(),
260        };
261
262        serenity::Client::builder(token, intents)
263            .event_handler(handler)
264            .await
265            .map_err(DiscordBotError::from)
266    }
267}