use lazy_static::lazy_static;
use matrix_sdk::ruma::events::room::member::StrippedRoomMemberEvent;
use matrix_sdk::ruma::events::room::message::MessageType;
use matrix_sdk::ruma::events::room::message::OriginalSyncRoomMessageEvent;
use matrix_sdk::ruma::events::room::message::RoomMessageEventContent;
use matrix_sdk::ruma::events::AnySyncMessageLikeEvent;
use matrix_sdk::ruma::OwnedUserId;
use matrix_sdk::RoomMemberships;
use matrix_sdk::RoomState;
use matrix_sdk::{
config::SyncSettings, matrix_auth::MatrixSession, ruma::api::client::filter::FilterDefinition,
Client, Error, LoopCtrl, Room,
};
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io::{self, Write};
use std::path::{Path, PathBuf};
use std::time::Duration;
use tokio::fs;
use tokio::sync::Mutex;
use tokio::time::sleep;
use tracing::{error, info, warn};
mod utils;
pub use utils::*;
lazy_static! {
static ref GLOBAL_STATE: Mutex<HashMap<String, Mutex<State>>> = Mutex::new(HashMap::new());
}
#[derive(Debug, Serialize, Deserialize)]
struct ClientSession {
homeserver: String,
db_path: PathBuf,
passphrase: String,
}
struct HelpText {
command: String,
short: Option<String>,
args: Option<String>,
}
struct State {
help: Vec<HelpText>,
}
#[derive(Debug, Serialize, Deserialize)]
struct FullSession {
client_session: ClientSession,
user_session: MatrixSession,
#[serde(skip_serializing_if = "Option::is_none")]
sync_token: Option<String>,
}
#[derive(Debug, Clone)]
pub struct Login {
pub homeserver_url: String,
pub username: String,
pub password: Option<String>,
}
#[derive(Debug, Clone)]
pub struct BotConfig {
pub login: Login,
pub name: Option<String>,
pub allow_list: Option<String>,
pub state_dir: Option<String>,
pub command_prefix: Option<String>,
pub room_size_limit: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct Bot {
config: BotConfig,
sync_token: Option<String>,
client: Option<Client>,
}
impl Bot {
pub async fn new(config: BotConfig) -> Self {
let bot = Bot {
config,
sync_token: None,
client: None,
};
let mut global_state = GLOBAL_STATE.lock().await;
global_state
.entry(bot.name())
.or_insert_with(|| Mutex::new(State { help: Vec::new() }));
bot
}
fn session_file(&self) -> PathBuf {
self.state_dir().join("session")
}
pub async fn login(&mut self) -> anyhow::Result<()> {
let state_dir = self.state_dir();
let session_file = self.session_file();
let (client, sync_token) = if session_file.exists() {
restore_session(&session_file).await?
} else {
(
login(
&state_dir,
&session_file,
&self.config.login.homeserver_url,
&self.config.login.username,
&self.config.login.password,
)
.await?,
None,
)
};
self.sync_token = sync_token;
self.client = Some(client);
Ok(())
}
pub async fn sync(&mut self) -> anyhow::Result<()> {
let client = self.client.as_ref().expect("client not initialized");
let filter = FilterDefinition::with_lazy_loading();
let mut sync_settings = SyncSettings::default().filter(filter.into());
if let Some(sync_token) = &self.sync_token {
sync_settings = sync_settings.token(sync_token);
}
loop {
match client.sync_once(sync_settings.clone()).await {
Ok(response) => {
self.sync_token = Some(response.next_batch.clone());
persist_sync_token(&self.session_file(), response.next_batch.clone()).await?;
break;
}
Err(error) => {
error!("An error occurred during initial sync: {error}");
error!("Trying again…");
}
}
}
Ok(())
}
async fn register_help_command(&self) {
let name = self.name();
let command_prefix = self.command_prefix();
self.register_text_command(
"help",
None,
Some("Show this message".to_string()),
|_, _, room| async move {
let global_state = GLOBAL_STATE.lock().await;
let state = global_state.get(&name).unwrap();
let state = state.lock().await;
let help = &state.help;
let mut response = format!("`{}help`\n\nAvailable commands:", command_prefix);
for h in help {
response.push_str(&format!("\n`{}{}", command_prefix, h.command));
if let Some(args) = &h.args {
response.push_str(&format!(" {}", args));
}
if let Some(short) = &h.short {
response.push_str(&format!("` - {}", short));
}
}
room.send(RoomMessageEventContent::text_markdown(response))
.await
.map_err(|_| ())?;
Ok(())
},
)
.await;
}
pub fn join_rooms(&self) {
let client = self.client.as_ref().expect("client not initialized");
let allow_list = self.config.allow_list.clone();
let username = self.full_name();
let room_size_limit = self.config.room_size_limit;
client.add_event_handler(
move |room_member: StrippedRoomMemberEvent, client: Client, room: Room| async move {
if room_member.state_key != client.user_id().unwrap() {
return;
}
if !is_allowed(allow_list, room_member.sender.as_str(), &username) {
return;
}
info!("Received stripped room member event: {:?}", room_member);
tokio::spawn(async move {
info!("Autojoining room {}", room.room_id());
let mut delay = 2;
while let Err(err) = room.join().await {
warn!(
"Failed to join room {} ({err:?}), retrying in {delay}s",
room.room_id()
);
sleep(Duration::from_secs(delay)).await;
delay *= 2;
if delay > 3600 {
error!("Can't join room {} ({err:?})", room.room_id());
break;
}
}
if is_room_too_large(&room, room_size_limit).await {
warn!(
"Room {} has too many members, refusing to join",
room.room_id()
);
if let Err(e) = room.leave().await {
error!("Error leaving room: {:?}", e);
}
return;
}
info!("Successfully joined room {}", room.room_id());
});
},
);
}
pub fn join_rooms_callback<F, Fut>(&self, callback: Option<F>)
where
F: FnOnce(Room) -> Fut + Send + 'static + Clone + Sync,
Fut: std::future::Future<Output = Result<(), ()>> + Send + 'static,
{
let client = self.client.as_ref().expect("client not initialized");
let allow_list = self.config.allow_list.clone();
let username = self.full_name();
let room_size_limit = self.config.room_size_limit;
client.add_event_handler(
move |room_member: StrippedRoomMemberEvent, client: Client, room: Room| async move {
if room_member.state_key != client.user_id().unwrap() {
return;
}
if !is_allowed(allow_list, room_member.sender.as_str(), &username) {
return;
}
info!("Received stripped room member event: {:?}", room_member);
tokio::spawn(async move {
info!("Autojoining room {}", room.room_id());
let mut delay = 2;
while let Err(err) = room.join().await {
warn!(
"Failed to join room {} ({err:?}), retrying in {delay}s",
room.room_id()
);
sleep(Duration::from_secs(delay)).await;
delay *= 2;
if delay > 3600 {
error!("Can't join room {} ({err:?})", room.room_id());
break;
}
}
if is_room_too_large(&room, room_size_limit).await {
warn!(
"Room {} has too many members, refusing to join",
room.room_id()
);
if let Err(e) = room.leave().await {
error!("Error leaving room: {:?}", e);
}
return;
}
info!("Successfully joined room {}", room.room_id());
if let Some(callback) = callback {
if let Err(e) = callback(room).await {
error!("Error joining room: {:?}", e)
}
}
});
},
);
}
pub fn register_text_handler<F, Fut>(&self, callback: F)
where
F: FnOnce(OwnedUserId, String, Room, OriginalSyncRoomMessageEvent) -> Fut
+ Send
+ 'static
+ Clone
+ Sync,
Fut: std::future::Future<Output = Result<(), ()>> + Send + 'static,
{
let client = self.client.as_ref().expect("client not initialized");
let allow_list = self.config.allow_list.clone();
let username = self.full_name();
let command_prefix = self.command_prefix();
client.add_event_handler(
move |event: OriginalSyncRoomMessageEvent, room: Room| async move {
if room.state() != RoomState::Joined {
return;
}
let MessageType::Text(text_content) = &event.content.msgtype.clone() else {
return;
};
if !is_allowed(allow_list, event.sender.as_str(), &username) {
return;
}
let body = text_content.body.trim_start();
if is_command(&command_prefix, body) {
return;
}
if let Err(e) =
callback(event.sender.clone(), body.to_string(), room, event.clone()).await
{
error!("Error responding to: {}\nError: {:?}", body, e);
}
},
);
}
pub async fn register_text_command<F, Fut, OptString>(
&self,
command: &str,
args: OptString,
short_help: OptString,
callback: F,
) where
F: FnOnce(OwnedUserId, String, Room) -> Fut + Send + 'static + Clone + Sync,
Fut: std::future::Future<Output = Result<(), ()>> + Send + 'static,
OptString: Into<Option<String>>,
{
{
let mut global_state = GLOBAL_STATE.lock().await;
let state = global_state.get_mut(&self.name()).unwrap();
let mut state = state.lock().await;
state.help.push(HelpText {
command: command.to_string(),
args: args.into(),
short: short_help.into(),
});
}
let client = self.client.as_ref().expect("client not initialized");
let allow_list = self.config.allow_list.clone();
let username = self.full_name();
let command = command.to_owned();
let command_prefix = self.command_prefix();
client.add_event_handler(
move |event: AnySyncMessageLikeEvent, room: Room| async move {
if room.state() != RoomState::Joined {
return;
}
let AnySyncMessageLikeEvent::RoomMessage(event) = event else {
return;
};
let Some(event) = event.as_original() else {
return;
};
let MessageType::Text(_) = event.content.msgtype else {
return;
};
let text_content = event.content.body();
if !is_allowed(allow_list, event.sender.as_str(), &username) {
return;
}
let body = text_content.trim_start();
if let Some(input_command) = get_command(&command_prefix, body) {
if input_command == command {
if let Err(e) = callback(event.sender.clone(), body.to_string(), room).await
{
error!("Error running command: {} - {:?}", command, e);
}
}
}
},
);
}
pub async fn run(&self) -> anyhow::Result<()> {
self.register_help_command().await;
let client = self.client.as_ref().expect("client not initialized");
let filter = FilterDefinition::with_lazy_loading();
let mut sync_settings = SyncSettings::default().filter(filter.into());
if let Some(sync_token) = &self.sync_token {
sync_settings = sync_settings.token(sync_token);
}
client
.sync_with_result_callback(sync_settings, |sync_result| async move {
let response = sync_result?;
self.persist_sync_token(response.next_batch)
.await
.map_err(|err| Error::UnknownError(err.into()))?;
Ok(LoopCtrl::Continue)
})
.await?;
Ok(())
}
async fn persist_sync_token(&self, sync_token: String) -> anyhow::Result<()> {
let serialized_session = fs::read_to_string(self.session_file().clone()).await?;
let mut full_session: FullSession = serde_json::from_str(&serialized_session)?;
full_session.sync_token = Some(sync_token);
let serialized_session = serde_json::to_string(&full_session)?;
fs::write(self.session_file().clone(), serialized_session).await?;
Ok(())
}
pub fn state_dir(&self) -> PathBuf {
if let Some(state_dir) = &self.config.state_dir {
PathBuf::from(expand_tilde(state_dir))
} else {
dirs::state_dir()
.expect("no state_dir directory found")
.join(self.name())
}
}
pub fn name(&self) -> String {
self.config
.name
.clone()
.unwrap_or_else(|| self.config.login.username.clone())
}
pub fn full_name(&self) -> String {
self.client().user_id().unwrap().to_string()
}
pub fn client(&self) -> &Client {
self.client.as_ref().expect("client not initialized")
}
pub fn command_prefix(&self) -> String {
let prefix = self
.config
.command_prefix
.clone()
.unwrap_or_else(|| format!("!{} ", self.name()));
if prefix.len() == 1 || prefix.ends_with(' ') {
prefix
} else {
format!("{} ", prefix)
}
}
}
fn is_allowed(allow_list: Option<String>, sender: &str, username: &str) -> bool {
if sender == username {
false
} else if let Some(allow_list) = allow_list {
let regex = Regex::new(&allow_list).expect("Invalid regular expression");
regex.is_match(sender)
} else {
false
}
}
pub fn is_command(command_prefix: &str, text: &str) -> bool {
text.starts_with(command_prefix)
}
pub fn get_command<'a>(command_prefix: &str, text: &'a str) -> Option<&'a str> {
if text.starts_with(command_prefix) {
text.trim_start_matches(command_prefix)
.split_whitespace()
.next()
} else {
None
}
}
fn expand_tilde(path: &str) -> String {
if path.starts_with("~/") {
if let Some(home_dir) = dirs::home_dir() {
let without_tilde = &path[1..]; return home_dir.display().to_string() + without_tilde;
}
}
path.to_string()
}
async fn restore_session(session_file: &Path) -> anyhow::Result<(Client, Option<String>)> {
info!(
"Previous session found in '{}'",
session_file.to_string_lossy()
);
let serialized_session = fs::read_to_string(session_file).await?;
let FullSession {
client_session,
user_session,
sync_token,
} = serde_json::from_str(&serialized_session)?;
let client = Client::builder()
.homeserver_url(client_session.homeserver)
.sqlite_store(client_session.db_path, Some(&client_session.passphrase))
.build()
.await?;
info!("Restoring session for {}…", &user_session.meta.user_id);
client.restore_session(user_session).await?;
info!("Done!");
Ok((client, sync_token))
}
async fn login(
state_dir: &Path,
session_file: &Path,
homeserver_url: &str,
username: &str,
password: &Option<String>,
) -> anyhow::Result<Client> {
info!("No previous session found, logging in…");
let (client, client_session) = build_client(state_dir, homeserver_url.to_owned()).await?;
let matrix_auth = client.matrix_auth();
let password = match password {
Some(password) => password.clone(),
None => {
print!("Password: ");
io::stdout().flush().expect("Unable to write to stdout");
let mut password = String::new();
io::stdin()
.read_line(&mut password)
.expect("Unable to read user input");
password.trim().to_owned()
}
};
match matrix_auth
.login_username(username, &password)
.initial_device_display_name("headjack client")
.await
{
Ok(_) => {
info!("Logged in as {username}");
}
Err(error) => {
error!("Error logging in: {error}");
return Err(error.into());
}
}
let user_session = matrix_auth
.session()
.expect("A logged-in client should have a session");
let serialized_session = serde_json::to_string(&FullSession {
client_session,
user_session,
sync_token: None,
})?;
fs::write(session_file, serialized_session).await?;
info!("Session persisted in {}", session_file.to_string_lossy());
Ok(client)
}
async fn build_client(
state_dir: &Path,
homeserver: String,
) -> anyhow::Result<(Client, ClientSession)> {
let mut rng = thread_rng();
let db_subfolder: String = (&mut rng)
.sample_iter(Alphanumeric)
.take(7)
.map(char::from)
.collect();
let db_path = state_dir.join(db_subfolder);
let passphrase: String = (&mut rng)
.sample_iter(Alphanumeric)
.take(32)
.map(char::from)
.collect();
match Client::builder()
.homeserver_url(&homeserver)
.sqlite_store(&db_path, Some(&passphrase))
.build()
.await
{
Ok(client) => Ok((
client,
ClientSession {
homeserver,
db_path,
passphrase,
},
)),
Err(error) => Err(error.into()),
}
}
async fn persist_sync_token(session_file: &Path, sync_token: String) -> anyhow::Result<()> {
let serialized_session = fs::read_to_string(session_file).await?;
let mut full_session: FullSession = serde_json::from_str(&serialized_session)?;
full_session.sync_token = Some(sync_token);
let serialized_session = serde_json::to_string(&full_session)?;
fs::write(session_file, serialized_session).await?;
Ok(())
}
async fn is_room_too_large(room: &Room, room_size_limit: Option<usize>) -> bool {
if let Some(room_size_limit) = room_size_limit {
if let Ok(members) = room.members(RoomMemberships::ACTIVE).await {
members.len() > room_size_limit
} else {
false
}
} else {
false
}
}