1use std::path::PathBuf;
2
3use clap::Args;
4use emmylua_formatter as luafmt;
5use eyre::{Context, Result};
6use lux_lib::{
7 config::Config, lua_version::LuaVersion, package::PackageName, project::Project,
8 workspace::Workspace,
9};
10use path_slash::PathExt;
11use walkdir::WalkDir;
12
13#[derive(Args)]
14pub struct Fmt {
15 workspace_or_file: Option<PathBuf>,
17
18 #[clap(default_value = "stylua")]
19 #[arg(long)]
20 backend: FmtBackend,
21
22 #[arg(short, long, visible_short_alias = 'p')]
24 package: Option<PackageName>,
25}
26
27#[derive(clap::ValueEnum, Clone, Debug)]
28enum FmtBackend {
29 Stylua,
31 Luafmt,
35 EmmyluaCodestyle,
37}
38
39pub fn format(args: Fmt, config: Config) -> Result<()> {
40 let workspace: Workspace = match args.workspace_or_file {
41 Some(ref ws) => match Workspace::from_exact(ws)? {
42 Some(ws) => ws,
43 None => Workspace::current_or_err()?,
44 },
45 None => Workspace::current_or_err()?,
46 };
47
48 if let Some(package) = &args.package {
49 let project = workspace.select_member(package)?;
50 format_project(&args, &workspace, project, &config)?;
51 } else {
52 for project in workspace.members() {
53 format_project(&args, &workspace, project, &config)?;
54 }
55 }
56 Ok(())
57}
58
59fn format_project(
60 args: &Fmt,
61 workspace: &Workspace,
62 project: &Project,
63 config: &Config,
64) -> Result<()> {
65 let root = workspace.root();
66
67 let stylua_config: stylua_lib::Config = std::fs::read_to_string(root.join("stylua.toml"))
68 .or_else(|_| std::fs::read_to_string(root.join(".stylua.toml")))
69 .map(|config: String| toml::from_str(&config).unwrap_or_default())
70 .or_else(|_| {
71 stylua_lib::editorconfig::parse(stylua_lib::Config::new(), &root.join("*.lua"))
72 })
73 .unwrap_or_default();
74
75 let luafmt_config = luafmt::resolve_config_for_path(Some(root.as_ref()), None)
76 .map(|resolved| resolved.config)
77 .unwrap_or_default();
78 let luafmt_syntax_level = workspace
79 .lua_version(config)
80 .map(lua_version_to_luafmt_syntax_level)
81 .unwrap_or(luafmt_config.syntax.level);
82
83 let emmylua_config = root.join(".editorconfig");
84
85 let workspace_or_file = args
86 .workspace_or_file
87 .as_ref()
88 .map(std::path::absolute)
89 .transpose()?;
90
91 WalkDir::new(project.root().join("src"))
92 .into_iter()
93 .chain(WalkDir::new(project.root().join("lua")))
94 .chain(WalkDir::new(project.root().join("lib")))
95 .chain(WalkDir::new(project.root().join("spec")))
96 .chain(WalkDir::new(project.root().join("test")))
97 .chain(WalkDir::new(project.root().join("tests")))
98 .filter_map(Result::ok)
99 .filter(|file| {
100 workspace_or_file
101 .as_ref()
102 .is_none_or(|workspace_or_file| file.path().starts_with(workspace_or_file))
103 })
104 .try_for_each(|file| {
105 if PathBuf::from(file.file_name())
106 .extension()
107 .is_some_and(|ext| ext == "lua")
108 {
109 let file = file.path();
110 let unformatted_code = std::fs::read_to_string(file)?;
111 let formatted_code = match args.backend {
112 FmtBackend::Stylua => stylua_lib::format_code(
113 &unformatted_code,
114 stylua_config,
115 None,
116 stylua_lib::OutputVerification::Full,
117 )
118 .context(format!("error formatting {} with stylua.", file.display()))?,
119 FmtBackend::Luafmt => {
120 luafmt::check_text(
121 &unformatted_code,
122 luafmt_syntax_level.into(),
123 &luafmt_config,
124 )
125 .formatted
126 }
127 FmtBackend::EmmyluaCodestyle => {
128 let uri = file.to_slash_lossy().to_string();
129 if emmylua_config.is_file() {
130 emmylua_codestyle::update_code_style(
131 &uri,
132 &emmylua_config.to_slash_lossy(),
133 );
134 }
135 emmylua_codestyle::reformat_code(
136 &unformatted_code,
137 &uri,
138 emmylua_codestyle::FormattingOptions::default(),
139 )
140 }
141 };
142
143 std::fs::write(file, formatted_code)
144 .context(format!("error writing formatted file {}.", file.display()))?
145 };
146 Ok::<_, eyre::Report>(())
147 })?;
148
149 let rockspec = project.root().join("extra.rockspec");
152
153 if rockspec.exists() {
154 let unformatted_code = std::fs::read_to_string(&rockspec)?;
155 let formatted_code = match args.backend {
156 FmtBackend::Stylua => stylua_lib::format_code(
157 &unformatted_code,
158 stylua_config,
159 None,
160 stylua_lib::OutputVerification::Full,
161 )?,
162 FmtBackend::Luafmt => {
163 luafmt::check_text(
164 &unformatted_code,
165 luafmt_syntax_level.into(),
166 &luafmt_config,
167 )
168 .formatted
169 }
170 FmtBackend::EmmyluaCodestyle => {
171 let uri = rockspec.to_slash_lossy().to_string();
172 if emmylua_config.is_file() {
173 emmylua_codestyle::update_code_style(&uri, &emmylua_config.to_slash_lossy());
174 }
175 emmylua_codestyle::reformat_code(
176 &unformatted_code,
177 &uri,
178 emmylua_codestyle::FormattingOptions::default(),
179 )
180 }
181 };
182
183 std::fs::write(rockspec, formatted_code)?;
184 }
185 Ok(())
186}
187
188fn lua_version_to_luafmt_syntax_level(lua_version: LuaVersion) -> luafmt::LuaSyntaxLevel {
189 match lua_version {
190 LuaVersion::Lua51 => luafmt::LuaSyntaxLevel::Lua51,
191 LuaVersion::Lua52 => luafmt::LuaSyntaxLevel::Lua52,
192 LuaVersion::Lua53 => luafmt::LuaSyntaxLevel::Lua53,
193 LuaVersion::Lua54 => luafmt::LuaSyntaxLevel::Lua54,
194 LuaVersion::Lua55 => luafmt::LuaSyntaxLevel::Lua55,
195 LuaVersion::LuaJIT | LuaVersion::LuaJIT52 => luafmt::LuaSyntaxLevel::LuaJIT,
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use assert_fs::fixture::PathChild;
202 use assert_fs::{prelude::PathCopy, TempDir};
203 use lux_lib::config::ConfigBuilder;
204 use serial_test::serial;
205
206 use super::*;
207 use std::path::PathBuf;
208
209 #[serial]
210 #[tokio::test]
211 async fn test_format_while_in_another_workspace() {
212 let unformatted_sample_project: PathBuf =
213 "resources/test/sample-projects/unformatted/".into();
214 let unformatted_project_root = TempDir::new().unwrap();
215 unformatted_project_root
216 .copy_from(&unformatted_sample_project, &["**"])
217 .unwrap();
218
219 let cwd_sample_project: PathBuf = "resources/test/sample-projects/init/".into();
220 let cwd_project_root = TempDir::new().unwrap();
221 cwd_project_root
222 .copy_from(&cwd_sample_project, &["**"])
223 .unwrap();
224
225 let cwd = std::env::current_dir().unwrap();
226 std::env::set_current_dir(&cwd_project_root).unwrap();
227
228 let config = ConfigBuilder::new().unwrap().build().unwrap();
229 let fmt = Fmt {
230 workspace_or_file: Some(unformatted_project_root.to_path_buf()),
231 backend: FmtBackend::Stylua,
232 package: None,
233 };
234
235 format(fmt, config).unwrap();
236
237 let unformatted_file_path = unformatted_project_root.child("src").child("main.lua");
238 let content = std::fs::read_to_string(&unformatted_file_path).unwrap();
239
240 assert!(content.contains("print(1 * 2)"));
242
243 std::env::set_current_dir(&cwd).unwrap();
244 }
245}