use std::collections::HashMap;
use std::io::Write;
use std::os::unix::ffi::OsStrExt;
use std::path::{
Path,
PathBuf,
};
use std::pin::Pin;
use std::sync::{
Arc,
Mutex,
MutexGuard,
RwLock,
};
use anyhow::{
Context,
anyhow,
};
use serde::de::Error;
use matrix_sdk::event_handler::{
EventHandler,
EventHandlerHandle,
SyncEvent,
};
use matrix_sdk::ruma::{
OwnedUserId,
RoomId,
UserId,
};
use matrix_sdk::{
AuthSession,
Client,
Room,
SendOutsideWasm,
};
use rusqlite;
use serde::Deserialize;
use serde::de::DeserializeOwned;
use tokio::task::JoinSet;
use crate::prelude::*;
use crate::utils;
#[allow(dead_code)]
const R_OK: i32 = 4;
#[allow(dead_code)]
const W_OK: i32 = 2;
#[allow(dead_code)]
const X_OK: i32 = 1;
#[allow(dead_code)]
const F_OK: i32 = 0;
fn access(pathname: &std::path::Path, mode: i32) -> bool {
#[link(name = "c")]
unsafe extern "C" {
fn access(pathname: *const std::os::raw::c_char, mode: i32) -> u32;
}
let mut bytes: Vec<u8> = pathname.as_os_str().as_bytes().to_vec();
bytes.push(0u8);
let ptr = std::ffi::CStr::from_bytes_with_nul(bytes.as_slice())
.unwrap_or_else(|_| panic!("&Path had wrong null terminators: {pathname:?}"))
.as_ptr();
let ret = unsafe { access(ptr, mode) };
ret == 0
}
const XDG_PACKAGE: &str = "rdzobot";
fn xdg_dir_user(envvar: &str, path: &str) -> PathBuf {
if let Ok(value) = std::env::var(envvar) {
[value.as_str(), XDG_PACKAGE].iter().collect()
} else {
#[allow(deprecated)]
[
std::env::home_dir().unwrap(),
path.into(),
XDG_PACKAGE.into(),
]
.iter()
.collect()
}
}
#[derive(Debug, Deserialize)]
pub struct Config {
pub user_id: String,
pub owner: Option<OwnedUserId>,
pub secrets: Option<PathBuf>,
#[serde(default = "timezone_default")]
#[serde(deserialize_with = "timezone_de")]
pub timezone: chrono_tz::Tz,
#[serde(default = "web_listen_default")]
pub web_listen: String,
#[serde(default)]
pub module: crate::ModConfig,
}
fn timezone_default() -> chrono_tz::Tz { chrono_tz::Europe::Warsaw }
fn timezone_de<'de, D>(deserializer: D) -> Result<chrono_tz::Tz, D::Error>
where
D: serde::Deserializer<'de>,
{
let s: String = serde::Deserialize::deserialize(deserializer)?;
s.parse().map_err(D::Error::custom)
}
fn web_listen_default() -> String { "0.0.0.0:5000".to_string() }
#[doc(hidden)]
pub trait HandlerFuture: Future<Output = anyhow::Result<()>> + Send + 'static {}
impl<T> HandlerFuture for T where T: Future<Output = anyhow::Result<()>> + Send + 'static {}
type CommandHandler<R> =
fn(clap::ArgMatches, OriginalSyncRoomMessageEvent, Client, Room, Rdzobot) -> R;
type StoredCommandHandler = Box<
dyn Fn(
clap::ArgMatches,
OriginalSyncRoomMessageEvent,
Client,
Room,
Rdzobot,
) -> Pin<Box<dyn HandlerFuture>>
+ Send
+ Sync
+ 'static,
>;
type RegexHandler<R> =
fn(regex::Regex, String, OriginalSyncRoomMessageEvent, Client, Room, Rdzobot) -> R;
type StoredRegexHandler = Box<
dyn Fn(
regex::Regex,
String,
OriginalSyncRoomMessageEvent,
Client,
Room,
Rdzobot,
) -> Pin<Box<dyn HandlerFuture>>
+ Send
+ Sync
+ 'static,
>;
#[derive(Clone)]
pub struct Rdzobot {
pub(crate) inner: Arc<RdzobotInner>,
pub client: Client,
}
pub(crate) struct RdzobotInner {
config: Config,
secrets: RwLock<HashMap<(String, String), String>>,
state_dir: PathBuf,
sqlite: Mutex<rusqlite::Connection>,
command: RwLock<clap::Command>,
web_router: RwLock<axum::Router<Rdzobot>>,
event_handlers: Mutex<Vec<EventHandlerHandle>>,
command_handlers: RwLock<HashMap<String, StoredCommandHandler>>,
regex_handlers: RwLock<Vec<(regex::Regex, StoredRegexHandler)>>,
}
impl Rdzobot {
pub fn add_command(&self, cmd: clap::Command, handler: CommandHandler<impl HandlerFuture>) {
let all_aliases = cmd.get_all_aliases().map(|i| i.to_string()).collect::<Vec<_>>();
tracing::debug!(
"adding command {}{}",
cmd.get_name(),
if all_aliases.is_empty() {
"".to_string()
} else {
format!(" with aliases {}", all_aliases.join(" "))
}
);
let mut command_handlers = self.inner.command_handlers.write().unwrap();
let mut all_names: Vec<String> = Vec::new();
all_names.push(cmd.get_name().to_string());
all_names.extend(all_aliases);
for alias in all_names.into_iter() {
command_handlers.insert(
alias,
Box::new(move |arg_matches, event, client, room, bot| {
Box::pin(handler(arg_matches, event, client, room, bot))
}),
);
}
let mut command = self.inner.command.write().unwrap();
*command = command.clone().subcommand(cmd);
}
pub fn add_regex(&mut self, pattern: regex::Regex, handler: RegexHandler<impl HandlerFuture>) {
self.inner.regex_handlers.write().unwrap().push((
pattern,
Box::new(move |re, body, event, client, room, bot| {
Box::pin(handler(re, body, event, client, room, bot))
}),
));
}
pub fn add_event_handler<Ev, Ctx, H>(&self, handler: H) -> EventHandlerHandle
where
Ev: SyncEvent + DeserializeOwned + SendOutsideWasm + 'static,
H: EventHandler<Ev, Ctx>,
{
let handle = self.client.add_event_handler(handler);
self.inner
.event_handlers
.lock()
.expect("can't lock event_handlers")
.push(handle.clone());
handle
}
pub fn add_room_event_handler<Ev, Ctx, H>(
&self,
room_id: &RoomId,
handler: H,
) -> EventHandlerHandle
where
Ev: SyncEvent + DeserializeOwned + SendOutsideWasm + 'static,
H: EventHandler<Ev, Ctx>,
{
let handle = self.client.add_room_event_handler(room_id, handler);
self.inner
.event_handlers
.lock()
.expect("can't lock event_handlers")
.push(handle.clone());
handle
}
pub fn merge_web_router(&self, router: axum::Router<Self>) {
let mut web_router = self.inner.web_router.write().unwrap();
*web_router = web_router.clone().merge(router);
}
}
impl Rdzobot {
pub async fn new() -> anyhow::Result<Self> {
Self::new_from_config(Self::read_config().context("failed to read config file")?).await
}
pub async fn new_no_restore() -> anyhow::Result<Self> {
Self::new_from_config_no_restore(Self::read_config().expect("failed to read config file"))
.await
}
async fn new_from_config(config: Config) -> anyhow::Result<Self> {
let bot = Self::new_from_config_no_restore(config).await?;
let session_path = bot.session_path();
if !session_path.exists() {
return Err(anyhow!("no session file, did you log in?"));
}
let file = std::fs::File::open(session_path)?;
let reader = std::io::BufReader::new(file);
let session = AuthSession::Matrix(
serde_json::from_reader(reader).expect("failed to parse session.json"),
);
bot.client.restore_session(session).await.expect("failed to restore session");
Ok(bot)
}
async fn new_from_config_no_restore(config: Config) -> anyhow::Result<Self> {
let user_id =
<OwnedUserId>::try_from(config.user_id.as_str()).context("failed to parse user_id")?;
let xdg_state_dir_user: PathBuf = [
xdg_dir_user("XDG_STATE_HOME", ".local/state"),
user_id.as_str().into(),
]
.iter()
.collect();
let xdg_state_dir_system: PathBuf =
["/var/lib", XDG_PACKAGE, user_id.as_str()].iter().collect();
let state_dir = {
if xdg_state_dir_system.is_dir() && access(&xdg_state_dir_system, W_OK) {
xdg_state_dir_system
} else if xdg_state_dir_user.is_dir() {
xdg_state_dir_user
} else if std::fs::create_dir_all(&xdg_state_dir_system).is_ok() {
xdg_state_dir_system
} else if std::fs::create_dir_all(&xdg_state_dir_user).is_ok() {
xdg_state_dir_user
} else {
return Err(anyhow!("failed to create state directory"));
}
};
let client = matrix_sdk::Client::builder()
.server_name(user_id.server_name())
.sqlite_store(&state_dir, None)
.build()
.await
.expect("failed to build matrix client");
let mut sqlite_path = state_dir.clone();
sqlite_path.push("rdzobot.sqlite3");
let secrets = Self::load_secrets(&config.secrets)?;
let this = Self {
inner: Arc::new(RdzobotInner {
config,
secrets: RwLock::new(secrets),
state_dir,
sqlite: Mutex::new(
rusqlite::Connection::open(sqlite_path).context("failed to open sqlite")?,
),
command: RwLock::new(
clap::Command::new("rdzobot")
.color(clap::ColorChoice::Never)
.multicall(true)
.disable_help_subcommand(true),
),
web_router: RwLock::new(axum::Router::new()),
event_handlers: Mutex::new(Vec::new()),
command_handlers: RwLock::new(HashMap::new()),
regex_handlers: RwLock::new(Vec::new()),
}),
client,
};
this.run_migrations()?;
Ok(this)
}
fn run_migrations(&self) -> anyhow::Result<()> {
Ok(self.sqlite().execute_batch(include_str!("../migrations/000_init.sql"))?)
}
pub fn load_modules(&self) {
self.client.add_event_handler_context(self.clone());
self.add_command(clap::Command::new("!help"), on_cmd_help);
crate::load_modules(self.clone());
self.client.add_event_handler(on_room_message_command_or_regex);
}
fn session_path(&self) -> PathBuf { self.inner.state_dir.join("session.json") }
fn read_config() -> anyhow::Result<Config> {
let xdg_config_dir_system: PathBuf = ["/etc", XDG_PACKAGE, "rdzobot.toml"].iter().collect();
let mut xdg_config_dir_user = xdg_dir_user("XDG_CONFIG_HOME", ".config");
xdg_config_dir_user.push("rdzobot.toml");
for path in [xdg_config_dir_user, xdg_config_dir_system] {
if let Ok(contents) = std::fs::read_to_string(&path) {
let mut config: Config = toml::from_str(&contents)?;
if config.secrets.is_none() {
config.secrets = Some(
path.with_file_name(format!("secrets-{}.toml", config.user_id.as_str())),
);
}
return Ok(config);
}
}
Err(anyhow!("failed to read config"))
}
fn load_secrets(path: &Option<PathBuf>) -> anyhow::Result<HashMap<(String, String), String>> {
let mut secrets = HashMap::new();
match path {
None => {
tracing::warn!("no path to secrets configured, no secrets available");
}
Some(path) => {
if let Ok(contents) = std::fs::read_to_string(path) {
tracing::debug!("loading secrets from {}", path.display());
let table: toml::Table = toml::from_str(&contents)?;
for (group, value) in table.iter() {
for (iden, secret) in
value.as_table().ok_or(anyhow::anyhow!("malformed secrets"))?.iter()
{
secrets.insert(
(group.to_owned(), iden.to_owned()),
secret
.as_str()
.ok_or(anyhow::anyhow!("malformed secrets"))?
.to_owned(),
);
}
}
} else {
tracing::warn!("failed to load secrets from {}", path.display());
}
}
}
Ok(secrets)
}
pub async fn run(&self) -> anyhow::Result<()> {
let sync_settings = matrix_sdk::config::SyncSettings::default().filter(
matrix_sdk::ruma::api::client::filter::FilterDefinition::with_lazy_loading().into(),
);
tracing::info!("initial sync");
loop {
match self.client.sync_once(sync_settings.clone()).await {
Ok(_) => {
break;
}
Err(err) => {
tracing::error!("error occured during initial sync: {err}; trying again");
}
}
}
self.load_modules();
let listener =
tokio::net::TcpListener::bind(self.config().web_listen.as_str()).await.unwrap();
let _web_handle = tokio::spawn(
axum::serve(
listener,
self.inner.web_router.read().unwrap().clone().with_state(self.clone()),
)
.into_future(),
);
loop {
tracing::info!("starting sync");
let err = self.client.sync(sync_settings.clone()).await.unwrap_err();
match err {
matrix_sdk::Error::Http(matrix_sdk::HttpError::Reqwest(e)) if e.is_timeout() => {
tracing::debug!("sync timed out, continuing");
}
_ => {
tracing::error!("sync error: {:?}", err);
return Err(err.into());
}
}
}
}
pub async fn login(&self, password: &str) -> anyhow::Result<()> {
let session_path = self.session_path();
if session_path.exists() {
return Err(anyhow!("already logged in at {}", self.inner.state_dir.display()));
}
let file = std::fs::File::create(&session_path).context("can't create session file")?;
let matrix_auth = self.client.matrix_auth();
matrix_auth
.login_username(self.user_id()?, password)
.initial_device_display_name("rdzobot")
.send()
.await
.context("failed to login")?;
let session = matrix_auth.session().expect("a logged-in client should have session?");
let mut writer = std::io::BufWriter::new(file);
serde_json::to_writer_pretty(&mut writer, &session)?;
writer.write_all(b"\n")?;
writer.flush()?;
Ok(())
}
}
impl Rdzobot {
pub fn config(&self) -> &Config { &self.inner.config }
pub fn sqlite(&self) -> MutexGuard<'_, rusqlite::Connection> {
self.inner.sqlite.lock().unwrap()
}
pub fn get_secret(&self, group: &str, iden: &str) -> Option<String> {
self.inner
.secrets
.read()
.unwrap()
.get(&(group.to_string(), iden.to_string()))
.cloned()
}
pub fn user_id(&self) -> anyhow::Result<&UserId> {
<&UserId>::try_from(self.inner.config.user_id.as_str()).context("failed to parse user_id")
}
pub fn state_dir(&self) -> &Path { &self.inner.state_dir }
}
async fn on_room_message_command_or_regex(
event: OriginalSyncRoomMessageEvent,
client: Client,
room: Room,
bot: Ctx<Rdzobot>,
) -> anyhow::Result<()> {
let Some(body) = text_message_gate(&event, &client, &room) else {
return Ok(());
};
if body.starts_with("!") {
let args: Vec<_> = shlex::Shlex::new(&body).collect();
let result = { bot.inner.command.read().unwrap().clone().try_get_matches_from_mut(args) };
match result {
Ok(mut matches) => {
let (sub_name, sub_matches) = matches.remove_subcommand().unwrap();
let fut = {
bot.inner.command_handlers.read().unwrap().get(&sub_name).unwrap()(
sub_matches,
event,
client,
room,
bot.0.clone(),
)
};
fut.await?;
}
Err(err) => match err.kind() {
clap::error::ErrorKind::DisplayHelp => {
room.send(RoomMessageEventContent::notice_plain(format!("{}", err.render())))
.await?;
}
_ => {
tracing::debug!("error parsing command: {:?}", err);
utils::nie_zesraj_się(&event, &room).await?;
}
},
}
} else {
let mut joins = JoinSet::new();
for (re, handler) in bot.0.inner.regex_handlers.read().unwrap().iter() {
if re.is_match(body.as_str()) {
joins.spawn(handler(
re.clone(),
body.clone(),
event.clone(),
client.clone(),
room.clone(),
bot.0.clone(),
));
}
}
while let Some(result) = joins.join_next().await {
let out = result?;
if let Err(e) = out {
tracing::error!("handler errored out: {}", e);
}
}
}
Ok(())
}
async fn on_cmd_help(
_arg_matches: clap::ArgMatches,
_event: OriginalSyncRoomMessageEvent,
_client: Client,
room: Room,
bot: Rdzobot,
) -> anyhow::Result<()> {
let mut names: Vec<String> = Vec::new();
for subcommand in bot.inner.command.read().unwrap().get_subcommands() {
let mut name = subcommand.get_name().to_string();
let aliases_count = subcommand.get_visible_aliases().count();
if aliases_count > 0 {
name.push_str(
format!(
" ({}: {})",
if aliases_count == 1 {
"alias"
} else {
"aliases"
},
subcommand.get_all_aliases().collect::<Vec<_>>().join(", "),
)
.as_str(),
)
}
names.push(name);
}
names.sort();
#[rustfmt::skip]
room.send(RoomMessageEventContent::notice_plain(
format!("Available commands: {}\nUse !<command> --help for more info", names.join(", "))
)).await?;
Ok(())
}
#[cfg(test)]
#[rustfmt::skip]
mod tests {
use super::*;
#[test]
fn test_parse_empty_config() {
assert!(toml::from_str::<Config>("").is_err());
}
#[test]
fn test_parse_minimal_config() {
dbg!(toml::from_str::<Config>("
user_id = '@woju:hackerspace.pl'
").unwrap());
}
#[test]
fn test_parse_timezone() {
dbg!(toml::from_str::<Config>("
user_id = '@woju:hackerspace.pl'
timezone = 'Europe/Brussels'
").unwrap());
}
}