1use std::path::{Path, PathBuf};
2
3use clap::Args;
4use emmylua_formatter as luafmt;
5use eyre::{bail, Context, Result};
6use lux_lib::{
7 config::Config, lua_version::LuaVersion, package::PackageName, project::Project,
8 workspace::Workspace,
9};
10use path_slash::PathExt;
11use walkdir::WalkDir;
12
13#[derive(Args)]
14pub struct Fmt {
15 path: Option<PathBuf>,
17
18 #[clap(default_value = "stylua")]
19 #[arg(long)]
20 backend: FmtBackend,
21
22 #[arg(short, long, visible_short_alias = 'p')]
24 package: Option<PackageName>,
25}
26
27#[derive(clap::ValueEnum, Clone, Debug)]
28enum FmtBackend {
29 Stylua,
31 Luafmt,
35 EmmyluaCodestyle,
37}
38
39enum PathTarget {
41 Workspace(Box<Workspace>),
42 Directory(PathBuf),
43 File(PathBuf),
44}
45
46fn classify_path(path: &Path) -> Result<PathTarget> {
47 if !path.exists() {
48 bail!("path does not exist: {}", path.display());
49 }
50 if let Some(workspace) = Workspace::from_exact(path)? {
51 return Ok(PathTarget::Workspace(Box::new(workspace)));
52 }
53 let path = std::path::absolute(path)?;
54 if path.is_file() {
55 Ok(PathTarget::File(path))
56 } else {
57 Ok(PathTarget::Directory(path))
58 }
59}
60
61pub fn format(args: Fmt, config: Config) -> Result<()> {
62 let target = match args.path.as_deref() {
63 None => PathTarget::Workspace(Box::new(Workspace::current_or_err()?)),
64 Some(path) => classify_path(path)?,
65 };
66 match target {
67 PathTarget::Workspace(workspace) => {
68 if let Some(package) = &args.package {
69 let project = workspace.select_member(package)?;
70 format_project(&args, &workspace, project, &config)?;
71 } else {
72 for project in workspace.members() {
73 format_project(&args, &workspace, project, &config)?;
74 }
75 }
76 }
77 PathTarget::File(file) => {
78 ensure_no_package(&args)?;
79 let root = file
80 .parent()
81 .unwrap_or_else(|| Path::new("."))
82 .to_path_buf();
83 format_loose(std::iter::once(file), &root, &args.backend, &config)?;
84 }
85 PathTarget::Directory(dir) => {
86 ensure_no_package(&args)?;
87 let files = WalkDir::new(&dir)
88 .into_iter()
89 .filter_map(Result::ok)
90 .filter(|entry| entry.file_type().is_file())
91 .map(|entry| entry.into_path())
92 .filter(|path| is_lua_source(path));
93 format_loose(files, &dir, &args.backend, &config)?;
94 }
95 }
96 Ok(())
97}
98
99struct FmtConfig {
100 stylua: stylua_lib::Config,
101 luafmt: luafmt::LuaFormatConfig,
102 luafmt_syntax_level: luafmt::LuaSyntaxLevel,
103 editorconfig: PathBuf,
104}
105
106impl FmtConfig {
107 fn resolve(root: &Path, lua_version: Option<LuaVersion>) -> Self {
108 let stylua: stylua_lib::Config = std::fs::read_to_string(root.join("stylua.toml"))
109 .or_else(|_| std::fs::read_to_string(root.join(".stylua.toml")))
110 .map(|config: String| toml::from_str(&config).unwrap_or_default())
111 .or_else(|_| {
112 stylua_lib::editorconfig::parse(stylua_lib::Config::new(), &root.join("*.lua"))
113 })
114 .unwrap_or_default();
115
116 let luafmt = luafmt::resolve_config_for_path(Some(root), None)
117 .map(|resolved| resolved.config)
118 .unwrap_or_default();
119 let luafmt_syntax_level = lua_version
120 .map(lua_version_to_luafmt_syntax_level)
121 .unwrap_or(luafmt.syntax.level);
122
123 Self {
124 stylua,
125 luafmt,
126 luafmt_syntax_level,
127 editorconfig: root.join(".editorconfig"),
128 }
129 }
130
131 fn format(&self, backend: &FmtBackend, path: &Path, code: &str) -> Result<String> {
132 Ok(match backend {
133 FmtBackend::Stylua => stylua_lib::format_code(
134 code,
135 self.stylua,
136 None,
137 stylua_lib::OutputVerification::Full,
138 )
139 .context(format!("error formatting {} with stylua.", path.display()))?,
140 FmtBackend::Luafmt => {
141 luafmt::check_text(code, self.luafmt_syntax_level.into(), &self.luafmt).formatted
142 }
143 FmtBackend::EmmyluaCodestyle => {
144 let uri = path.to_slash_lossy().to_string();
145 if self.editorconfig.is_file() {
146 emmylua_codestyle::update_code_style(&uri, &self.editorconfig.to_slash_lossy());
147 }
148 emmylua_codestyle::reformat_code(
149 code,
150 &uri,
151 emmylua_codestyle::FormattingOptions::default(),
152 )
153 }
154 })
155 }
156}
157
158fn format_files(
159 files: impl Iterator<Item = PathBuf>,
160 configs: &FmtConfig,
161 backend: &FmtBackend,
162) -> Result<()> {
163 files.into_iter().try_for_each(|file| {
164 let unformatted_code = std::fs::read_to_string(&file)?;
165 let formatted_code = configs.format(backend, &file, &unformatted_code)?;
166 std::fs::write(&file, formatted_code)
167 .context(format!("error writing formatted file {}.", file.display()))
168 })
169}
170
171fn format_project(
172 args: &Fmt,
173 workspace: &Workspace,
174 project: &Project,
175 config: &Config,
176) -> Result<()> {
177 let configs = FmtConfig::resolve(
178 workspace.root().as_ref(),
179 workspace.lua_version(config).ok(),
180 );
181
182 let lua_files = ["src", "lua", "lib", "spec", "test", "tests"]
183 .iter()
184 .flat_map(|dir| WalkDir::new(project.root().join(dir)))
185 .filter_map(Result::ok)
186 .map(walkdir::DirEntry::into_path)
187 .filter(|path| is_lua_source(path));
188
189 let rockspec = project.root().join("extra.rockspec");
190
191 format_files(
192 lua_files.chain(rockspec.exists().then_some(rockspec)),
193 &configs,
194 &args.backend,
195 )
196}
197
198fn is_lua_source(path: &Path) -> bool {
199 path.extension()
200 .is_some_and(|ext| ext == "lua" || ext == "rockspec")
201}
202
203fn ensure_no_package(args: &Fmt) -> Result<()> {
204 if args.package.is_some() {
205 bail!("--package is only valid within a workspace");
206 }
207 Ok(())
208}
209
210fn format_loose(
211 files: impl Iterator<Item = PathBuf>,
212 root: &Path,
213 backend: &FmtBackend,
214 config: &Config,
215) -> Result<()> {
216 let (config_root, lua_version) = match Workspace::from(root)? {
217 Some(workspace) => (
218 workspace.root().as_ref().to_path_buf(),
219 workspace.lua_version(config).ok(),
220 ),
221 None => (root.to_path_buf(), config.lua_version().cloned()),
222 };
223 let configs = FmtConfig::resolve(&config_root, lua_version);
224 format_files(files, &configs, backend)
225}
226
227fn lua_version_to_luafmt_syntax_level(lua_version: LuaVersion) -> luafmt::LuaSyntaxLevel {
228 match lua_version {
229 LuaVersion::Lua51 => luafmt::LuaSyntaxLevel::Lua51,
230 LuaVersion::Lua52 => luafmt::LuaSyntaxLevel::Lua52,
231 LuaVersion::Lua53 => luafmt::LuaSyntaxLevel::Lua53,
232 LuaVersion::Lua54 => luafmt::LuaSyntaxLevel::Lua54,
233 LuaVersion::Lua55 => luafmt::LuaSyntaxLevel::Lua55,
234 LuaVersion::LuaJIT | LuaVersion::LuaJIT52 => luafmt::LuaSyntaxLevel::LuaJIT,
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use assert_fs::fixture::PathChild;
241 use assert_fs::{prelude::PathCopy, TempDir};
242 use lux_lib::config::ConfigBuilder;
243 use serial_test::serial;
244
245 use super::*;
246 use std::path::PathBuf;
247
248 #[serial]
249 #[tokio::test]
250 async fn test_format_while_in_another_workspace() {
251 let unformatted_sample_project: PathBuf =
252 "resources/test/sample-projects/unformatted/".into();
253 let unformatted_project_root = TempDir::new().unwrap();
254 unformatted_project_root
255 .copy_from(&unformatted_sample_project, &["**"])
256 .unwrap();
257
258 let cwd_sample_project: PathBuf = "resources/test/sample-projects/init/".into();
259 let cwd_project_root = TempDir::new().unwrap();
260 cwd_project_root
261 .copy_from(&cwd_sample_project, &["**"])
262 .unwrap();
263
264 let cwd = std::env::current_dir().unwrap();
265 std::env::set_current_dir(&cwd_project_root).unwrap();
266
267 let config = ConfigBuilder::new().unwrap().build().unwrap();
268 let fmt = Fmt {
269 path: Some(unformatted_project_root.to_path_buf()),
270 backend: FmtBackend::Stylua,
271 package: None,
272 };
273
274 format(fmt, config).unwrap();
275
276 let unformatted_file_path = unformatted_project_root.child("src").child("main.lua");
277 let content = std::fs::read_to_string(&unformatted_file_path).unwrap();
278
279 assert!(content.contains("print(1 * 2)"));
281
282 std::env::set_current_dir(&cwd).unwrap();
283 }
284
285 fn loose_lua_temp_dir() -> TempDir {
286 let fixture: PathBuf = "resources/test/loose-lua/".into();
287 let dir = TempDir::new().unwrap();
288 dir.copy_from(&fixture, &["**"]).unwrap();
289 dir
290 }
291
292 fn fmt(path: Option<PathBuf>) -> Fmt {
293 Fmt {
294 path,
295 backend: FmtBackend::Stylua,
296 package: None,
297 }
298 }
299
300 #[test]
301 fn test_format_plain_directory_without_lux_toml() {
302 let dir = loose_lua_temp_dir();
303 let config = ConfigBuilder::new().unwrap().build().unwrap();
304
305 format(fmt(Some(dir.to_path_buf())), config).unwrap();
306
307 let top = std::fs::read_to_string(dir.child("a.lua")).unwrap();
308 let nested = std::fs::read_to_string(dir.child("nested").child("b.lua")).unwrap();
309 let other = std::fs::read_to_string(dir.child("notes.txt")).unwrap();
310 assert!(top.contains("print(1 * 2)"));
311 assert!(nested.contains("print(3 + 4)"));
312 assert!(other.contains("print( 5 * 6 )"));
314 }
315
316 #[test]
317 fn test_format_single_lua_file() {
318 let dir = loose_lua_temp_dir();
319 let config = ConfigBuilder::new().unwrap().build().unwrap();
320
321 format(fmt(Some(dir.child("a.lua").to_path_buf())), config).unwrap();
322
323 let top = std::fs::read_to_string(dir.child("a.lua")).unwrap();
324 let nested = std::fs::read_to_string(dir.child("nested").child("b.lua")).unwrap();
325 assert!(top.contains("print(1 * 2)"));
326 assert!(nested.contains("print( 3 + 4 )"));
328 }
329
330 #[test]
331 fn test_format_nonexistent_path_errors() {
332 let config = ConfigBuilder::new().unwrap().build().unwrap();
333 let result = format(fmt(Some("/no/such/path".into())), config);
334 assert!(result.is_err());
335 }
336
337 #[test]
338 fn test_format_subdir_inherits_workspace_config() {
339 let fixture: PathBuf = "resources/test/sample-projects/stylua-config/".into();
341 let workspace = TempDir::new().unwrap();
342 workspace.copy_from(&fixture, &["**"]).unwrap();
343 let config = ConfigBuilder::new().unwrap().build().unwrap();
344
345 format(fmt(Some(workspace.child("src").to_path_buf())), config).unwrap();
346
347 let content = std::fs::read_to_string(workspace.child("src").child("main.lua")).unwrap();
348 assert!(content.contains("\n print(1 * 2)"));
349 assert!(!content.contains('\t'));
350 }
351}