use crate::compile::{flat_package_src_dirs, flat_package_test_dirs};
use anyhow::{bail, Context, Result};
use crate::build::central_repos;
use curie_deps::resolver::{resolve, DepEntry, ResolveOptions};
use std::path::{Path, PathBuf};
use std::process::Command;
use walkdir::WalkDir;
const PJF_COORD: &str = "com.palantir.javaformat:palantir-java-format";
const PJF_VERSION: &str = "2.90.0";
const PJF_MAIN: &str = "com.palantir.javaformat.java.Main";
const KTFMT_COORD: &str = "com.facebook:ktfmt";
const KTFMT_VERSION: &str = "0.51";
const KTFMT_MAIN: &str = "com.facebook.ktfmt.cli.Main";
pub fn resolve_pjf(offline: bool) -> Result<Vec<PathBuf>> {
resolve(
&[DepEntry { key: PJF_COORD, version: PJF_VERSION, repo_id: None }],
&ResolveOptions {
default_repos: central_repos(),
named_repos: vec![],
progress: false,
bom_imports: vec![],
offline,
},
)
.context("failed to resolve palantir-java-format from Maven Central")
}
pub fn resolve_ktfmt(offline: bool) -> Result<Vec<PathBuf>> {
resolve(
&[DepEntry { key: KTFMT_COORD, version: KTFMT_VERSION, repo_id: None }],
&ResolveOptions {
default_repos: central_repos(),
named_repos: vec![],
progress: false,
bom_imports: vec![],
offline,
},
)
.context("failed to resolve ktfmt from Maven Central")
}
pub fn has_kotlin_sources(project_root: &Path) -> bool {
kotlin_source_roots(project_root).into_iter().any(|root| {
WalkDir::new(root)
.into_iter()
.filter_map(|e| e.ok())
.any(|e| {
e.file_type().is_file()
&& e.path().extension().map_or(false, |x| x == "kt")
})
})
}
pub fn run_fmt(project_root: &Path, check_only: bool, offline: bool) -> Result<()> {
let pjf_jars = resolve_pjf(offline)?;
let ktfmt_jars = if has_kotlin_sources(project_root) {
resolve_ktfmt(offline)?
} else {
Vec::new()
};
run_fmt_with_jars(project_root, check_only, &pjf_jars, &ktfmt_jars)
}
pub fn run_fmt_with_jars(
project_root: &Path,
check_only: bool,
pjf_jars: &[PathBuf],
ktfmt_jars: &[PathBuf],
) -> Result<()> {
let java_files = collect_java_files(project_root);
let kotlin_files = if ktfmt_jars.is_empty() {
vec![]
} else {
collect_kotlin_files(project_root)
};
if java_files.is_empty() && kotlin_files.is_empty() {
return Ok(());
}
let file_summary = match (java_files.len(), kotlin_files.len()) {
(j, 0) => format!("{j} Java file(s)"),
(0, k) => format!("{k} Kotlin file(s)"),
(j, k) => format!("{j} Java, {k} Kotlin file(s)"),
};
let action = if check_only { "Check" } else { "Format" };
crate::parallel::emit(&crate::style::fmt_step(action, &file_summary));
let java_err = if !java_files.is_empty() {
fmt_java(&java_files, pjf_jars, check_only).err()
} else {
None
};
let kotlin_err = if !kotlin_files.is_empty() {
fmt_kotlin(&kotlin_files, ktfmt_jars, check_only).err()
} else {
None
};
match (java_err, kotlin_err) {
(None, None) => Ok(()),
(Some(e), None) | (None, Some(e)) => Err(e),
(Some(je), Some(ke)) => bail!("{:#}\n{:#}", je, ke),
}
}
struct FormatterSpec {
main_class: &'static str,
jvm_flags: Vec<String>,
reformat_args: &'static [&'static str],
check_args: &'static [&'static str],
name: &'static str,
language: &'static str,
}
fn run_formatter(
files: &[PathBuf],
jars: &[PathBuf],
check_only: bool,
spec: &FormatterSpec,
) -> Result<()> {
let cp = classpath(jars);
let mut cmd = Command::new("java");
for flag in &spec.jvm_flags {
cmd.arg(flag);
}
cmd.arg("-cp").arg(&cp).arg(spec.main_class);
if check_only {
for arg in spec.check_args {
cmd.arg(arg);
}
} else {
for arg in spec.reformat_args {
cmd.arg(arg);
}
}
for f in files {
cmd.arg(f);
}
let status = crate::proc::spawn_cmd(&mut cmd)
.context("failed to launch `java` — is a JDK installed and on PATH?")?;
if !status.success() {
if check_only {
bail!(
"fmt: one or more {} files are not correctly formatted. \
Run `curie fmt` (without --check) to fix them.",
spec.language
);
} else {
bail!("{} exited non-zero", spec.name);
}
}
Ok(())
}
fn fmt_java(java_files: &[PathBuf], pjf_jars: &[PathBuf], check_only: bool) -> Result<()> {
run_formatter(java_files, pjf_jars, check_only, &FormatterSpec {
main_class: PJF_MAIN,
jvm_flags: jvm_add_exports(),
reformat_args: &["--aosp", "--replace"],
check_args: &["--aosp", "--dry-run", "--set-exit-if-changed"],
name: "palantir-java-format",
language: "Java",
})
}
fn fmt_kotlin(kotlin_files: &[PathBuf], ktfmt_jars: &[PathBuf], check_only: bool) -> Result<()> {
run_formatter(kotlin_files, ktfmt_jars, check_only, &FormatterSpec {
main_class: KTFMT_MAIN,
jvm_flags: vec!["--enable-native-access=ALL-UNNAMED".to_string()],
reformat_args: &["--kotlinlang-style"],
check_args: &["--kotlinlang-style", "--dry-run", "--set-exit-if-changed"],
name: "ktfmt",
language: "Kotlin",
})
}
fn classpath(jars: &[PathBuf]) -> String {
jars.iter()
.map(|p| p.to_string_lossy())
.collect::<Vec<_>>()
.join(":")
}
pub(crate) fn collect_java_files(project_root: &Path) -> Vec<PathBuf> {
let mut roots: Vec<PathBuf> = Vec::new();
let main_java = project_root.join("src").join("main").join("java");
if main_java.exists() {
roots.push(main_java);
}
let test_java = project_root.join("src").join("test").join("java");
if test_java.exists() {
roots.push(test_java);
}
roots.extend(flat_package_src_dirs(project_root));
let mut files: Vec<PathBuf> = roots
.iter()
.flat_map(|root| {
WalkDir::new(root)
.into_iter()
.filter_map(|e| e.ok())
.filter(|e| {
e.file_type().is_file()
&& e.path().extension().map_or(false, |x| x == "java")
})
.map(|e| e.into_path())
})
.collect();
files.sort();
files
}
pub(crate) fn collect_kotlin_files(project_root: &Path) -> Vec<PathBuf> {
let mut files: Vec<PathBuf> = kotlin_source_roots(project_root)
.iter()
.flat_map(|root| {
WalkDir::new(root)
.into_iter()
.filter_map(|e| e.ok())
.filter(|e| {
e.file_type().is_file()
&& e.path().extension().map_or(false, |x| x == "kt")
})
.map(|e| e.into_path())
})
.collect();
files.sort();
files.dedup();
files
}
fn kotlin_source_roots(project_root: &Path) -> Vec<PathBuf> {
let mut roots: Vec<PathBuf> = Vec::new();
let main_kotlin = project_root.join("src").join("main").join("kotlin");
if main_kotlin.exists() {
roots.push(main_kotlin);
}
let test_kotlin = project_root.join("src").join("test").join("kotlin");
if test_kotlin.exists() {
roots.push(test_kotlin);
}
roots.extend(flat_package_src_dirs(project_root));
roots.extend(flat_package_test_dirs(project_root));
roots
}
pub(crate) fn jvm_add_exports() -> Vec<String> {
let packages = [
"com.sun.tools.javac.api",
"com.sun.tools.javac.code",
"com.sun.tools.javac.file",
"com.sun.tools.javac.main",
"com.sun.tools.javac.parser",
"com.sun.tools.javac.tree",
"com.sun.tools.javac.util",
];
packages
.iter()
.map(|p| format!("--add-exports=jdk.compiler/{}=ALL-UNNAMED", p))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[test]
fn jvm_add_exports_covers_required_packages() {
let flags = jvm_add_exports();
let required = [
"com.sun.tools.javac.api",
"com.sun.tools.javac.code",
"com.sun.tools.javac.file",
"com.sun.tools.javac.main",
"com.sun.tools.javac.parser",
"com.sun.tools.javac.tree",
"com.sun.tools.javac.util",
];
for pkg in required {
let needle = format!("jdk.compiler/{}=ALL-UNNAMED", pkg);
assert!(
flags.iter().any(|f| f.contains(&needle)),
"missing --add-exports for {pkg}"
);
}
}
#[test]
fn jvm_add_exports_all_start_with_flag() {
for flag in jvm_add_exports() {
assert!(
flag.starts_with("--add-exports="),
"unexpected flag format: {flag}"
);
}
}
#[test]
fn collect_java_files_empty_project() {
let tmp = TempDir::new().unwrap();
let files = collect_java_files(tmp.path());
assert!(files.is_empty(), "expected no files in empty project");
}
#[test]
fn collect_java_files_maven_layout() {
let tmp = TempDir::new().unwrap();
let main_java = tmp.path().join("src").join("main").join("java");
let test_java = tmp.path().join("src").join("test").join("java");
fs::create_dir_all(&main_java).unwrap();
fs::create_dir_all(&test_java).unwrap();
fs::write(main_java.join("Foo.java"), "class Foo {}").unwrap();
fs::write(main_java.join("Bar.java"), "class Bar {}").unwrap();
fs::write(test_java.join("FooTest.java"), "class FooTest {}").unwrap();
fs::write(main_java.join("README.txt"), "docs").unwrap();
let files = collect_java_files(tmp.path());
assert_eq!(files.len(), 3, "expected 3 .java files, got {:?}", files);
for f in &files {
assert_eq!(f.extension().unwrap(), "java");
}
}
#[test]
fn collect_java_files_returns_sorted() {
let tmp = TempDir::new().unwrap();
let src = tmp.path().join("src").join("main").join("java");
fs::create_dir_all(&src).unwrap();
fs::write(src.join("Zoo.java"), "class Zoo {}").unwrap();
fs::write(src.join("Alpha.java"), "class Alpha {}").unwrap();
fs::write(src.join("Mango.java"), "class Mango {}").unwrap();
let files = collect_java_files(tmp.path());
let names: Vec<_> = files
.iter()
.map(|p| p.file_name().unwrap().to_str().unwrap())
.collect();
let mut sorted = names.clone();
sorted.sort();
assert_eq!(names, sorted, "files should be returned sorted");
}
#[test]
fn collect_java_files_recursive() {
let tmp = TempDir::new().unwrap();
let pkg = tmp
.path()
.join("src")
.join("main")
.join("java")
.join("com")
.join("example");
fs::create_dir_all(&pkg).unwrap();
fs::write(pkg.join("Deep.java"), "class Deep {}").unwrap();
let files = collect_java_files(tmp.path());
assert_eq!(files.len(), 1);
assert!(files[0].ends_with("Deep.java"));
}
#[test]
fn collect_kotlin_files_empty_project() {
let tmp = TempDir::new().unwrap();
assert!(collect_kotlin_files(tmp.path()).is_empty());
}
#[test]
fn collect_kotlin_files_maven_layout() {
let tmp = TempDir::new().unwrap();
let main_kt = tmp.path().join("src").join("main").join("kotlin");
let test_kt = tmp.path().join("src").join("test").join("kotlin");
fs::create_dir_all(&main_kt).unwrap();
fs::create_dir_all(&test_kt).unwrap();
fs::write(main_kt.join("Greeting.kt"), "class Greeting").unwrap();
fs::write(test_kt.join("GreetingTest.kt"), "class GreetingTest").unwrap();
fs::write(main_kt.join("notes.txt"), "docs").unwrap();
let files = collect_kotlin_files(tmp.path());
assert_eq!(files.len(), 2, "expected 2 .kt files, got {:?}", files);
for f in &files {
assert_eq!(f.extension().unwrap(), "kt");
}
}
#[test]
fn collect_kotlin_files_flat_package() {
let tmp = TempDir::new().unwrap();
let pkg = tmp.path().join("src").join("com.example.mixed");
fs::create_dir_all(&pkg).unwrap();
fs::write(pkg.join("Greeting.kt"), "class Greeting").unwrap();
fs::write(pkg.join("Main.java"), "class Main {}").unwrap();
let files = collect_kotlin_files(tmp.path());
assert_eq!(files.len(), 1);
assert!(files[0].ends_with("Greeting.kt"));
}
#[test]
fn collect_kotlin_files_includes_test_files() {
let tmp = TempDir::new().unwrap();
let pkg = tmp.path().join("src").join("com.example");
fs::create_dir_all(&pkg).unwrap();
fs::write(pkg.join("Foo.kt"), "class Foo").unwrap();
fs::write(pkg.join("FooTest.kt"), "class FooTest").unwrap();
fs::write(pkg.join("FooSpec.kt"), "class FooSpec").unwrap();
let files = collect_kotlin_files(tmp.path());
assert_eq!(files.len(), 3, "test/spec files must be included in fmt: {:?}", files);
}
#[test]
fn collect_kotlin_files_includes_tests_dir() {
let tmp = TempDir::new().unwrap();
let tests_pkg = tmp.path().join("tests").join("com.example");
fs::create_dir_all(&tests_pkg).unwrap();
fs::write(tests_pkg.join("IntTest.kt"), "class IntTest").unwrap();
let files = collect_kotlin_files(tmp.path());
assert_eq!(files.len(), 1);
assert!(files[0].ends_with("IntTest.kt"));
}
#[test]
fn collect_kotlin_files_returns_sorted() {
let tmp = TempDir::new().unwrap();
let src = tmp.path().join("src").join("main").join("kotlin");
fs::create_dir_all(&src).unwrap();
fs::write(src.join("Zoo.kt"), "class Zoo").unwrap();
fs::write(src.join("Alpha.kt"), "class Alpha").unwrap();
let files = collect_kotlin_files(tmp.path());
let names: Vec<_> = files
.iter()
.map(|p| p.file_name().unwrap().to_str().unwrap())
.collect();
let mut sorted = names.clone();
sorted.sort();
assert_eq!(names, sorted);
}
#[test]
fn has_kotlin_sources_false_for_empty_project() {
let tmp = TempDir::new().unwrap();
assert!(!has_kotlin_sources(tmp.path()));
}
#[test]
fn has_kotlin_sources_false_for_java_only() {
let tmp = TempDir::new().unwrap();
let src = tmp.path().join("src").join("main").join("java");
fs::create_dir_all(&src).unwrap();
fs::write(src.join("Foo.java"), "class Foo {}").unwrap();
assert!(!has_kotlin_sources(tmp.path()));
}
#[test]
fn has_kotlin_sources_true_when_kt_present() {
let tmp = TempDir::new().unwrap();
let src = tmp.path().join("src").join("main").join("kotlin");
fs::create_dir_all(&src).unwrap();
fs::write(src.join("Greeting.kt"), "class Greeting").unwrap();
assert!(has_kotlin_sources(tmp.path()));
}
#[test]
fn has_kotlin_sources_true_for_flat_package_kt() {
let tmp = TempDir::new().unwrap();
let pkg = tmp.path().join("src").join("com.example");
fs::create_dir_all(&pkg).unwrap();
fs::write(pkg.join("App.kt"), "fun main() {}").unwrap();
assert!(has_kotlin_sources(tmp.path()));
}
}