use crate::{LUA, RUNTIME_FLAGS, components::database::DATABASE_POOLS};
use std::path::PathBuf;
use tracing::error;
pub async fn run_command(
file_path: Option<String>,
code: Option<String>,
stdlib_path: Option<String>,
extra_args: Option<Vec<String>>,
) {
#[allow(clippy::expect_used)]
let lua = LUA.get().expect("Could not get access to the global VM");
let mut actual_path: String = "init.lua".to_string();
#[allow(clippy::expect_used)]
let (user_file, actual_path_str) = if let Some(code) = code {
actual_path = "<commandline>".to_string();
(code, actual_path.clone())
} else {
let file = if let Some(file_path) = file_path {
check_for_default_file(&mut actual_path, file_path)
} else {
check_for_default_file(&mut actual_path, ".".to_string())
};
(file, actual_path.clone())
};
run_command_prerequisite(lua, &actual_path_str, stdlib_path, extra_args).await;
spawn_termination_task(lua.clone());
let user_file = user_file
.lines()
.filter(|line| !line.starts_with("#!"))
.collect::<Vec<_>>()
.join("\n");
if let Err(e) = lua
.load(user_file)
.set_name(actual_path_str)
.exec_async()
.await
{
error!("{}", e);
}
let metrics = tokio::runtime::Handle::current().metrics();
loop {
let alive_tasks = metrics.num_alive_tasks();
if alive_tasks == 1 {
break;
}
}
}
async fn run_command_prerequisite(
lua: &mlua::Lua,
file_path: &str,
stdlib_path: Option<String>,
extra_args: Option<Vec<String>>,
) {
if let Err(e) = super::remove_old_runtime() {
error!("{e:?}");
}
let stdlib_path = stdlib_path.unwrap_or("astra".to_string());
if let Err(e) = RUNTIME_FLAGS.set(crate::RuntimeFlags {
stdlib_path: PathBuf::from(stdlib_path.clone()),
}) {
error!("Could not set the global STDLIB_PATH: {e:?}");
}
if let Err(e) = super::registration(lua, stdlib_path).await {
error!("Error setting up the standard library: {e:?}");
}
if let Some(extra_args) = extra_args
&& let Ok(args) = lua.create_table()
{
if let Err(e) = args.set(0, file_path) {
error!("Error adding arg to the args list: {e:?}");
}
for (index, value) in extra_args.into_iter().enumerate() {
if let Err(e) = args.set((index + 1) as i32, value) {
error!("Error adding arg to the args list: {e:?}");
}
}
if let Err(e) = lua.globals().set("arg", args) {
error!("Error setting the global variable ARGS: {e:?}");
}
}
#[allow(clippy::expect_used)]
let astra_table = lua
.globals()
.get::<mlua::Table>("Astra")
.expect("Could not get the global Astra table");
#[allow(clippy::expect_used)]
astra_table
.set("current_script", file_path)
.expect("Couldn't set the script path");
#[allow(clippy::expect_used)]
astra_table
.set("main_script", file_path)
.expect("Couldn't set the script path");
}
fn check_for_default_file(actual_path: &mut String, file_path: String) -> String {
actual_path.clone_from(&file_path);
let result;
let file_path = std::path::Path::new(&file_path);
#[allow(clippy::expect_used)]
if file_path.exists() && file_path.is_file() {
result = std::fs::read_to_string(file_path).expect("Couldn't read file");
} else if file_path.join("init.lua").exists() {
actual_path.clone_from(&file_path.join("init.lua").to_string_lossy().to_string());
result = std::fs::read_to_string(file_path.join("init.lua")).expect("Couldn't read file");
} else if file_path.join("init.luau").exists() {
actual_path.clone_from(&file_path.join("init.luau").to_string_lossy().to_string());
result = std::fs::read_to_string(file_path.join("init.luau")).expect("Couldn't read file");
} else {
panic!("Could not find any file to run...");
}
result
}
fn spawn_termination_task(lua: mlua::Lua) {
tokio::spawn(async move {
let sigint = tokio::signal::ctrl_c();
#[cfg(unix)]
if let Ok(mut sigterm) =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
&& let Ok(mut sigquit) =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::quit())
{
tokio::select! {
_ = sigterm.recv() => {}
_ = sigquit.recv() => {}
_ = sigint => {}
}
}
#[cfg(not(unix))]
{
tokio::select! {
_ = sigint => {}
}
}
if let Ok(exit_function) = lua.globals().get::<mlua::Function>("ASTRA_SHUTDOWN_CODE")
&& let Err(e) = exit_function.call_async::<()>(()).await
{
error!("{e}");
}
let database_pools = DATABASE_POOLS.lock().await.clone();
for i in database_pools {
match i {
crate::components::database::DatabaseType::Postgres(pool) => pool.close().await,
crate::components::database::DatabaseType::Sqlite(pool) => pool.close().await,
}
}
std::process::exit(
#[cfg(unix)]
0,
#[cfg(not(unix))]
256,
);
});
}