use std::path::PathBuf;
use clap::Args;
use emmylua_formatter as luafmt;
use eyre::{Context, Result};
use lux_lib::{
config::Config, lua_version::LuaVersion, package::PackageName, project::Project,
workspace::Workspace,
};
use path_slash::PathExt;
use walkdir::WalkDir;
#[derive(Args)]
pub struct Fmt {
workspace_or_file: Option<PathBuf>,
#[clap(default_value = "stylua")]
#[arg(long)]
backend: FmtBackend,
#[arg(short, long, visible_short_alias = 'p')]
package: Option<PackageName>,
}
#[derive(clap::ValueEnum, Clone, Debug)]
enum FmtBackend {
Stylua,
Luafmt,
EmmyluaCodestyle,
}
pub fn format(args: Fmt, config: Config) -> Result<()> {
let workspace: Workspace = match args.workspace_or_file {
Some(ref ws) => match Workspace::from_exact(ws)? {
Some(ws) => ws,
None => Workspace::current_or_err()?,
},
None => Workspace::current_or_err()?,
};
if let Some(package) = &args.package {
let project = workspace.select_member(package)?;
format_project(&args, &workspace, project, &config)?;
} else {
for project in workspace.members() {
format_project(&args, &workspace, project, &config)?;
}
}
Ok(())
}
fn format_project(
args: &Fmt,
workspace: &Workspace,
project: &Project,
config: &Config,
) -> Result<()> {
let root = workspace.root();
let stylua_config: stylua_lib::Config = std::fs::read_to_string(root.join("stylua.toml"))
.or_else(|_| std::fs::read_to_string(root.join(".stylua.toml")))
.map(|config: String| toml::from_str(&config).unwrap_or_default())
.or_else(|_| {
stylua_lib::editorconfig::parse(stylua_lib::Config::new(), &root.join("*.lua"))
})
.unwrap_or_default();
let luafmt_config = luafmt::resolve_config_for_path(Some(root.as_ref()), None)
.map(|resolved| resolved.config)
.unwrap_or_default();
let luafmt_syntax_level = workspace
.lua_version(config)
.map(lua_version_to_luafmt_syntax_level)
.unwrap_or(luafmt_config.syntax.level);
let emmylua_config = root.join(".editorconfig");
let workspace_or_file = args
.workspace_or_file
.as_ref()
.map(std::path::absolute)
.transpose()?;
WalkDir::new(project.root().join("src"))
.into_iter()
.chain(WalkDir::new(project.root().join("lua")))
.chain(WalkDir::new(project.root().join("lib")))
.chain(WalkDir::new(project.root().join("spec")))
.chain(WalkDir::new(project.root().join("test")))
.chain(WalkDir::new(project.root().join("tests")))
.filter_map(Result::ok)
.filter(|file| {
workspace_or_file
.as_ref()
.is_none_or(|workspace_or_file| file.path().starts_with(workspace_or_file))
})
.try_for_each(|file| {
if PathBuf::from(file.file_name())
.extension()
.is_some_and(|ext| ext == "lua")
{
let file = file.path();
let unformatted_code = std::fs::read_to_string(file)?;
let formatted_code = match args.backend {
FmtBackend::Stylua => stylua_lib::format_code(
&unformatted_code,
stylua_config,
None,
stylua_lib::OutputVerification::Full,
)
.context(format!("error formatting {} with stylua.", file.display()))?,
FmtBackend::Luafmt => {
luafmt::check_text(
&unformatted_code,
luafmt_syntax_level.into(),
&luafmt_config,
)
.formatted
}
FmtBackend::EmmyluaCodestyle => {
let uri = file.to_slash_lossy().to_string();
if emmylua_config.is_file() {
emmylua_codestyle::update_code_style(
&uri,
&emmylua_config.to_slash_lossy(),
);
}
emmylua_codestyle::reformat_code(
&unformatted_code,
&uri,
emmylua_codestyle::FormattingOptions::default(),
)
}
};
std::fs::write(file, formatted_code)
.context(format!("error writing formatted file {}.", file.display()))?
};
Ok::<_, eyre::Report>(())
})?;
let rockspec = project.root().join("extra.rockspec");
if rockspec.exists() {
let unformatted_code = std::fs::read_to_string(&rockspec)?;
let formatted_code = match args.backend {
FmtBackend::Stylua => stylua_lib::format_code(
&unformatted_code,
stylua_config,
None,
stylua_lib::OutputVerification::Full,
)?,
FmtBackend::Luafmt => {
luafmt::check_text(
&unformatted_code,
luafmt_syntax_level.into(),
&luafmt_config,
)
.formatted
}
FmtBackend::EmmyluaCodestyle => {
let uri = rockspec.to_slash_lossy().to_string();
if emmylua_config.is_file() {
emmylua_codestyle::update_code_style(&uri, &emmylua_config.to_slash_lossy());
}
emmylua_codestyle::reformat_code(
&unformatted_code,
&uri,
emmylua_codestyle::FormattingOptions::default(),
)
}
};
std::fs::write(rockspec, formatted_code)?;
}
Ok(())
}
fn lua_version_to_luafmt_syntax_level(lua_version: LuaVersion) -> luafmt::LuaSyntaxLevel {
match lua_version {
LuaVersion::Lua51 => luafmt::LuaSyntaxLevel::Lua51,
LuaVersion::Lua52 => luafmt::LuaSyntaxLevel::Lua52,
LuaVersion::Lua53 => luafmt::LuaSyntaxLevel::Lua53,
LuaVersion::Lua54 => luafmt::LuaSyntaxLevel::Lua54,
LuaVersion::Lua55 => luafmt::LuaSyntaxLevel::Lua55,
LuaVersion::LuaJIT | LuaVersion::LuaJIT52 => luafmt::LuaSyntaxLevel::LuaJIT,
}
}
#[cfg(test)]
mod tests {
use assert_fs::fixture::PathChild;
use assert_fs::{prelude::PathCopy, TempDir};
use lux_lib::config::ConfigBuilder;
use serial_test::serial;
use super::*;
use std::path::PathBuf;
#[serial]
#[tokio::test]
async fn test_format_while_in_another_workspace() {
let unformatted_sample_project: PathBuf =
"resources/test/sample-projects/unformatted/".into();
let unformatted_project_root = TempDir::new().unwrap();
unformatted_project_root
.copy_from(&unformatted_sample_project, &["**"])
.unwrap();
let cwd_sample_project: PathBuf = "resources/test/sample-projects/init/".into();
let cwd_project_root = TempDir::new().unwrap();
cwd_project_root
.copy_from(&cwd_sample_project, &["**"])
.unwrap();
let cwd = std::env::current_dir().unwrap();
std::env::set_current_dir(&cwd_project_root).unwrap();
let config = ConfigBuilder::new().unwrap().build().unwrap();
let fmt = Fmt {
workspace_or_file: Some(unformatted_project_root.to_path_buf()),
backend: FmtBackend::Stylua,
package: None,
};
format(fmt, config).unwrap();
let unformatted_file_path = unformatted_project_root.child("src").child("main.lua");
let content = std::fs::read_to_string(&unformatted_file_path).unwrap();
assert!(content.contains("print(1 * 2)"));
std::env::set_current_dir(&cwd).unwrap();
}
}