use crate::compile::flat_package_src_dirs;
use anyhow::{bail, Context, Result};
use curie_deps::resolver::{resolve, ResolveOptions};
use std::path::{Path, PathBuf};
use std::process::Command;
use walkdir::WalkDir;
const PJF_COORD: &str = "com.palantir.javaformat:palantir-java-format";
const PJF_VERSION: &str = "2.90.0";
const PJF_MAIN: &str = "com.palantir.javaformat.java.Main";
pub fn resolve_pjf(offline: bool) -> Result<Vec<PathBuf>> {
resolve(
&[(PJF_COORD, PJF_VERSION)],
&ResolveOptions {
extra_repos: vec![],
progress: false,
bom_imports: vec![],
offline,
},
)
.context("failed to resolve palantir-java-format from Maven Central")
}
pub fn run_fmt(project_root: &Path, check_only: bool, offline: bool) -> Result<()> {
let pjf_jars = resolve_pjf(offline)?;
run_fmt_with_jars(project_root, check_only, &pjf_jars)
}
pub fn run_fmt_with_jars(
project_root: &Path,
check_only: bool,
pjf_jars: &[PathBuf],
) -> Result<()> {
let java_files = collect_java_files(project_root);
if java_files.is_empty() {
return Ok(());
}
let cp = pjf_jars
.iter()
.map(|p| p.to_string_lossy())
.collect::<Vec<_>>()
.join(":");
let mut cmd = Command::new("java");
for flag in jvm_add_exports() {
cmd.arg(flag);
}
cmd.arg("-cp").arg(&cp).arg(PJF_MAIN);
cmd.arg("--aosp");
if check_only {
cmd.args(["--dry-run", "--set-exit-if-changed"]);
} else {
cmd.arg("--replace");
}
for f in &java_files {
cmd.arg(f);
}
let status = cmd
.status()
.context("failed to launch `java` — is a JDK installed and on PATH?")?;
if !status.success() {
if check_only {
bail!(
"fmt: one or more files are not correctly formatted. \
Run `curie fmt` (without --check) to fix them."
);
} else {
bail!("palantir-java-format exited with status {}", status);
}
}
Ok(())
}
pub(crate) fn collect_java_files(project_root: &Path) -> Vec<PathBuf> {
let mut roots: Vec<PathBuf> = Vec::new();
let main_java = project_root.join("src").join("main").join("java");
if main_java.exists() {
roots.push(main_java);
}
let test_java = project_root.join("src").join("test").join("java");
if test_java.exists() {
roots.push(test_java);
}
roots.extend(flat_package_src_dirs(project_root));
let mut files: Vec<PathBuf> = roots
.iter()
.flat_map(|root| {
WalkDir::new(root)
.into_iter()
.filter_map(|e| e.ok())
.filter(|e| {
e.file_type().is_file()
&& e.path().extension().map_or(false, |x| x == "java")
})
.map(|e| e.into_path())
})
.collect();
files.sort();
files
}
pub(crate) fn jvm_add_exports() -> Vec<String> {
let packages = [
"com.sun.tools.javac.api",
"com.sun.tools.javac.code",
"com.sun.tools.javac.file",
"com.sun.tools.javac.main",
"com.sun.tools.javac.parser",
"com.sun.tools.javac.tree",
"com.sun.tools.javac.util",
];
packages
.iter()
.map(|p| format!("--add-exports=jdk.compiler/{}=ALL-UNNAMED", p))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[test]
fn jvm_add_exports_covers_required_packages() {
let flags = jvm_add_exports();
let required = [
"com.sun.tools.javac.api",
"com.sun.tools.javac.code",
"com.sun.tools.javac.file",
"com.sun.tools.javac.main",
"com.sun.tools.javac.parser",
"com.sun.tools.javac.tree",
"com.sun.tools.javac.util",
];
for pkg in required {
let needle = format!("jdk.compiler/{}=ALL-UNNAMED", pkg);
assert!(
flags.iter().any(|f| f.contains(&needle)),
"missing --add-exports for {pkg}"
);
}
}
#[test]
fn jvm_add_exports_all_start_with_flag() {
for flag in jvm_add_exports() {
assert!(
flag.starts_with("--add-exports="),
"unexpected flag format: {flag}"
);
}
}
#[test]
fn collect_java_files_empty_project() {
let tmp = TempDir::new().unwrap();
let files = collect_java_files(tmp.path());
assert!(files.is_empty(), "expected no files in empty project");
}
#[test]
fn collect_java_files_maven_layout() {
let tmp = TempDir::new().unwrap();
let main_java = tmp.path().join("src").join("main").join("java");
let test_java = tmp.path().join("src").join("test").join("java");
fs::create_dir_all(&main_java).unwrap();
fs::create_dir_all(&test_java).unwrap();
fs::write(main_java.join("Foo.java"), "class Foo {}").unwrap();
fs::write(main_java.join("Bar.java"), "class Bar {}").unwrap();
fs::write(test_java.join("FooTest.java"), "class FooTest {}").unwrap();
fs::write(main_java.join("README.txt"), "docs").unwrap();
let files = collect_java_files(tmp.path());
assert_eq!(files.len(), 3, "expected 3 .java files, got {:?}", files);
for f in &files {
assert_eq!(f.extension().unwrap(), "java");
}
}
#[test]
fn collect_java_files_returns_sorted() {
let tmp = TempDir::new().unwrap();
let src = tmp.path().join("src").join("main").join("java");
fs::create_dir_all(&src).unwrap();
fs::write(src.join("Zoo.java"), "class Zoo {}").unwrap();
fs::write(src.join("Alpha.java"), "class Alpha {}").unwrap();
fs::write(src.join("Mango.java"), "class Mango {}").unwrap();
let files = collect_java_files(tmp.path());
let names: Vec<_> = files
.iter()
.map(|p| p.file_name().unwrap().to_str().unwrap())
.collect();
let mut sorted = names.clone();
sorted.sort();
assert_eq!(names, sorted, "files should be returned sorted");
}
#[test]
fn collect_java_files_recursive() {
let tmp = TempDir::new().unwrap();
let pkg = tmp
.path()
.join("src")
.join("main")
.join("java")
.join("com")
.join("example");
fs::create_dir_all(&pkg).unwrap();
fs::write(pkg.join("Deep.java"), "class Deep {}").unwrap();
let files = collect_java_files(tmp.path());
assert_eq!(files.len(), 1);
assert!(files[0].ends_with("Deep.java"));
}
}