1use lazy_static::lazy_static;
2use matrix_sdk::ruma::events::room::member::StrippedRoomMemberEvent;
3use matrix_sdk::ruma::events::room::message::MessageType;
4use matrix_sdk::ruma::events::room::message::OriginalSyncRoomMessageEvent;
5use matrix_sdk::ruma::events::room::message::RoomMessageEventContent;
6use matrix_sdk::ruma::events::AnySyncMessageLikeEvent;
7use matrix_sdk::ruma::OwnedUserId;
8use matrix_sdk::RoomMemberships;
9use matrix_sdk::RoomState;
10use matrix_sdk::{
11 config::SyncSettings, matrix_auth::MatrixSession, ruma::api::client::filter::FilterDefinition,
12 Client, Error, LoopCtrl, Room,
13};
14use rand::{distributions::Alphanumeric, thread_rng, Rng};
15use regex::Regex;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::io::{self, Write};
19use std::path::{Path, PathBuf};
20use std::time::Duration;
21use tokio::fs;
22use tokio::sync::Mutex;
23use tokio::time::sleep;
24use tracing::{error, info, warn};
25
26mod utils;
27pub use utils::*;
28
29lazy_static! {
34 static ref GLOBAL_STATE: Mutex<HashMap<String, Mutex<State>>> = Mutex::new(HashMap::new());
37}
38
39#[derive(Debug, Serialize, Deserialize)]
41struct ClientSession {
42 homeserver: String,
44
45 db_path: PathBuf,
47
48 passphrase: String,
50}
51
52struct HelpText {
53 command: String,
55 short: Option<String>,
57 args: Option<String>,
59}
60
61struct State {
62 help: Vec<HelpText>,
64}
65
66#[derive(Debug, Serialize, Deserialize)]
70struct FullSession {
71 client_session: ClientSession,
73
74 user_session: MatrixSession,
76
77 #[serde(skip_serializing_if = "Option::is_none")]
79 sync_token: Option<String>,
80}
81
82#[derive(Debug, Clone)]
83pub struct Login {
84 pub homeserver_url: String,
86 pub username: String,
88 pub password: Option<String>,
90}
91
92#[derive(Debug, Clone)]
94pub struct BotConfig {
95 pub login: Login,
97 pub name: Option<String>,
100 pub allow_list: Option<String>,
102 pub state_dir: Option<String>,
105 pub command_prefix: Option<String>,
107 pub room_size_limit: Option<usize>,
110}
111
112#[derive(Debug, Clone)]
114pub struct Bot {
115 config: BotConfig,
117
118 sync_token: Option<String>,
120
121 client: Option<Client>,
123}
124
125impl Bot {
126 pub async fn new(config: BotConfig) -> Self {
127 let bot = Bot {
128 config,
129 sync_token: None,
130 client: None,
131 };
132 let mut global_state = GLOBAL_STATE.lock().await;
134 global_state
135 .entry(bot.name())
136 .or_insert_with(|| Mutex::new(State { help: Vec::new() }));
137 bot
138 }
139
140 fn session_file(&self) -> PathBuf {
142 self.state_dir().join("session")
143 }
144
145 pub async fn login(&mut self) -> anyhow::Result<()> {
148 let state_dir = self.state_dir();
149 let session_file = self.session_file();
150
151 let (client, sync_token) = if session_file.exists() {
152 restore_session(&session_file).await?
153 } else {
154 (
155 login(
156 &state_dir,
157 &session_file,
158 &self.config.login.homeserver_url,
159 &self.config.login.username,
160 &self.config.login.password,
161 )
162 .await?,
163 None,
164 )
165 };
166
167 self.sync_token = sync_token;
168 self.client = Some(client);
169
170 Ok(())
171 }
172
173 pub async fn sync(&mut self) -> anyhow::Result<()> {
175 let client = self.client.as_ref().expect("client not initialized");
176
177 let filter = FilterDefinition::with_lazy_loading();
181 let mut sync_settings = SyncSettings::default().filter(filter.into());
182
183 if let Some(sync_token) = &self.sync_token {
185 sync_settings = sync_settings.token(sync_token);
186 }
187
188 loop {
189 match client.sync_once(sync_settings.clone()).await {
190 Ok(response) => {
191 self.sync_token = Some(response.next_batch.clone());
192 persist_sync_token(&self.session_file(), response.next_batch.clone()).await?;
193 break;
194 }
195 Err(error) => {
196 error!("An error occurred during initial sync: {error}");
197 error!("Trying again…");
198 }
199 }
200 }
201 Ok(())
202 }
203
204 async fn register_help_command(&self) {
207 let name = self.name();
208 let command_prefix = self.command_prefix();
209 self.register_text_command(
210 "help",
211 None,
212 Some("Show this message".to_string()),
213 |_, _, room| async move {
214 let global_state = GLOBAL_STATE.lock().await;
215 let state = global_state.get(&name).unwrap();
216 let state = state.lock().await;
217 let help = &state.help;
218 let mut response = format!("`{}help`\n\nAvailable commands:", command_prefix);
219
220 for h in help {
221 response.push_str(&format!("\n`{}{}", command_prefix, h.command));
222 if let Some(args) = &h.args {
223 response.push_str(&format!(" {}", args));
224 }
225 if let Some(short) = &h.short {
226 response.push_str(&format!("` - {}", short));
227 }
228 }
229 room.send(RoomMessageEventContent::text_markdown(response))
230 .await
231 .map_err(|_| ())?;
232 Ok(())
233 },
234 )
235 .await;
236 }
237
238 pub fn join_rooms(&self) {
241 let client = self.client.as_ref().expect("client not initialized");
242 let allow_list = self.config.allow_list.clone();
243 let username = self.full_name();
244 let room_size_limit = self.config.room_size_limit;
245 client.add_event_handler(
246 move |room_member: StrippedRoomMemberEvent, client: Client, room: Room| async move {
247 if room_member.state_key != client.user_id().unwrap() {
248 return;
250 }
251 if !is_allowed(allow_list, room_member.sender.as_str(), &username) {
252 return;
254 }
255 info!("Received stripped room member event: {:?}", room_member);
256
257 tokio::spawn(async move {
262 info!("Autojoining room {}", room.room_id());
263 let mut delay = 2;
264
265 while let Err(err) = room.join().await {
266 warn!(
270 "Failed to join room {} ({err:?}), retrying in {delay}s",
271 room.room_id()
272 );
273
274 sleep(Duration::from_secs(delay)).await;
275 delay *= 2;
276
277 if delay > 3600 {
278 error!("Can't join room {} ({err:?})", room.room_id());
279 break;
280 }
281 }
282 if is_room_too_large(&room, room_size_limit).await {
284 warn!(
285 "Room {} has too many members, refusing to join",
286 room.room_id()
287 );
288 if let Err(e) = room.leave().await {
289 error!("Error leaving room: {:?}", e);
290 }
291 return;
292 }
293 info!("Successfully joined room {}", room.room_id());
294 });
295 },
296 );
297 }
298
299 pub fn join_rooms_callback<F, Fut>(&self, callback: Option<F>)
303 where
304 F: FnOnce(Room) -> Fut + Send + 'static + Clone + Sync,
305 Fut: std::future::Future<Output = Result<(), ()>> + Send + 'static,
306 {
307 let client = self.client.as_ref().expect("client not initialized");
308 let allow_list = self.config.allow_list.clone();
309 let username = self.full_name();
310 let room_size_limit = self.config.room_size_limit;
311 client.add_event_handler(
312 move |room_member: StrippedRoomMemberEvent, client: Client, room: Room| async move {
313 if room_member.state_key != client.user_id().unwrap() {
314 return;
316 }
317 if !is_allowed(allow_list, room_member.sender.as_str(), &username) {
318 return;
320 }
321 info!("Received stripped room member event: {:?}", room_member);
322
323 tokio::spawn(async move {
328 info!("Autojoining room {}", room.room_id());
329 let mut delay = 2;
330
331 while let Err(err) = room.join().await {
332 warn!(
336 "Failed to join room {} ({err:?}), retrying in {delay}s",
337 room.room_id()
338 );
339
340 sleep(Duration::from_secs(delay)).await;
341 delay *= 2;
342
343 if delay > 3600 {
344 error!("Can't join room {} ({err:?})", room.room_id());
345 break;
346 }
347 }
348 if is_room_too_large(&room, room_size_limit).await {
350 warn!(
351 "Room {} has too many members, refusing to join",
352 room.room_id()
353 );
354 if let Err(e) = room.leave().await {
355 error!("Error leaving room: {:?}", e);
356 }
357 return;
358 }
359 info!("Successfully joined room {}", room.room_id());
360 if let Some(callback) = callback {
361 if let Err(e) = callback(room).await {
362 error!("Error joining room: {:?}", e)
363 }
364 }
365 });
366 },
367 );
368 }
369
370 pub fn register_text_handler<F, Fut>(&self, callback: F)
373 where
374 F: FnOnce(OwnedUserId, String, Room, OriginalSyncRoomMessageEvent) -> Fut
375 + Send
376 + 'static
377 + Clone
378 + Sync,
379 Fut: std::future::Future<Output = Result<(), ()>> + Send + 'static,
380 {
381 let client = self.client.as_ref().expect("client not initialized");
382 let allow_list = self.config.allow_list.clone();
383 let username = self.full_name();
384 let command_prefix = self.command_prefix();
385 client.add_event_handler(
386 move |event: OriginalSyncRoomMessageEvent, room: Room| async move {
387 if room.state() != RoomState::Joined {
389 return;
390 }
391 let MessageType::Text(text_content) = &event.content.msgtype.clone() else {
392 return;
393 };
394 if !is_allowed(allow_list, event.sender.as_str(), &username) {
395 return;
397 }
398 let body = text_content.body.trim_start();
399 if is_command(&command_prefix, body) {
401 return;
402 }
403 if let Err(e) =
404 callback(event.sender.clone(), body.to_string(), room, event.clone()).await
405 {
406 error!("Error responding to: {}\nError: {:?}", body, e);
407 }
408 },
409 );
410 }
411
412 pub async fn register_text_command<F, Fut, OptString>(
418 &self,
419 command: &str,
420 args: OptString,
421 short_help: OptString,
422 callback: F,
423 ) where
424 F: FnOnce(OwnedUserId, String, Room) -> Fut + Send + 'static + Clone + Sync,
425 Fut: std::future::Future<Output = Result<(), ()>> + Send + 'static,
426 OptString: Into<Option<String>>,
427 {
428 {
429 let mut global_state = GLOBAL_STATE.lock().await;
431 let state = global_state.get_mut(&self.name()).unwrap();
432 let mut state = state.lock().await;
433 state.help.push(HelpText {
434 command: command.to_string(),
435 args: args.into(),
436 short: short_help.into(),
437 });
438 }
439 let client = self.client.as_ref().expect("client not initialized");
440 let allow_list = self.config.allow_list.clone();
441 let username = self.full_name();
442 let command = command.to_owned();
443 let command_prefix = self.command_prefix();
444 client.add_event_handler(
445 move |event: AnySyncMessageLikeEvent, room: Room| async move {
447 if room.state() != RoomState::Joined {
449 return;
450 }
451 let AnySyncMessageLikeEvent::RoomMessage(event) = event else {
453 return;
454 };
455 let Some(event) = event.as_original() else {
457 return;
458 };
459 let MessageType::Text(_) = event.content.msgtype else {
461 return;
462 };
463 let text_content = event.content.body();
464 if !is_allowed(allow_list, event.sender.as_str(), &username) {
465 return;
467 }
468 let body = text_content.trim_start();
469 if let Some(input_command) = get_command(&command_prefix, body) {
470 if input_command == command {
471 if let Err(e) = callback(event.sender.clone(), body.to_string(), room).await
473 {
474 error!("Error running command: {} - {:?}", command, e);
475 }
476 }
477 }
478 },
479 );
480 }
481
482 pub async fn run(&self) -> anyhow::Result<()> {
485 self.register_help_command().await;
486 let client = self.client.as_ref().expect("client not initialized");
487
488 let filter = FilterDefinition::with_lazy_loading();
489 let mut sync_settings = SyncSettings::default().filter(filter.into());
490
491 if let Some(sync_token) = &self.sync_token {
493 sync_settings = sync_settings.token(sync_token);
494 }
495 client
497 .sync_with_result_callback(sync_settings, |sync_result| async move {
498 let response = sync_result?;
499
500 self.persist_sync_token(response.next_batch)
502 .await
503 .map_err(|err| Error::UnknownError(err.into()))?;
504
505 Ok(LoopCtrl::Continue)
506 })
507 .await?;
508
509 Ok(())
510 }
511
512 async fn persist_sync_token(&self, sync_token: String) -> anyhow::Result<()> {
513 let serialized_session = fs::read_to_string(self.session_file().clone()).await?;
514 let mut full_session: FullSession = serde_json::from_str(&serialized_session)?;
515
516 full_session.sync_token = Some(sync_token);
517 let serialized_session = serde_json::to_string(&full_session)?;
518 fs::write(self.session_file().clone(), serialized_session).await?;
519
520 Ok(())
521 }
522
523 pub fn state_dir(&self) -> PathBuf {
525 if let Some(state_dir) = &self.config.state_dir {
526 PathBuf::from(expand_tilde(state_dir))
527 } else {
528 dirs::state_dir()
529 .expect("no state_dir directory found")
530 .join(self.name())
531 }
532 }
533
534 pub fn name(&self) -> String {
536 self.config
537 .name
538 .clone()
539 .unwrap_or_else(|| self.config.login.username.clone())
540 }
541
542 pub fn full_name(&self) -> String {
544 self.client().user_id().unwrap().to_string()
545 }
546
547 pub fn client(&self) -> &Client {
549 self.client.as_ref().expect("client not initialized")
550 }
551
552 pub fn command_prefix(&self) -> String {
554 let prefix = self
555 .config
556 .command_prefix
557 .clone()
558 .unwrap_or_else(|| format!("!{} ", self.name()));
559 if prefix.len() == 1 || prefix.ends_with(' ') {
561 prefix
562 } else {
563 format!("{} ", prefix)
564 }
565 }
566}
567
568fn is_allowed(allow_list: Option<String>, sender: &str, username: &str) -> bool {
570 if sender == username {
572 false
573 } else if let Some(allow_list) = allow_list {
574 let regex = Regex::new(&allow_list).expect("Invalid regular expression");
575 regex.is_match(sender)
576 } else {
577 false
578 }
579}
580
581pub fn is_command(command_prefix: &str, text: &str) -> bool {
583 text.starts_with(command_prefix)
584}
585
586pub fn get_command<'a>(command_prefix: &str, text: &'a str) -> Option<&'a str> {
588 if text.starts_with(command_prefix) {
589 text.trim_start_matches(command_prefix)
590 .split_whitespace()
591 .next()
592 } else {
593 None
594 }
595}
596
597fn expand_tilde(path: &str) -> String {
599 if path.starts_with("~/") {
600 if let Some(home_dir) = dirs::home_dir() {
601 let without_tilde = &path[1..]; return home_dir.display().to_string() + without_tilde;
603 }
604 }
605 path.to_string()
606}
607
608async fn restore_session(session_file: &Path) -> anyhow::Result<(Client, Option<String>)> {
610 info!(
611 "Previous session found in '{}'",
612 session_file.to_string_lossy()
613 );
614
615 let serialized_session = fs::read_to_string(session_file).await?;
617 let FullSession {
618 client_session,
619 user_session,
620 sync_token,
621 } = serde_json::from_str(&serialized_session)?;
622
623 let client = Client::builder()
625 .homeserver_url(client_session.homeserver)
626 .sqlite_store(client_session.db_path, Some(&client_session.passphrase))
627 .build()
628 .await?;
629
630 info!("Restoring session for {}…", &user_session.meta.user_id);
631
632 client.restore_session(user_session).await?;
634
635 info!("Done!");
636
637 Ok((client, sync_token))
638}
639
640async fn login(
642 state_dir: &Path,
643 session_file: &Path,
644 homeserver_url: &str,
645 username: &str,
646 password: &Option<String>,
647) -> anyhow::Result<Client> {
648 info!("No previous session found, logging in…");
649
650 let (client, client_session) = build_client(state_dir, homeserver_url.to_owned()).await?;
651 let matrix_auth = client.matrix_auth();
652
653 let password = match password {
655 Some(password) => password.clone(),
656 None => {
657 print!("Password: ");
658 io::stdout().flush().expect("Unable to write to stdout");
659 let mut password = String::new();
660 io::stdin()
661 .read_line(&mut password)
662 .expect("Unable to read user input");
663 password.trim().to_owned()
664 }
665 };
666
667 match matrix_auth
668 .login_username(username, &password)
669 .initial_device_display_name("headjack client")
670 .await
671 {
672 Ok(_) => {
673 info!("Logged in as {username}");
674 }
675 Err(error) => {
676 error!("Error logging in: {error}");
677 return Err(error.into());
678 }
679 }
680
681 let user_session = matrix_auth
683 .session()
684 .expect("A logged-in client should have a session");
685 let serialized_session = serde_json::to_string(&FullSession {
686 client_session,
687 user_session,
688 sync_token: None,
689 })?;
690 fs::write(session_file, serialized_session).await?;
691
692 info!("Session persisted in {}", session_file.to_string_lossy());
693
694 Ok(client)
695}
696
697async fn build_client(
699 state_dir: &Path,
700 homeserver: String,
701) -> anyhow::Result<(Client, ClientSession)> {
702 let mut rng = thread_rng();
703
704 let db_subfolder: String = (&mut rng)
706 .sample_iter(Alphanumeric)
707 .take(7)
708 .map(char::from)
709 .collect();
710 let db_path = state_dir.join(db_subfolder);
711
712 let passphrase: String = (&mut rng)
715 .sample_iter(Alphanumeric)
716 .take(32)
717 .map(char::from)
718 .collect();
719
720 match Client::builder()
721 .homeserver_url(&homeserver)
722 .sqlite_store(&db_path, Some(&passphrase))
726 .build()
727 .await
728 {
729 Ok(client) => Ok((
730 client,
731 ClientSession {
732 homeserver,
733 db_path,
734 passphrase,
735 },
736 )),
737 Err(error) => Err(error.into()),
738 }
739}
740
741async fn persist_sync_token(session_file: &Path, sync_token: String) -> anyhow::Result<()> {
743 let serialized_session = fs::read_to_string(session_file).await?;
744 let mut full_session: FullSession = serde_json::from_str(&serialized_session)?;
745
746 full_session.sync_token = Some(sync_token);
747 let serialized_session = serde_json::to_string(&full_session)?;
748 fs::write(session_file, serialized_session).await?;
749
750 Ok(())
751}
752
753async fn is_room_too_large(room: &Room, room_size_limit: Option<usize>) -> bool {
755 if let Some(room_size_limit) = room_size_limit {
756 if let Ok(members) = room.members(RoomMemberships::ACTIVE).await {
757 members.len() > room_size_limit
758 } else {
759 false
760 }
761 } else {
762 false
763 }
764}