rig/integrations/
discord_bot.rs

1//! Integration for deploying your Rig agents (and more) as Discord bots.
2//! This feature is not WASM-compatible (and as such, is incompatible with the `worker` feature).
3use crate::OneOrMany;
4use crate::agent::Agent;
5use crate::completion::{AssistantContent, CompletionModel, request::Chat};
6use crate::message::{Message as RigMessage, UserContent};
7use serenity::all::{
8    Command, CommandInteraction, Context, CreateCommand, CreateThread, EventHandler,
9    GatewayIntents, Interaction, Message, Ready, async_trait,
10};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14
15// Bot state containing the agent and conversation histories
16struct BotState<M: CompletionModel> {
17    agent: Agent<M>,
18    conversations: Arc<RwLock<HashMap<u64, Vec<RigMessage>>>>,
19}
20
21impl<M: CompletionModel> BotState<M> {
22    fn new(agent: Agent<M>) -> Self {
23        Self {
24            agent,
25            conversations: Arc::new(RwLock::new(HashMap::new())),
26        }
27    }
28}
29
30// Event handler for the Discord bot
31struct Handler<M: CompletionModel> {
32    state: Arc<BotState<M>>,
33}
34
35#[async_trait]
36impl<M> EventHandler for Handler<M>
37where
38    M: CompletionModel + Send + Sync + 'static,
39{
40    async fn ready(&self, ctx: Context, ready: Ready) {
41        println!("{} is connected!", ready.user.name);
42
43        let register_cmd =
44            CreateCommand::new("new").description("Start a new chat session with the bot");
45
46        // Register slash command globally
47        let command = Command::create_global_command(&ctx.http, register_cmd).await;
48
49        match command {
50            Ok(cmd) => println!("Registered global command: {}", cmd.name),
51            Err(e) => eprintln!("Failed to register command: {}", e),
52        }
53    }
54
55    async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
56        if let Interaction::Command(command) = interaction {
57            self.handle_command(&ctx, &command).await;
58        }
59    }
60
61    async fn message(&self, ctx: Context, msg: Message) {
62        // Ignore bot's own messages
63        if msg.author.bot {
64            return;
65        }
66
67        // Only respond to messages in threads created by the bot
68        let conversations = self.state.conversations.read().await;
69        if conversations.contains_key(&msg.channel_id.get()) {
70            drop(conversations);
71            self.handle_thread_message(&ctx, &msg).await;
72        }
73    }
74}
75
76impl<M> Handler<M>
77where
78    M: CompletionModel + Send + Sync + 'static,
79{
80    async fn handle_command(&self, ctx: &Context, command: &CommandInteraction) {
81        if command.data.name.as_str() == "new" {
82            // Defer the response to prevent timeout
83            if let Err(e) = command.defer(&ctx.http).await {
84                eprintln!("Failed to defer command: {}", e);
85                return;
86            }
87
88            // Create a new thread
89            let thread_name = format!("AI Conversation - {}", command.user.name);
90
91            let thread = match command
92                .channel_id
93                .create_thread(
94                    &ctx.http,
95                    CreateThread::new(thread_name)
96                        .kind(serenity::all::ChannelType::PublicThread)
97                        .auto_archive_duration(serenity::all::AutoArchiveDuration::OneDay),
98                )
99                .await
100            {
101                Ok(t) => t,
102                Err(e) => {
103                    eprintln!("Failed to create thread: {}", e);
104                    let _ = command
105                        .edit_response(
106                            &ctx.http,
107                            serenity::all::EditInteractionResponse::new()
108                                .content("Failed to create thread. Please try again."),
109                        )
110                        .await;
111                    return;
112                }
113            };
114
115            // Initialize conversation history for this thread
116            let mut conversations = self.state.conversations.write().await;
117            conversations.insert(thread.id.get(), Vec::new());
118            drop(conversations);
119
120            // Edit the deferred response
121            if let Err(e) = command
122                .edit_response(
123                    &ctx.http,
124                    serenity::all::EditInteractionResponse::new()
125                        .content(format!(
126                            "Started a new conversation in <#{}>! Send messages there to chat with the AI.",
127                            thread.id
128                        ))
129                )
130                .await
131            {
132                eprintln!("Failed to edit response: {}", e);
133            }
134
135            // Send welcome message to the thread
136            if let Err(e) = thread
137                .send_message(
138                    &ctx.http,
139                    serenity::all::CreateMessage::new()
140                        .content("Hello! I'm ready to help. What would you like to talk about?"),
141                )
142                .await
143            {
144                eprintln!("Failed to send welcome message: {}", e);
145            }
146        }
147    }
148
149    async fn handle_thread_message(&self, ctx: &Context, msg: &Message) {
150        let thread_id = msg.channel_id.get();
151
152        // Add user message to history
153        {
154            let mut conversations = self.state.conversations.write().await;
155            if let Some(history) = conversations.get_mut(&thread_id) {
156                history.push(RigMessage::User {
157                    content: OneOrMany::one(UserContent::text(msg.content.clone())),
158                });
159            }
160        }
161
162        // Show typing indicator
163        let _ = msg.channel_id.broadcast_typing(&ctx.http).await;
164
165        // Get conversation history
166        let conversations = self.state.conversations.read().await;
167        let history = if let Some(history) = conversations.get(&thread_id) {
168            history.clone()
169        } else {
170            vec![]
171        };
172        drop(conversations);
173
174        // Generate response using the agent with conversation history
175        let response = match self.state.agent.chat(&msg.content, history).await {
176            Ok(resp) => resp,
177            Err(e) => {
178                eprintln!("Agent error: {}", e);
179                let _ = msg
180                    .channel_id
181                    .say(
182                        &ctx.http,
183                        "Sorry, I encountered an error processing your message.",
184                    )
185                    .await;
186                return;
187            }
188        };
189
190        // Add assistant response to history
191        {
192            let mut conversations = self.state.conversations.write().await;
193            if let Some(history) = conversations.get_mut(&thread_id) {
194                history.push(RigMessage::Assistant {
195                    content: OneOrMany::one(AssistantContent::text(msg.content.clone())),
196                    id: None,
197                });
198            }
199        }
200
201        // Send response (split if too long for Discord's 2000 char limit)
202        let chunks: Vec<String> = response
203            .chars()
204            .collect::<Vec<_>>()
205            .chunks(1900)
206            .map(|c| c.iter().collect())
207            .collect();
208
209        for chunk in chunks {
210            if let Err(e) = msg.channel_id.say(&ctx.http, &chunk).await {
211                eprintln!("Failed to send message: {}", e);
212            }
213        }
214    }
215}
216
217/// A trait for turning a type into a `serenity` client.
218///
219pub trait DiscordExt: Sized + Send + Sync
220where
221    Self: 'static,
222{
223    fn into_discord_bot(
224        self,
225        token: &str,
226    ) -> impl std::future::Future<Output = serenity::Client> + Send;
227
228    fn into_discord_bot_from_env(
229        self,
230    ) -> impl std::future::Future<Output = serenity::Client> + Send {
231        let token = std::env::var("DISCORD_BOT_TOKEN")
232            .expect("DISCORD_BOT_TOKEN should exist as an env var");
233
234        async move { DiscordExt::into_discord_bot(self, &token).await }
235    }
236}
237
238impl<M> DiscordExt for Agent<M>
239where
240    M: CompletionModel + Send + Sync + 'static,
241{
242    async fn into_discord_bot(self, token: &str) -> serenity::Client {
243        let intents = GatewayIntents::GUILDS
244            | GatewayIntents::GUILD_MESSAGES
245            | GatewayIntents::MESSAGE_CONTENT;
246
247        let state = Arc::new(BotState::new(self));
248        let handler = Handler {
249            state: state.clone(),
250        };
251
252        serenity::Client::builder(token, intents)
253            .event_handler(handler)
254            .await
255            .expect("Failed to create Discord client")
256    }
257}