Skip to main content

lux_cli/
format.rs

1use std::path::{Path, PathBuf};
2
3use clap::Args;
4use emmylua_formatter as luafmt;
5use eyre::{bail, Context, Result};
6use lux_lib::{
7    config::Config, lua_version::LuaVersion, package::PackageName, project::Project,
8    workspace::Workspace,
9};
10use path_slash::PathExt;
11use walkdir::WalkDir;
12
13#[derive(Args)]
14pub struct Fmt {
15    /// Path to a workspace, directory, or Lua file to format. Defaults to the current workspace.
16    path: Option<PathBuf>,
17
18    #[clap(default_value = "stylua")]
19    #[arg(long)]
20    backend: FmtBackend,
21
22    /// Package to format.
23    #[arg(short, long, visible_short_alias = 'p')]
24    package: Option<PackageName>,
25}
26
27#[derive(clap::ValueEnum, Clone, Debug)]
28enum FmtBackend {
29    /// Mainly follows the [Roblox Lua style guide](https://roblox.github.io/lua-style-guide/).
30    Stylua,
31    /// The default formatter used by [emmylua-analyzer-rust](https://github.com/EmmyLuaLs/emmylua-analyzer-rust).
32    /// If invoked with `lx --lua-version=<version> fmt`, Lux will configure the luafmt syntax level
33    /// to match the specified Lua version.
34    Luafmt,
35    /// The default formatter used by [lua-language-server](https://luals.github.io/).
36    EmmyluaCodestyle,
37}
38
39// TODO: For `lx check` #1407 and `lx lint` #1409, move `PathTarget` + `classify_path` into a shared module.
40enum PathTarget {
41    Workspace(Box<Workspace>),
42    Directory(PathBuf),
43    File(PathBuf),
44}
45
46fn classify_path(path: &Path) -> Result<PathTarget> {
47    if !path.exists() {
48        bail!("path does not exist: {}", path.display());
49    }
50    if let Some(workspace) = Workspace::from_exact(path)? {
51        return Ok(PathTarget::Workspace(Box::new(workspace)));
52    }
53    let path = std::path::absolute(path)?;
54    if path.is_file() {
55        Ok(PathTarget::File(path))
56    } else {
57        Ok(PathTarget::Directory(path))
58    }
59}
60
61pub fn format(args: Fmt, config: Config) -> Result<()> {
62    let target = match args.path.as_deref() {
63        None => PathTarget::Workspace(Box::new(Workspace::current_or_err()?)),
64        Some(path) => classify_path(path)?,
65    };
66    match target {
67        PathTarget::Workspace(workspace) => {
68            if let Some(package) = &args.package {
69                let project = workspace.select_member(package)?;
70                format_project(&args, &workspace, project, &config)?;
71            } else {
72                for project in workspace.members() {
73                    format_project(&args, &workspace, project, &config)?;
74                }
75            }
76        }
77        PathTarget::File(file) => {
78            ensure_no_package(&args)?;
79            let root = file
80                .parent()
81                .unwrap_or_else(|| Path::new("."))
82                .to_path_buf();
83            format_loose(std::iter::once(file), &root, &args.backend, &config)?;
84        }
85        PathTarget::Directory(dir) => {
86            ensure_no_package(&args)?;
87            let files = WalkDir::new(&dir)
88                .into_iter()
89                .filter_map(Result::ok)
90                .filter(|entry| entry.file_type().is_file())
91                .map(|entry| entry.into_path())
92                .filter(|path| is_lua_source(path));
93            format_loose(files, &dir, &args.backend, &config)?;
94        }
95    }
96    Ok(())
97}
98
99struct FmtConfig {
100    stylua: stylua_lib::Config,
101    luafmt: luafmt::LuaFormatConfig,
102    luafmt_syntax_level: luafmt::LuaSyntaxLevel,
103    editorconfig: PathBuf,
104}
105
106impl FmtConfig {
107    fn resolve(root: &Path, lua_version: Option<LuaVersion>) -> Self {
108        let stylua: stylua_lib::Config = std::fs::read_to_string(root.join("stylua.toml"))
109            .or_else(|_| std::fs::read_to_string(root.join(".stylua.toml")))
110            .map(|config: String| toml::from_str(&config).unwrap_or_default())
111            .or_else(|_| {
112                stylua_lib::editorconfig::parse(stylua_lib::Config::new(), &root.join("*.lua"))
113            })
114            .unwrap_or_default();
115
116        let luafmt = luafmt::resolve_config_for_path(Some(root), None)
117            .map(|resolved| resolved.config)
118            .unwrap_or_default();
119        let luafmt_syntax_level = lua_version
120            .map(lua_version_to_luafmt_syntax_level)
121            .unwrap_or(luafmt.syntax.level);
122
123        Self {
124            stylua,
125            luafmt,
126            luafmt_syntax_level,
127            editorconfig: root.join(".editorconfig"),
128        }
129    }
130
131    fn format(&self, backend: &FmtBackend, path: &Path, code: &str) -> Result<String> {
132        Ok(match backend {
133            FmtBackend::Stylua => stylua_lib::format_code(
134                code,
135                self.stylua,
136                None,
137                stylua_lib::OutputVerification::Full,
138            )
139            .context(format!("error formatting {} with stylua.", path.display()))?,
140            FmtBackend::Luafmt => {
141                luafmt::check_text(code, self.luafmt_syntax_level.into(), &self.luafmt).formatted
142            }
143            FmtBackend::EmmyluaCodestyle => {
144                let uri = path.to_slash_lossy().to_string();
145                if self.editorconfig.is_file() {
146                    emmylua_codestyle::update_code_style(&uri, &self.editorconfig.to_slash_lossy());
147                }
148                emmylua_codestyle::reformat_code(
149                    code,
150                    &uri,
151                    emmylua_codestyle::FormattingOptions::default(),
152                )
153            }
154        })
155    }
156}
157
158fn format_files(
159    files: impl Iterator<Item = PathBuf>,
160    configs: &FmtConfig,
161    backend: &FmtBackend,
162) -> Result<()> {
163    files.into_iter().try_for_each(|file| {
164        let unformatted_code = std::fs::read_to_string(&file)?;
165        let formatted_code = configs.format(backend, &file, &unformatted_code)?;
166        std::fs::write(&file, formatted_code)
167            .context(format!("error writing formatted file {}.", file.display()))
168    })
169}
170
171fn format_project(
172    args: &Fmt,
173    workspace: &Workspace,
174    project: &Project,
175    config: &Config,
176) -> Result<()> {
177    let configs = FmtConfig::resolve(
178        workspace.root().as_ref(),
179        workspace.lua_version(config).ok(),
180    );
181
182    let lua_files = ["src", "lua", "lib", "spec", "test", "tests"]
183        .iter()
184        .flat_map(|dir| WalkDir::new(project.root().join(dir)))
185        .filter_map(Result::ok)
186        .map(walkdir::DirEntry::into_path)
187        .filter(|path| is_lua_source(path));
188
189    let rockspec = project.root().join("extra.rockspec");
190
191    format_files(
192        lua_files.chain(rockspec.exists().then_some(rockspec)),
193        &configs,
194        &args.backend,
195    )
196}
197
198fn is_lua_source(path: &Path) -> bool {
199    path.extension()
200        .is_some_and(|ext| ext == "lua" || ext == "rockspec")
201}
202
203fn ensure_no_package(args: &Fmt) -> Result<()> {
204    if args.package.is_some() {
205        bail!("--package is only valid within a workspace");
206    }
207    Ok(())
208}
209
210fn format_loose(
211    files: impl Iterator<Item = PathBuf>,
212    root: &Path,
213    backend: &FmtBackend,
214    config: &Config,
215) -> Result<()> {
216    let (config_root, lua_version) = match Workspace::from(root)? {
217        Some(workspace) => (
218            workspace.root().as_ref().to_path_buf(),
219            workspace.lua_version(config).ok(),
220        ),
221        None => (root.to_path_buf(), config.lua_version().cloned()),
222    };
223    let configs = FmtConfig::resolve(&config_root, lua_version);
224    format_files(files, &configs, backend)
225}
226
227fn lua_version_to_luafmt_syntax_level(lua_version: LuaVersion) -> luafmt::LuaSyntaxLevel {
228    match lua_version {
229        LuaVersion::Lua51 => luafmt::LuaSyntaxLevel::Lua51,
230        LuaVersion::Lua52 => luafmt::LuaSyntaxLevel::Lua52,
231        LuaVersion::Lua53 => luafmt::LuaSyntaxLevel::Lua53,
232        LuaVersion::Lua54 => luafmt::LuaSyntaxLevel::Lua54,
233        LuaVersion::Lua55 => luafmt::LuaSyntaxLevel::Lua55,
234        LuaVersion::LuaJIT | LuaVersion::LuaJIT52 => luafmt::LuaSyntaxLevel::LuaJIT,
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use assert_fs::fixture::PathChild;
241    use assert_fs::{prelude::PathCopy, TempDir};
242    use lux_lib::config::ConfigBuilder;
243    use serial_test::serial;
244
245    use super::*;
246    use std::path::PathBuf;
247
248    #[serial]
249    #[tokio::test]
250    async fn test_format_while_in_another_workspace() {
251        let unformatted_sample_project: PathBuf =
252            "resources/test/sample-projects/unformatted/".into();
253        let unformatted_project_root = TempDir::new().unwrap();
254        unformatted_project_root
255            .copy_from(&unformatted_sample_project, &["**"])
256            .unwrap();
257
258        let cwd_sample_project: PathBuf = "resources/test/sample-projects/init/".into();
259        let cwd_project_root = TempDir::new().unwrap();
260        cwd_project_root
261            .copy_from(&cwd_sample_project, &["**"])
262            .unwrap();
263
264        let cwd = std::env::current_dir().unwrap();
265        std::env::set_current_dir(&cwd_project_root).unwrap();
266
267        let config = ConfigBuilder::new().unwrap().build().unwrap();
268        let fmt = Fmt {
269            path: Some(unformatted_project_root.to_path_buf()),
270            backend: FmtBackend::Stylua,
271            package: None,
272        };
273
274        format(fmt, config).unwrap();
275
276        let unformatted_file_path = unformatted_project_root.child("src").child("main.lua");
277        let content = std::fs::read_to_string(&unformatted_file_path).unwrap();
278
279        // the unformatted variant contains too many spaces
280        assert!(content.contains("print(1 * 2)"));
281
282        std::env::set_current_dir(&cwd).unwrap();
283    }
284
285    fn loose_lua_temp_dir() -> TempDir {
286        let fixture: PathBuf = "resources/test/loose-lua/".into();
287        let dir = TempDir::new().unwrap();
288        dir.copy_from(&fixture, &["**"]).unwrap();
289        dir
290    }
291
292    fn fmt(path: Option<PathBuf>) -> Fmt {
293        Fmt {
294            path,
295            backend: FmtBackend::Stylua,
296            package: None,
297        }
298    }
299
300    #[test]
301    fn test_format_plain_directory_without_lux_toml() {
302        let dir = loose_lua_temp_dir();
303        let config = ConfigBuilder::new().unwrap().build().unwrap();
304
305        format(fmt(Some(dir.to_path_buf())), config).unwrap();
306
307        let top = std::fs::read_to_string(dir.child("a.lua")).unwrap();
308        let nested = std::fs::read_to_string(dir.child("nested").child("b.lua")).unwrap();
309        let other = std::fs::read_to_string(dir.child("notes.txt")).unwrap();
310        assert!(top.contains("print(1 * 2)"));
311        assert!(nested.contains("print(3 + 4)"));
312        // non-Lua files are left untouched
313        assert!(other.contains("print( 5 *    6 )"));
314    }
315
316    #[test]
317    fn test_format_single_lua_file() {
318        let dir = loose_lua_temp_dir();
319        let config = ConfigBuilder::new().unwrap().build().unwrap();
320
321        format(fmt(Some(dir.child("a.lua").to_path_buf())), config).unwrap();
322
323        let top = std::fs::read_to_string(dir.child("a.lua")).unwrap();
324        let nested = std::fs::read_to_string(dir.child("nested").child("b.lua")).unwrap();
325        assert!(top.contains("print(1 * 2)"));
326        // a sibling file is not touched when a single file is targeted
327        assert!(nested.contains("print( 3 +    4 )"));
328    }
329
330    #[test]
331    fn test_format_nonexistent_path_errors() {
332        let config = ConfigBuilder::new().unwrap().build().unwrap();
333        let result = format(fmt(Some("/no/such/path".into())), config);
334        assert!(result.is_err());
335    }
336
337    #[test]
338    fn test_format_subdir_inherits_workspace_config() {
339        // must resolve workspace's stylua.toml (Spaces/2-width), not stylua default.
340        let fixture: PathBuf = "resources/test/sample-projects/stylua-config/".into();
341        let workspace = TempDir::new().unwrap();
342        workspace.copy_from(&fixture, &["**"]).unwrap();
343        let config = ConfigBuilder::new().unwrap().build().unwrap();
344
345        format(fmt(Some(workspace.child("src").to_path_buf())), config).unwrap();
346
347        let content = std::fs::read_to_string(workspace.child("src").child("main.lua")).unwrap();
348        assert!(content.contains("\n  print(1 * 2)"));
349        assert!(!content.contains('\t'));
350    }
351}