rig/integrations/
discord_bot.rs1use crate::OneOrMany;
4use crate::agent::Agent;
5use crate::completion::{AssistantContent, CompletionModel, request::Chat};
6use crate::message::{Message as RigMessage, UserContent};
7use serenity::all::{
8 Command, CommandInteraction, Context, CreateCommand, CreateThread, EventHandler,
9 GatewayIntents, Interaction, Message, Ready, async_trait,
10};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14
15struct BotState<M: CompletionModel> {
17 agent: Agent<M>,
18 conversations: Arc<RwLock<HashMap<u64, Vec<RigMessage>>>>,
19}
20
21impl<M: CompletionModel> BotState<M> {
22 fn new(agent: Agent<M>) -> Self {
23 Self {
24 agent,
25 conversations: Arc::new(RwLock::new(HashMap::new())),
26 }
27 }
28}
29
30struct Handler<M: CompletionModel> {
32 state: Arc<BotState<M>>,
33}
34
35#[async_trait]
36impl<M> EventHandler for Handler<M>
37where
38 M: CompletionModel + Send + Sync + 'static,
39{
40 async fn ready(&self, ctx: Context, ready: Ready) {
41 println!("{} is connected!", ready.user.name);
42
43 let register_cmd =
44 CreateCommand::new("new").description("Start a new chat session with the bot");
45
46 let command = Command::create_global_command(&ctx.http, register_cmd).await;
48
49 match command {
50 Ok(cmd) => println!("Registered global command: {}", cmd.name),
51 Err(e) => eprintln!("Failed to register command: {}", e),
52 }
53 }
54
55 async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
56 if let Interaction::Command(command) = interaction {
57 self.handle_command(&ctx, &command).await;
58 }
59 }
60
61 async fn message(&self, ctx: Context, msg: Message) {
62 if msg.author.bot {
64 return;
65 }
66
67 let conversations = self.state.conversations.read().await;
69 if conversations.contains_key(&msg.channel_id.get()) {
70 drop(conversations);
71 self.handle_thread_message(&ctx, &msg).await;
72 }
73 }
74}
75
76impl<M> Handler<M>
77where
78 M: CompletionModel + Send + Sync + 'static,
79{
80 async fn handle_command(&self, ctx: &Context, command: &CommandInteraction) {
81 if command.data.name.as_str() == "new" {
82 if let Err(e) = command.defer(&ctx.http).await {
84 eprintln!("Failed to defer command: {}", e);
85 return;
86 }
87
88 let thread_name = format!("AI Conversation - {}", command.user.name);
90
91 let thread = match command
92 .channel_id
93 .create_thread(
94 &ctx.http,
95 CreateThread::new(thread_name)
96 .kind(serenity::all::ChannelType::PublicThread)
97 .auto_archive_duration(serenity::all::AutoArchiveDuration::OneDay),
98 )
99 .await
100 {
101 Ok(t) => t,
102 Err(e) => {
103 eprintln!("Failed to create thread: {}", e);
104 let _ = command
105 .edit_response(
106 &ctx.http,
107 serenity::all::EditInteractionResponse::new()
108 .content("Failed to create thread. Please try again."),
109 )
110 .await;
111 return;
112 }
113 };
114
115 let mut conversations = self.state.conversations.write().await;
117 conversations.insert(thread.id.get(), Vec::new());
118 drop(conversations);
119
120 if let Err(e) = command
122 .edit_response(
123 &ctx.http,
124 serenity::all::EditInteractionResponse::new()
125 .content(format!(
126 "Started a new conversation in <#{}>! Send messages there to chat with the AI.",
127 thread.id
128 ))
129 )
130 .await
131 {
132 eprintln!("Failed to edit response: {}", e);
133 }
134
135 if let Err(e) = thread
137 .send_message(
138 &ctx.http,
139 serenity::all::CreateMessage::new()
140 .content("Hello! I'm ready to help. What would you like to talk about?"),
141 )
142 .await
143 {
144 eprintln!("Failed to send welcome message: {}", e);
145 }
146 }
147 }
148
149 async fn handle_thread_message(&self, ctx: &Context, msg: &Message) {
150 let thread_id = msg.channel_id.get();
151
152 {
154 let mut conversations = self.state.conversations.write().await;
155 if let Some(history) = conversations.get_mut(&thread_id) {
156 history.push(RigMessage::User {
157 content: OneOrMany::one(UserContent::text(msg.content.clone())),
158 });
159 }
160 }
161
162 let _ = msg.channel_id.broadcast_typing(&ctx.http).await;
164
165 let conversations = self.state.conversations.read().await;
167 let history = if let Some(history) = conversations.get(&thread_id) {
168 history.clone()
169 } else {
170 vec![]
171 };
172 drop(conversations);
173
174 let response = match self.state.agent.chat(&msg.content, history).await {
176 Ok(resp) => resp,
177 Err(e) => {
178 eprintln!("Agent error: {}", e);
179 let _ = msg
180 .channel_id
181 .say(
182 &ctx.http,
183 "Sorry, I encountered an error processing your message.",
184 )
185 .await;
186 return;
187 }
188 };
189
190 {
192 let mut conversations = self.state.conversations.write().await;
193 if let Some(history) = conversations.get_mut(&thread_id) {
194 history.push(RigMessage::Assistant {
195 content: OneOrMany::one(AssistantContent::text(msg.content.clone())),
196 id: None,
197 });
198 }
199 }
200
201 let chunks: Vec<String> = response
203 .chars()
204 .collect::<Vec<_>>()
205 .chunks(1900)
206 .map(|c| c.iter().collect())
207 .collect();
208
209 for chunk in chunks {
210 if let Err(e) = msg.channel_id.say(&ctx.http, &chunk).await {
211 eprintln!("Failed to send message: {}", e);
212 }
213 }
214 }
215}
216
217pub trait DiscordExt: Sized + Send + Sync
220where
221 Self: 'static,
222{
223 fn into_discord_bot(
224 self,
225 token: &str,
226 ) -> impl std::future::Future<Output = serenity::Client> + Send;
227
228 fn into_discord_bot_from_env(
229 self,
230 ) -> impl std::future::Future<Output = serenity::Client> + Send {
231 let token = std::env::var("DISCORD_BOT_TOKEN")
232 .expect("DISCORD_BOT_TOKEN should exist as an env var");
233
234 async move { DiscordExt::into_discord_bot(self, &token).await }
235 }
236}
237
238impl<M> DiscordExt for Agent<M>
239where
240 M: CompletionModel + Send + Sync + 'static,
241{
242 async fn into_discord_bot(self, token: &str) -> serenity::Client {
243 let intents = GatewayIntents::GUILDS
244 | GatewayIntents::GUILD_MESSAGES
245 | GatewayIntents::MESSAGE_CONTENT;
246
247 let state = Arc::new(BotState::new(self));
248 let handler = Handler {
249 state: state.clone(),
250 };
251
252 serenity::Client::builder(token, intents)
253 .event_handler(handler)
254 .await
255 .expect("Failed to create Discord client")
256 }
257}