use anyhow::{bail, Context, Result};
use clap::StructOpt;
use console::style;
use ignore::{gitignore::Gitignore, overrides::OverrideBuilder, WalkBuilder};
use log::{LevelFilter, *};
use serde_json::json;
use std::fs;
use std::io::{stderr, stdin, stdout, Read, Write};
use std::path::Path;
use std::sync::atomic::{AtomicI32, AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Instant;
use thiserror::Error;
use threadpool::ThreadPool;
use stylua_lib::{format_code, Config, OutputVerification, Range};
use crate::config::find_ignore_file_path;
mod config;
mod opt;
mod output_diff;
static EXIT_CODE: AtomicI32 = AtomicI32::new(0);
static UNFORMATTED_FILE_COUNT: AtomicU32 = AtomicU32::new(0);
enum FormatResult {
Complete,
SuccessBufferedOutput(Vec<u8>),
Diff(Vec<u8>),
}
#[derive(Error, Debug)]
#[error("{:#}", .error)]
struct ErrorFileWrapper {
file: String,
error: anyhow::Error,
}
fn convert_parse_error_to_json(file: &str, err: &full_moon::Error) -> Option<serde_json::Value> {
Some(match err {
full_moon::Error::AstError(full_moon::ast::AstError::UnexpectedToken {
token,
additional,
}) => json!({
"type": "parse_error",
"message": format!("unexpected token `{}`{}", token, additional.as_ref().map(|x| format!(": {}", x)).unwrap_or_default()),
"filename": file,
"location": {
"start": token.start_position().bytes(),
"start_line": token.start_position().line(),
"start_column": token.start_position().character(),
"end": token.end_position().bytes(),
"end_line": token.end_position().line(),
"end_column": token.end_position().character(),
},
}),
full_moon::Error::TokenizerError(error) => json!({
"type": "parse_error",
"message": match error.error() {
full_moon::tokenizer::TokenizerErrorType::UnclosedComment => {
"unclosed comment".to_string()
}
full_moon::tokenizer::TokenizerErrorType::UnclosedString => {
"unclosed string".to_string()
}
full_moon::tokenizer::TokenizerErrorType::UnexpectedShebang => {
"unexpected shebang".to_string()
}
full_moon::tokenizer::TokenizerErrorType::UnexpectedToken(
character,
) => {
format!("unexpected character {}", character)
}
full_moon::tokenizer::TokenizerErrorType::InvalidSymbol(symbol) => {
format!("invalid symbol {}", symbol)
}
},
"filename": file,
"location": {
"start": error.position().bytes(),
"start_line": error.position().line(),
"start_column": error.position().character(),
"end": error.position().bytes(),
"end_line": error.position().line(),
"end_column": error.position().character(),
},
}),
_ => {
error!("{:#}", err);
return None;
}
})
}
fn create_diff(
opt: &opt::Opt,
original: &str,
expected: &str,
file_name: &str,
) -> Result<Option<Vec<u8>>> {
match opt.output_format {
opt::OutputFormat::Standard => output_diff::output_diff(
original,
expected,
3,
&format!("Diff in {}:", file_name),
opt.color,
),
opt::OutputFormat::Unified => output_diff::output_diff_unified(original, expected),
opt::OutputFormat::Json => {
output_diff::output_diff_json(original, expected)
.map(|mismatches| {
serde_json::to_vec(&json!({
"file": file_name,
"mismatches": mismatches
}))
.map(|mut vec| {
vec.push(b'\n');
vec
})
.map_err(|err| err.into())
})
.transpose()
}
opt::OutputFormat::Summary => {
if original == expected {
Ok(None)
} else {
Ok(Some(format!("{}\n", file_name).into_bytes()))
}
}
}
}
fn format_file(
path: &Path,
config: Config,
range: Option<Range>,
opt: &opt::Opt,
verify_output: OutputVerification,
) -> Result<FormatResult> {
let contents =
fs::read_to_string(path).with_context(|| format!("failed to read {}", path.display()))?;
let before_formatting = Instant::now();
let formatted_contents = format_code(&contents, config, range, verify_output)
.with_context(|| format!("could not format file {}", path.display()))?;
let after_formatting = Instant::now();
debug!(
"formatted {} in {:?}",
path.display(),
after_formatting.duration_since(before_formatting)
);
if opt.check {
let diff = create_diff(
opt,
&contents,
&formatted_contents,
path.display().to_string().as_str(),
)
.context("failed to create diff")?;
match diff {
Some(diff) => Ok(FormatResult::Diff(diff)),
None => Ok(FormatResult::Complete),
}
} else {
fs::write(path, formatted_contents)
.with_context(|| format!("could not write to {}", path.display()))?;
Ok(FormatResult::Complete)
}
}
fn format_string(
input: String,
config: Config,
range: Option<Range>,
opt: &opt::Opt,
verify_output: OutputVerification,
should_skip: bool,
) -> Result<FormatResult> {
let formatted_contents = if should_skip {
input.clone()
} else {
format_code(&input, config, range, verify_output).context("failed to format from stdin")?
};
if opt.check {
let diff = create_diff(opt, &input, &formatted_contents, "stdin")
.context("failed to create diff")?;
match diff {
Some(diff) => Ok(FormatResult::Diff(diff)),
None => Ok(FormatResult::Complete),
}
} else {
Ok(FormatResult::SuccessBufferedOutput(
formatted_contents.into_bytes(),
))
}
}
fn get_ignore(
directory: &Path,
search_parent_directories: bool,
) -> Result<Gitignore, ignore::Error> {
let file_path = find_ignore_file_path(directory.to_path_buf(), search_parent_directories);
if let Some(file_path) = file_path {
let (ignore, err) = Gitignore::new(file_path);
if let Some(err) = err {
Err(err)
} else {
Ok(ignore)
}
} else {
Ok(Gitignore::empty())
}
}
fn format(opt: opt::Opt) -> Result<i32> {
if opt.files.is_empty() {
bail!("no files provided");
}
if !opt.check
&& matches!(
opt.output_format,
opt::OutputFormat::Unified | opt::OutputFormat::Summary
)
{
bail!("--output-format=unified and --output-format=standard can only be used when --check is enabled");
}
let config = config::load_config(&opt)?;
let config = config::load_overrides(config, &opt);
debug!("config: {:#?}", config);
let range = if opt.range_start.is_some() || opt.range_end.is_some() {
Some(Range::from_values(opt.range_start, opt.range_end))
} else {
None
};
let verify_output = if opt.verify {
OutputVerification::Full
} else {
OutputVerification::None
};
let cwd = std::env::current_dir()?;
let mut walker_builder = WalkBuilder::new(&opt.files[0]);
for file_path in &opt.files[1..] {
walker_builder.add(file_path);
}
walker_builder
.standard_filters(false)
.hidden(!opt.allow_hidden)
.parents(true)
.add_custom_ignore_filename(".styluaignore");
let ignore_path = cwd.join(".styluaignore");
if ignore_path.is_file() {
walker_builder.add_ignore(ignore_path);
}
let use_default_glob = match opt.glob {
Some(ref globs) => {
let mut overrides = OverrideBuilder::new(cwd);
for pattern in globs {
overrides.add(pattern)?;
}
let overrides = overrides.build()?;
walker_builder.overrides(overrides);
false
}
None => true,
};
debug!("creating a pool with {} threads", opt.num_threads);
let pool = ThreadPool::new(std::cmp::max(opt.num_threads, 2)); let (tx, rx) = crossbeam_channel::unbounded::<Result<_>>();
let output_format = opt.output_format;
let opt = Arc::new(opt);
if matches!(opt.output_format, opt::OutputFormat::Summary) {
println!(
"{} Checking formatting...",
style("!")
.cyan()
.bold()
.force_styling(opt.color.should_use_color())
);
}
pool.execute(move || {
for output in rx {
match output {
Ok(result) => match result {
FormatResult::Complete => (),
FormatResult::SuccessBufferedOutput(output) => {
let stdout = stdout();
let mut handle = stdout.lock();
match handle.write_all(&output) {
Ok(_) => (),
Err(err) => {
error!("could not output to stdout: {:#}", err)
}
};
}
FormatResult::Diff(diff) => {
if EXIT_CODE.load(Ordering::SeqCst) != 2 {
EXIT_CODE.store(1, Ordering::SeqCst);
}
UNFORMATTED_FILE_COUNT.fetch_add(1, Ordering::SeqCst);
let stdout = stdout();
let mut handle = stdout.lock();
match handle.write_all(&diff) {
Ok(_) => (),
Err(err) => error!("{:#}", err),
}
}
},
Err(err) if matches!(output_format, opt::OutputFormat::Json) => {
match err.downcast_ref::<ErrorFileWrapper>() {
Some(ErrorFileWrapper { file, error }) => {
match error.downcast_ref::<stylua_lib::Error>() {
Some(stylua_lib::Error::ParseError(err)) => {
if let Some(structured_err) =
convert_parse_error_to_json(file, err)
{
let stderr = stderr();
let mut handle = stderr.lock();
match handle
.write_all(structured_err.to_string().as_bytes())
{
Ok(_) => (),
Err(err) => {
error!("could not output to stdout: {:#}", err)
}
};
}
}
_ => error!("{:#}", err),
}
}
_ => error!("{:#}", err),
}
}
Err(err) => error!("{:#}", err),
}
}
});
let walker = walker_builder.build();
for result in walker {
match result {
Ok(entry) => {
if entry.is_stdin() {
let tx = tx.clone();
let opt = opt.clone();
let should_skip_format = match &opt.stdin_filepath {
Some(filepath) => {
let ignore = get_ignore(
filepath.parent().expect("cannot get parent directory"),
opt.search_parent_directories,
)
.context("failed to parse ignore file")?;
matches!(ignore.matched(filepath, false), ignore::Match::Ignore(_))
}
None => false,
};
pool.execute(move || {
let mut buf = String::new();
match stdin().read_to_string(&mut buf) {
Ok(_) => tx.send(format_string(
buf,
config,
range,
&opt,
verify_output,
should_skip_format,
)),
Err(error) => tx.send(
Err(ErrorFileWrapper {
file: "stdin".to_string(),
error: error.into(),
})
.context("could not format from stdin"),
),
}
.unwrap();
});
} else {
let path = entry.path().to_owned(); let opt = opt.clone();
if path.is_file() {
if use_default_glob && !opt.files.iter().any(|p| path == *p) {
lazy_static::lazy_static! {
static ref DEFAULT_GLOB: globset::GlobSet = {
let mut builder = globset::GlobSetBuilder::new();
builder.add(globset::Glob::new("**/*.lua").expect("cannot create default glob"));
#[cfg(feature = "luau")]
builder.add(globset::Glob::new("**/*.luau").expect("cannot create default luau glob"));
builder.build().expect("cannot build default globset")
};
}
if !DEFAULT_GLOB.is_match(&path) {
continue;
}
}
let tx = tx.clone();
pool.execute(move || {
tx.send(
format_file(&path, config, range, &opt, verify_output).map_err(
|error| {
ErrorFileWrapper {
file: path.display().to_string(),
error,
}
.into()
},
),
)
.unwrap()
});
}
}
}
Err(error) => match error {
ignore::Error::WithPath { path, err } => match *err {
ignore::Error::Io(error) => match error.kind() {
std::io::ErrorKind::NotFound => {
error!("no file or directory found matching '{:#}'", path.display())
}
_ => error!("{:#}", error),
},
_ => error!("{:#}", err),
},
_ => error!("{:#}", error),
},
}
}
drop(tx);
pool.join();
if matches!(opt.output_format, opt::OutputFormat::Summary) {
let file_count = UNFORMATTED_FILE_COUNT.load(Ordering::SeqCst);
if file_count == 0 {
println!(
"{} All files are correctly formatted.",
style("✓")
.green()
.bold()
.force_styling(opt.color.should_use_color())
);
} else {
println!(
"{} Code style issues found in {} file{} above.",
style("✕")
.red()
.bold()
.force_styling(opt.color.should_use_color()),
style(file_count)
.yellow()
.bold()
.force_styling(opt.color.should_use_color()),
if file_count == 1 { "" } else { "s" }
);
}
}
let output_code = if pool.panic_count() > 0 {
2
} else {
EXIT_CODE.load(Ordering::SeqCst)
};
Ok(output_code)
}
fn main() {
let opt = opt::Opt::parse();
let output_format = opt.output_format;
let should_use_color = opt.color.should_use_color_stderr();
let level_filter = if opt.verbose {
LevelFilter::Debug
} else {
LevelFilter::Warn
};
env_logger::Builder::from_env("STYLUA_LOG")
.filter(None, level_filter)
.format(move |buf, record| {
if let Level::Error = record.level() {
EXIT_CODE.store(2, Ordering::SeqCst);
}
let tag = match record.level() {
Level::Error => style("error").red(),
Level::Warn => style("warn").yellow(),
Level::Info => style("info").green(),
Level::Debug => style("debug").cyan(),
Level::Trace => style("trace").magenta(),
}
.bold()
.force_styling(should_use_color);
if let opt::OutputFormat::Json = output_format {
writeln!(
buf,
"{}",
json!({
"type": record.level().to_string().to_lowercase(),
"message": record.args().to_string(),
})
)
} else {
writeln!(
buf,
"{}{} {}",
tag,
style(":").bold().force_styling(should_use_color),
record.args()
)
}
})
.init();
let exit_code = match format(opt) {
Ok(code) => code,
Err(err) => {
error!("{:#}", err);
2
}
};
std::process::exit(exit_code);
}
#[cfg(test)]
mod tests {
use assert_cmd::Command;
#[test]
fn test_no_files_provided() {
let mut cmd = Command::cargo_bin(env!("CARGO_PKG_NAME")).unwrap();
cmd.assert()
.failure()
.code(2)
.stderr("error: no files provided\n");
}
#[test]
fn test_format_stdin() {
let mut cmd = Command::cargo_bin(env!("CARGO_PKG_NAME")).unwrap();
cmd.arg("-")
.write_stdin("local x = 1")
.assert()
.success()
.stdout("local x = 1\n");
}
}