Skip to main content

lux_cli/
format.rs

1use std::path::PathBuf;
2
3use clap::Args;
4use emmylua_formatter as luafmt;
5use eyre::{Context, Result};
6use lux_lib::{
7    config::Config, lua_version::LuaVersion, package::PackageName, project::Project,
8    workspace::Workspace,
9};
10use path_slash::PathExt;
11use walkdir::WalkDir;
12
13#[derive(Args)]
14pub struct Fmt {
15    /// Optional path to a workspace or Lua file to format.
16    workspace_or_file: Option<PathBuf>,
17
18    #[clap(default_value = "stylua")]
19    #[arg(long)]
20    backend: FmtBackend,
21
22    /// Package to format.
23    #[arg(short, long, visible_short_alias = 'p')]
24    package: Option<PackageName>,
25}
26
27#[derive(clap::ValueEnum, Clone, Debug)]
28enum FmtBackend {
29    /// Mainly follows the [Roblox Lua style guide](https://roblox.github.io/lua-style-guide/).
30    Stylua,
31    /// The default formatter used by [emmylua-analyzer-rust](https://github.com/EmmyLuaLs/emmylua-analyzer-rust).
32    /// If invoked with `lx --lua-version=<version> fmt`, Lux will configure the luafmt syntax level
33    /// to match the specified Lua version.
34    Luafmt,
35    /// The default formatter used by [lua-language-server](https://luals.github.io/).
36    EmmyluaCodestyle,
37}
38
39pub fn format(args: Fmt, config: Config) -> Result<()> {
40    let workspace: Workspace = match args.workspace_or_file {
41        Some(ref ws) => match Workspace::from_exact(ws)? {
42            Some(ws) => ws,
43            None => Workspace::current_or_err()?,
44        },
45        None => Workspace::current_or_err()?,
46    };
47
48    if let Some(package) = &args.package {
49        let project = workspace.select_member(package)?;
50        format_project(&args, &workspace, project, &config)?;
51    } else {
52        for project in workspace.members() {
53            format_project(&args, &workspace, project, &config)?;
54        }
55    }
56    Ok(())
57}
58
59fn format_project(
60    args: &Fmt,
61    workspace: &Workspace,
62    project: &Project,
63    config: &Config,
64) -> Result<()> {
65    let root = workspace.root();
66
67    let stylua_config: stylua_lib::Config = std::fs::read_to_string(root.join("stylua.toml"))
68        .or_else(|_| std::fs::read_to_string(root.join(".stylua.toml")))
69        .map(|config: String| toml::from_str(&config).unwrap_or_default())
70        .or_else(|_| {
71            stylua_lib::editorconfig::parse(stylua_lib::Config::new(), &root.join("*.lua"))
72        })
73        .unwrap_or_default();
74
75    let luafmt_config = luafmt::resolve_config_for_path(Some(root.as_ref()), None)
76        .map(|resolved| resolved.config)
77        .unwrap_or_default();
78    let luafmt_syntax_level = workspace
79        .lua_version(config)
80        .map(lua_version_to_luafmt_syntax_level)
81        .unwrap_or(luafmt_config.syntax.level);
82
83    let emmylua_config = root.join(".editorconfig");
84
85    let workspace_or_file = args
86        .workspace_or_file
87        .as_ref()
88        .map(std::path::absolute)
89        .transpose()?;
90
91    WalkDir::new(project.root().join("src"))
92        .into_iter()
93        .chain(WalkDir::new(project.root().join("lua")))
94        .chain(WalkDir::new(project.root().join("lib")))
95        .chain(WalkDir::new(project.root().join("spec")))
96        .chain(WalkDir::new(project.root().join("test")))
97        .chain(WalkDir::new(project.root().join("tests")))
98        .filter_map(Result::ok)
99        .filter(|file| {
100            workspace_or_file
101                .as_ref()
102                .is_none_or(|workspace_or_file| file.path().starts_with(workspace_or_file))
103        })
104        .try_for_each(|file| {
105            if PathBuf::from(file.file_name())
106                .extension()
107                .is_some_and(|ext| ext == "lua")
108            {
109                let file = file.path();
110                let unformatted_code = std::fs::read_to_string(file)?;
111                let formatted_code = match args.backend {
112                    FmtBackend::Stylua => stylua_lib::format_code(
113                        &unformatted_code,
114                        stylua_config,
115                        None,
116                        stylua_lib::OutputVerification::Full,
117                    )
118                    .context(format!("error formatting {} with stylua.", file.display()))?,
119                    FmtBackend::Luafmt => {
120                        luafmt::check_text(
121                            &unformatted_code,
122                            luafmt_syntax_level.into(),
123                            &luafmt_config,
124                        )
125                        .formatted
126                    }
127                    FmtBackend::EmmyluaCodestyle => {
128                        let uri = file.to_slash_lossy().to_string();
129                        if emmylua_config.is_file() {
130                            emmylua_codestyle::update_code_style(
131                                &uri,
132                                &emmylua_config.to_slash_lossy(),
133                            );
134                        }
135                        emmylua_codestyle::reformat_code(
136                            &unformatted_code,
137                            &uri,
138                            emmylua_codestyle::FormattingOptions::default(),
139                        )
140                    }
141                };
142
143                std::fs::write(file, formatted_code)
144                    .context(format!("error writing formatted file {}.", file.display()))?
145            };
146            Ok::<_, eyre::Report>(())
147        })?;
148
149    // Format the rockspec
150
151    let rockspec = project.root().join("extra.rockspec");
152
153    if rockspec.exists() {
154        let unformatted_code = std::fs::read_to_string(&rockspec)?;
155        let formatted_code = match args.backend {
156            FmtBackend::Stylua => stylua_lib::format_code(
157                &unformatted_code,
158                stylua_config,
159                None,
160                stylua_lib::OutputVerification::Full,
161            )?,
162            FmtBackend::Luafmt => {
163                luafmt::check_text(
164                    &unformatted_code,
165                    luafmt_syntax_level.into(),
166                    &luafmt_config,
167                )
168                .formatted
169            }
170            FmtBackend::EmmyluaCodestyle => {
171                let uri = rockspec.to_slash_lossy().to_string();
172                if emmylua_config.is_file() {
173                    emmylua_codestyle::update_code_style(&uri, &emmylua_config.to_slash_lossy());
174                }
175                emmylua_codestyle::reformat_code(
176                    &unformatted_code,
177                    &uri,
178                    emmylua_codestyle::FormattingOptions::default(),
179                )
180            }
181        };
182
183        std::fs::write(rockspec, formatted_code)?;
184    }
185    Ok(())
186}
187
188fn lua_version_to_luafmt_syntax_level(lua_version: LuaVersion) -> luafmt::LuaSyntaxLevel {
189    match lua_version {
190        LuaVersion::Lua51 => luafmt::LuaSyntaxLevel::Lua51,
191        LuaVersion::Lua52 => luafmt::LuaSyntaxLevel::Lua52,
192        LuaVersion::Lua53 => luafmt::LuaSyntaxLevel::Lua53,
193        LuaVersion::Lua54 => luafmt::LuaSyntaxLevel::Lua54,
194        LuaVersion::Lua55 => luafmt::LuaSyntaxLevel::Lua55,
195        LuaVersion::LuaJIT | LuaVersion::LuaJIT52 => luafmt::LuaSyntaxLevel::LuaJIT,
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use assert_fs::fixture::PathChild;
202    use assert_fs::{prelude::PathCopy, TempDir};
203    use lux_lib::config::ConfigBuilder;
204    use serial_test::serial;
205
206    use super::*;
207    use std::path::PathBuf;
208
209    #[serial]
210    #[tokio::test]
211    async fn test_format_while_in_another_workspace() {
212        let unformatted_sample_project: PathBuf =
213            "resources/test/sample-projects/unformatted/".into();
214        let unformatted_project_root = TempDir::new().unwrap();
215        unformatted_project_root
216            .copy_from(&unformatted_sample_project, &["**"])
217            .unwrap();
218
219        let cwd_sample_project: PathBuf = "resources/test/sample-projects/init/".into();
220        let cwd_project_root = TempDir::new().unwrap();
221        cwd_project_root
222            .copy_from(&cwd_sample_project, &["**"])
223            .unwrap();
224
225        let cwd = std::env::current_dir().unwrap();
226        std::env::set_current_dir(&cwd_project_root).unwrap();
227
228        let config = ConfigBuilder::new().unwrap().build().unwrap();
229        let fmt = Fmt {
230            workspace_or_file: Some(unformatted_project_root.to_path_buf()),
231            backend: FmtBackend::Stylua,
232            package: None,
233        };
234
235        format(fmt, config).unwrap();
236
237        let unformatted_file_path = unformatted_project_root.child("src").child("main.lua");
238        let content = std::fs::read_to_string(&unformatted_file_path).unwrap();
239
240        // the unformatted variant contains too many spaces
241        assert!(content.contains("print(1 * 2)"));
242
243        std::env::set_current_dir(&cwd).unwrap();
244    }
245}