headjack/
lib.rs

1use lazy_static::lazy_static;
2use matrix_sdk::ruma::events::room::member::StrippedRoomMemberEvent;
3use matrix_sdk::ruma::events::room::message::MessageType;
4use matrix_sdk::ruma::events::room::message::OriginalSyncRoomMessageEvent;
5use matrix_sdk::ruma::events::room::message::RoomMessageEventContent;
6use matrix_sdk::ruma::events::AnySyncMessageLikeEvent;
7use matrix_sdk::ruma::OwnedUserId;
8use matrix_sdk::RoomMemberships;
9use matrix_sdk::RoomState;
10use matrix_sdk::{
11    config::SyncSettings, matrix_auth::MatrixSession, ruma::api::client::filter::FilterDefinition,
12    Client, Error, LoopCtrl, Room,
13};
14use rand::{distributions::Alphanumeric, thread_rng, Rng};
15use regex::Regex;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::io::{self, Write};
19use std::path::{Path, PathBuf};
20use std::time::Duration;
21use tokio::fs;
22use tokio::sync::Mutex;
23use tokio::time::sleep;
24use tracing::{error, info, warn};
25
26mod utils;
27pub use utils::*;
28
29// The structure of the matrix rust sdk requires that any state that you need access to in the callbacks
30// is 'static.
31// This is a bit of a pain, so we need to use a global state to store the actual bot state for ease of use.
32
33lazy_static! {
34    ///  Stores the global state for all bots.
35    /// The key is the user ID of the bot
36    static ref GLOBAL_STATE: Mutex<HashMap<String, Mutex<State>>> = Mutex::new(HashMap::new());
37}
38
39/// The data needed to re-build a client.
40#[derive(Debug, Serialize, Deserialize)]
41struct ClientSession {
42    /// The URL of the homeserver of the user.
43    homeserver: String,
44
45    /// The path of the database.
46    db_path: PathBuf,
47
48    /// The passphrase of the database.
49    passphrase: String,
50}
51
52struct HelpText {
53    /// The command string that triggers this command
54    command: String,
55    /// Single line of help text
56    short: Option<String>,
57    /// Argument format.
58    args: Option<String>,
59}
60
61struct State {
62    /// Descriptions of the commands
63    help: Vec<HelpText>,
64}
65
66/// The full session to persist.
67/// It contains the data to re-build the client and the Matrix user session.
68/// This will be synced to disk so that we can restore the session later.
69#[derive(Debug, Serialize, Deserialize)]
70struct FullSession {
71    /// The data to re-build the client.
72    client_session: ClientSession,
73
74    /// The Matrix user session.
75    user_session: MatrixSession,
76
77    /// The latest sync token.
78    #[serde(skip_serializing_if = "Option::is_none")]
79    sync_token: Option<String>,
80}
81
82#[derive(Debug, Clone)]
83pub struct Login {
84    /// The homeserver URL to connect to
85    pub homeserver_url: String,
86    /// The username to login with
87    pub username: String,
88    /// Optionally specify the password, if not set it will be asked for on cmd line
89    pub password: Option<String>,
90}
91
92/// The bot struct, holds all configuration needed for the bot
93#[derive(Debug, Clone)]
94pub struct BotConfig {
95    /// Login info for matrix
96    pub login: Login,
97    /// Name to use for the bot
98    /// Defaults to login.username
99    pub name: Option<String>,
100    /// Allow list of which accounts we will respond to
101    pub allow_list: Option<String>,
102    /// Set the state directory to use
103    /// Defaults to $XDG_STATE_HOME/username
104    pub state_dir: Option<String>,
105    /// Set the prefix for bot commands. Defaults to "!($name) "
106    pub command_prefix: Option<String>,
107    /// The Room size limit.
108    /// Will refuse to join rooms exceeding this limit.
109    pub room_size_limit: Option<usize>,
110}
111
112/// A Matrix Bot
113#[derive(Debug, Clone)]
114pub struct Bot {
115    /// Configuration for the bot.
116    config: BotConfig,
117
118    /// The current sync token.
119    sync_token: Option<String>,
120
121    /// The matrix client.
122    client: Option<Client>,
123}
124
125impl Bot {
126    pub async fn new(config: BotConfig) -> Self {
127        let bot = Bot {
128            config,
129            sync_token: None,
130            client: None,
131        };
132        // Initialize the global state for the bot if it doesn't exist
133        let mut global_state = GLOBAL_STATE.lock().await;
134        global_state
135            .entry(bot.name())
136            .or_insert_with(|| Mutex::new(State { help: Vec::new() }));
137        bot
138    }
139
140    /// Get the path to the session file
141    fn session_file(&self) -> PathBuf {
142        self.state_dir().join("session")
143    }
144
145    /// Login to the matrix server
146    /// Performs everything needed to login or relogin
147    pub async fn login(&mut self) -> anyhow::Result<()> {
148        let state_dir = self.state_dir();
149        let session_file = self.session_file();
150
151        let (client, sync_token) = if session_file.exists() {
152            restore_session(&session_file).await?
153        } else {
154            (
155                login(
156                    &state_dir,
157                    &session_file,
158                    &self.config.login.homeserver_url,
159                    &self.config.login.username,
160                    &self.config.login.password,
161                )
162                .await?,
163                None,
164            )
165        };
166
167        self.sync_token = sync_token;
168        self.client = Some(client);
169
170        Ok(())
171    }
172
173    /// Sync to the current state of the homeserver
174    pub async fn sync(&mut self) -> anyhow::Result<()> {
175        let client = self.client.as_ref().expect("client not initialized");
176
177        // Enable room members lazy-loading, it will speed up the initial sync a lot
178        // with accounts in lots of rooms.
179        // See <https://spec.matrix.org/v1.6/client-server-api/#lazy-loading-room-members>.
180        let filter = FilterDefinition::with_lazy_loading();
181        let mut sync_settings = SyncSettings::default().filter(filter.into());
182
183        // If we've already synced through a certain point, we'll sync the latest.
184        if let Some(sync_token) = &self.sync_token {
185            sync_settings = sync_settings.token(sync_token);
186        }
187
188        loop {
189            match client.sync_once(sync_settings.clone()).await {
190                Ok(response) => {
191                    self.sync_token = Some(response.next_batch.clone());
192                    persist_sync_token(&self.session_file(), response.next_batch.clone()).await?;
193                    break;
194                }
195                Err(error) => {
196                    error!("An error occurred during initial sync: {error}");
197                    error!("Trying again…");
198                }
199            }
200        }
201        Ok(())
202    }
203
204    /// Create the help command
205    /// This adds a command that prints the help
206    async fn register_help_command(&self) {
207        let name = self.name();
208        let command_prefix = self.command_prefix();
209        self.register_text_command(
210            "help",
211            None,
212            Some("Show this message".to_string()),
213            |_, _, room| async move {
214                let global_state = GLOBAL_STATE.lock().await;
215                let state = global_state.get(&name).unwrap();
216                let state = state.lock().await;
217                let help = &state.help;
218                let mut response = format!("`{}help`\n\nAvailable commands:", command_prefix);
219
220                for h in help {
221                    response.push_str(&format!("\n`{}{}", command_prefix, h.command));
222                    if let Some(args) = &h.args {
223                        response.push_str(&format!(" {}", args));
224                    }
225                    if let Some(short) = &h.short {
226                        response.push_str(&format!("` - {}", short));
227                    }
228                }
229                room.send(RoomMessageEventContent::text_markdown(response))
230                    .await
231                    .map_err(|_| ())?;
232                Ok(())
233            },
234        )
235        .await;
236    }
237
238    /// Adds a callback to join rooms we've been invited to
239    /// Ignores invites from anyone who is not on the allow_list
240    pub fn join_rooms(&self) {
241        let client = self.client.as_ref().expect("client not initialized");
242        let allow_list = self.config.allow_list.clone();
243        let username = self.full_name();
244        let room_size_limit = self.config.room_size_limit;
245        client.add_event_handler(
246            move |room_member: StrippedRoomMemberEvent, client: Client, room: Room| async move {
247                if room_member.state_key != client.user_id().unwrap() {
248                    // the invite we've seen isn't for us, but for someone else. ignore
249                    return;
250                }
251                if !is_allowed(allow_list, room_member.sender.as_str(), &username) {
252                    // Sender is not on the allowlist
253                    return;
254                }
255                info!("Received stripped room member event: {:?}", room_member);
256
257                // The event handlers are called before the next sync begins, but
258                // methods that change the state of a room (joining, leaving a room)
259                // wait for the sync to return the new room state so we need to spawn
260                // a new task for them.
261                tokio::spawn(async move {
262                    info!("Autojoining room {}", room.room_id());
263                    let mut delay = 2;
264
265                    while let Err(err) = room.join().await {
266                        // retry autojoin due to synapse sending invites, before the
267                        // invited user can join for more information see
268                        // https://github.com/matrix-org/synapse/issues/4345
269                        warn!(
270                            "Failed to join room {} ({err:?}), retrying in {delay}s",
271                            room.room_id()
272                        );
273
274                        sleep(Duration::from_secs(delay)).await;
275                        delay *= 2;
276
277                        if delay > 3600 {
278                            error!("Can't join room {} ({err:?})", room.room_id());
279                            break;
280                        }
281                    }
282                    // Immediately leave if the room is too large
283                    if is_room_too_large(&room, room_size_limit).await {
284                        warn!(
285                            "Room {} has too many members, refusing to join",
286                            room.room_id()
287                        );
288                        if let Err(e) = room.leave().await {
289                            error!("Error leaving room: {:?}", e);
290                        }
291                        return;
292                    }
293                    info!("Successfully joined room {}", room.room_id());
294                });
295            },
296        );
297    }
298
299    /// Adds a callback to join rooms we've been invited to
300    /// Ignores invites from anyone who is not on the allow_list
301    /// Calls the callback each time a room is joined
302    pub fn join_rooms_callback<F, Fut>(&self, callback: Option<F>)
303    where
304        F: FnOnce(Room) -> Fut + Send + 'static + Clone + Sync,
305        Fut: std::future::Future<Output = Result<(), ()>> + Send + 'static,
306    {
307        let client = self.client.as_ref().expect("client not initialized");
308        let allow_list = self.config.allow_list.clone();
309        let username = self.full_name();
310        let room_size_limit = self.config.room_size_limit;
311        client.add_event_handler(
312            move |room_member: StrippedRoomMemberEvent, client: Client, room: Room| async move {
313                if room_member.state_key != client.user_id().unwrap() {
314                    // the invite we've seen isn't for us, but for someone else. ignore
315                    return;
316                }
317                if !is_allowed(allow_list, room_member.sender.as_str(), &username) {
318                    // Sender is not on the allowlist
319                    return;
320                }
321                info!("Received stripped room member event: {:?}", room_member);
322
323                // The event handlers are called before the next sync begins, but
324                // methods that change the state of a room (joining, leaving a room)
325                // wait for the sync to return the new room state so we need to spawn
326                // a new task for them.
327                tokio::spawn(async move {
328                    info!("Autojoining room {}", room.room_id());
329                    let mut delay = 2;
330
331                    while let Err(err) = room.join().await {
332                        // retry autojoin due to synapse sending invites, before the
333                        // invited user can join for more information see
334                        // https://github.com/matrix-org/synapse/issues/4345
335                        warn!(
336                            "Failed to join room {} ({err:?}), retrying in {delay}s",
337                            room.room_id()
338                        );
339
340                        sleep(Duration::from_secs(delay)).await;
341                        delay *= 2;
342
343                        if delay > 3600 {
344                            error!("Can't join room {} ({err:?})", room.room_id());
345                            break;
346                        }
347                    }
348                    // Immediately leave if the room is too large
349                    if is_room_too_large(&room, room_size_limit).await {
350                        warn!(
351                            "Room {} has too many members, refusing to join",
352                            room.room_id()
353                        );
354                        if let Err(e) = room.leave().await {
355                            error!("Error leaving room: {:?}", e);
356                        }
357                        return;
358                    }
359                    info!("Successfully joined room {}", room.room_id());
360                    if let Some(callback) = callback {
361                        if let Err(e) = callback(room).await {
362                            error!("Error joining room: {:?}", e)
363                        }
364                    }
365                });
366            },
367        );
368    }
369
370    /// Register a command that will be called for every non-command message
371    /// Useful for bots that want to act more like chatbots, having some response to every message
372    pub fn register_text_handler<F, Fut>(&self, callback: F)
373    where
374        F: FnOnce(OwnedUserId, String, Room, OriginalSyncRoomMessageEvent) -> Fut
375            + Send
376            + 'static
377            + Clone
378            + Sync,
379        Fut: std::future::Future<Output = Result<(), ()>> + Send + 'static,
380    {
381        let client = self.client.as_ref().expect("client not initialized");
382        let allow_list = self.config.allow_list.clone();
383        let username = self.full_name();
384        let command_prefix = self.command_prefix();
385        client.add_event_handler(
386            move |event: OriginalSyncRoomMessageEvent, room: Room| async move {
387                // Ignore messages from rooms we're not in
388                if room.state() != RoomState::Joined {
389                    return;
390                }
391                let MessageType::Text(text_content) = &event.content.msgtype.clone() else {
392                    return;
393                };
394                if !is_allowed(allow_list, event.sender.as_str(), &username) {
395                    // Sender is not on the allowlist
396                    return;
397                }
398                let body = text_content.body.trim_start();
399                // _Ignore_ the message if it's a command
400                if is_command(&command_prefix, body) {
401                    return;
402                }
403                if let Err(e) =
404                    callback(event.sender.clone(), body.to_string(), room, event.clone()).await
405                {
406                    error!("Error responding to: {}\nError: {:?}", body, e);
407                }
408            },
409        );
410    }
411
412    /// Register a text command
413    /// This will call the callback when the command is received
414    /// Sending no help text will make the command not show up in the help
415    /// FIXME: This adds a separate handler for every command, this can be made more efficient
416    /// by storing the commands in the State struct
417    pub async fn register_text_command<F, Fut, OptString>(
418        &self,
419        command: &str,
420        args: OptString,
421        short_help: OptString,
422        callback: F,
423    ) where
424        F: FnOnce(OwnedUserId, String, Room) -> Fut + Send + 'static + Clone + Sync,
425        Fut: std::future::Future<Output = Result<(), ()>> + Send + 'static,
426        OptString: Into<Option<String>>,
427    {
428        {
429            // Add the command to the help list
430            let mut global_state = GLOBAL_STATE.lock().await;
431            let state = global_state.get_mut(&self.name()).unwrap();
432            let mut state = state.lock().await;
433            state.help.push(HelpText {
434                command: command.to_string(),
435                args: args.into(),
436                short: short_help.into(),
437            });
438        }
439        let client = self.client.as_ref().expect("client not initialized");
440        let allow_list = self.config.allow_list.clone();
441        let username = self.full_name();
442        let command = command.to_owned();
443        let command_prefix = self.command_prefix();
444        client.add_event_handler(
445            // This handler matches pretty much every sync event, we'll use that and then filter ourselves
446            move |event: AnySyncMessageLikeEvent, room: Room| async move {
447                // Ignore messages from rooms we're not in
448                if room.state() != RoomState::Joined {
449                    return;
450                }
451                // Ignore non-message events
452                let AnySyncMessageLikeEvent::RoomMessage(event) = event else {
453                    return;
454                };
455                // Must be unredacted
456                let Some(event) = event.as_original() else {
457                    return;
458                };
459                // Only look at text messages
460                let MessageType::Text(_) = event.content.msgtype else {
461                    return;
462                };
463                let text_content = event.content.body();
464                if !is_allowed(allow_list, event.sender.as_str(), &username) {
465                    // Sender is not on the allowlist
466                    return;
467                }
468                let body = text_content.trim_start();
469                if let Some(input_command) = get_command(&command_prefix, body) {
470                    if input_command == command {
471                        // Call the callback
472                        if let Err(e) = callback(event.sender.clone(), body.to_string(), room).await
473                        {
474                            error!("Error running command: {} - {:?}", command, e);
475                        }
476                    }
477                }
478            },
479        );
480    }
481
482    /// Run the bot continuously
483    /// This function takes ownership of the bot, we'll be moving data out of it for use in the function closures
484    pub async fn run(&self) -> anyhow::Result<()> {
485        self.register_help_command().await;
486        let client = self.client.as_ref().expect("client not initialized");
487
488        let filter = FilterDefinition::with_lazy_loading();
489        let mut sync_settings = SyncSettings::default().filter(filter.into());
490
491        // If we've already synced through a certain point, we'll sync the latest.
492        if let Some(sync_token) = &self.sync_token {
493            sync_settings = sync_settings.token(sync_token);
494        }
495        // This loops until we kill the program or an error happens.
496        client
497            .sync_with_result_callback(sync_settings, |sync_result| async move {
498                let response = sync_result?;
499
500                // We persist the token each time to be able to restore our session
501                self.persist_sync_token(response.next_batch)
502                    .await
503                    .map_err(|err| Error::UnknownError(err.into()))?;
504
505                Ok(LoopCtrl::Continue)
506            })
507            .await?;
508
509        Ok(())
510    }
511
512    async fn persist_sync_token(&self, sync_token: String) -> anyhow::Result<()> {
513        let serialized_session = fs::read_to_string(self.session_file().clone()).await?;
514        let mut full_session: FullSession = serde_json::from_str(&serialized_session)?;
515
516        full_session.sync_token = Some(sync_token);
517        let serialized_session = serde_json::to_string(&full_session)?;
518        fs::write(self.session_file().clone(), serialized_session).await?;
519
520        Ok(())
521    }
522
523    /// Get the state directory for the bot
524    pub fn state_dir(&self) -> PathBuf {
525        if let Some(state_dir) = &self.config.state_dir {
526            PathBuf::from(expand_tilde(state_dir))
527        } else {
528            dirs::state_dir()
529                .expect("no state_dir directory found")
530                .join(self.name())
531        }
532    }
533
534    /// Get the name of the bot
535    pub fn name(&self) -> String {
536        self.config
537            .name
538            .clone()
539            .unwrap_or_else(|| self.config.login.username.clone())
540    }
541
542    /// Get the full name of the bot
543    pub fn full_name(&self) -> String {
544        self.client().user_id().unwrap().to_string()
545    }
546
547    /// Get the client used by the bot
548    pub fn client(&self) -> &Client {
549        self.client.as_ref().expect("client not initialized")
550    }
551
552    /// Get the command prefix for the bot
553    pub fn command_prefix(&self) -> String {
554        let prefix = self
555            .config
556            .command_prefix
557            .clone()
558            .unwrap_or_else(|| format!("!{} ", self.name()));
559        // If the prefix is 1 character, we'll return it as it. If it's more than 1 character, we'll ensure it ends with a space
560        if prefix.len() == 1 || prefix.ends_with(' ') {
561            prefix
562        } else {
563            format!("{} ", prefix)
564        }
565    }
566}
567
568/// Verify if the sender is on the allow_list
569fn is_allowed(allow_list: Option<String>, sender: &str, username: &str) -> bool {
570    // Check to see if it's from ourselves, in which case we should ignore it
571    if sender == username {
572        false
573    } else if let Some(allow_list) = allow_list {
574        let regex = Regex::new(&allow_list).expect("Invalid regular expression");
575        regex.is_match(sender)
576    } else {
577        false
578    }
579}
580
581/// Check if the message is a command.
582pub fn is_command(command_prefix: &str, text: &str) -> bool {
583    text.starts_with(command_prefix)
584}
585
586/// Get the command, if it is a command.
587pub fn get_command<'a>(command_prefix: &str, text: &'a str) -> Option<&'a str> {
588    if text.starts_with(command_prefix) {
589        text.trim_start_matches(command_prefix)
590            .split_whitespace()
591            .next()
592    } else {
593        None
594    }
595}
596
597/// Fixup the path if they've provided a ~
598fn expand_tilde(path: &str) -> String {
599    if path.starts_with("~/") {
600        if let Some(home_dir) = dirs::home_dir() {
601            let without_tilde = &path[1..]; // Remove the '~' and keep the rest of the path
602            return home_dir.display().to_string() + without_tilde;
603        }
604    }
605    path.to_string()
606}
607
608/// Restore a previous session.
609async fn restore_session(session_file: &Path) -> anyhow::Result<(Client, Option<String>)> {
610    info!(
611        "Previous session found in '{}'",
612        session_file.to_string_lossy()
613    );
614
615    // The session was serialized as JSON in a file.
616    let serialized_session = fs::read_to_string(session_file).await?;
617    let FullSession {
618        client_session,
619        user_session,
620        sync_token,
621    } = serde_json::from_str(&serialized_session)?;
622
623    // Build the client with the previous settings from the session.
624    let client = Client::builder()
625        .homeserver_url(client_session.homeserver)
626        .sqlite_store(client_session.db_path, Some(&client_session.passphrase))
627        .build()
628        .await?;
629
630    info!("Restoring session for {}…", &user_session.meta.user_id);
631
632    // Restore the Matrix user session.
633    client.restore_session(user_session).await?;
634
635    info!("Done!");
636
637    Ok((client, sync_token))
638}
639
640/// Login with a new device.
641async fn login(
642    state_dir: &Path,
643    session_file: &Path,
644    homeserver_url: &str,
645    username: &str,
646    password: &Option<String>,
647) -> anyhow::Result<Client> {
648    info!("No previous session found, logging in…");
649
650    let (client, client_session) = build_client(state_dir, homeserver_url.to_owned()).await?;
651    let matrix_auth = client.matrix_auth();
652
653    // If there's no password, ask for it
654    let password = match password {
655        Some(password) => password.clone(),
656        None => {
657            print!("Password: ");
658            io::stdout().flush().expect("Unable to write to stdout");
659            let mut password = String::new();
660            io::stdin()
661                .read_line(&mut password)
662                .expect("Unable to read user input");
663            password.trim().to_owned()
664        }
665    };
666
667    match matrix_auth
668        .login_username(username, &password)
669        .initial_device_display_name("headjack client")
670        .await
671    {
672        Ok(_) => {
673            info!("Logged in as {username}");
674        }
675        Err(error) => {
676            error!("Error logging in: {error}");
677            return Err(error.into());
678        }
679    }
680
681    // Persist the session to reuse it later.
682    let user_session = matrix_auth
683        .session()
684        .expect("A logged-in client should have a session");
685    let serialized_session = serde_json::to_string(&FullSession {
686        client_session,
687        user_session,
688        sync_token: None,
689    })?;
690    fs::write(session_file, serialized_session).await?;
691
692    info!("Session persisted in {}", session_file.to_string_lossy());
693
694    Ok(client)
695}
696
697/// Build a new client.
698async fn build_client(
699    state_dir: &Path,
700    homeserver: String,
701) -> anyhow::Result<(Client, ClientSession)> {
702    let mut rng = thread_rng();
703
704    // Place the db into a subfolder, just in case multiple clients are running
705    let db_subfolder: String = (&mut rng)
706        .sample_iter(Alphanumeric)
707        .take(7)
708        .map(char::from)
709        .collect();
710    let db_path = state_dir.join(db_subfolder);
711
712    // Generate a random passphrase.
713    // It will be saved in the session file and used to encrypt the database.
714    let passphrase: String = (&mut rng)
715        .sample_iter(Alphanumeric)
716        .take(32)
717        .map(char::from)
718        .collect();
719
720    match Client::builder()
721        .homeserver_url(&homeserver)
722        // We use the SQLite store, which is enabled by default. This is the crucial part to
723        // persist the encryption setup.
724        // Note that other store backends are available and you can even implement your own.
725        .sqlite_store(&db_path, Some(&passphrase))
726        .build()
727        .await
728    {
729        Ok(client) => Ok((
730            client,
731            ClientSession {
732                homeserver,
733                db_path,
734                passphrase,
735            },
736        )),
737        Err(error) => Err(error.into()),
738    }
739}
740
741/// Write the sync_token to the session file
742async fn persist_sync_token(session_file: &Path, sync_token: String) -> anyhow::Result<()> {
743    let serialized_session = fs::read_to_string(session_file).await?;
744    let mut full_session: FullSession = serde_json::from_str(&serialized_session)?;
745
746    full_session.sync_token = Some(sync_token);
747    let serialized_session = serde_json::to_string(&full_session)?;
748    fs::write(session_file, serialized_session).await?;
749
750    Ok(())
751}
752
753/// Check if the room exceeds the size limit
754async fn is_room_too_large(room: &Room, room_size_limit: Option<usize>) -> bool {
755    if let Some(room_size_limit) = room_size_limit {
756        if let Ok(members) = room.members(RoomMemberships::ACTIVE).await {
757            members.len() > room_size_limit
758        } else {
759            false
760        }
761    } else {
762        false
763    }
764}