use crate::{
model::{Artifact, Group, Version},
parser::{get_scala_version_from_build_sbt, span::Edit, Dependency, DependencyParser, Span},
};
use anyhow::Result;
use std::{collections::HashMap, path::Path};
use std::{fs, path::PathBuf};
mod file_cache;
#[derive(Clone, Debug, PartialEq)]
pub struct Location {
pub path: PathBuf,
pub span: Span,
}
impl Location {
pub fn new(path: PathBuf, span: Span) -> Self {
Self { path, span }
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct VersionWithLocations {
pub version: Version,
pub locations: Vec<Location>,
}
impl VersionWithLocations {
pub fn new(version: &Version, location: &Location) -> Self {
Self {
version: version.clone(),
locations: vec![location.clone()],
}
}
pub fn add(&mut self, version: &Version, location: &Location) {
if version > &self.version {
self.version = version.clone();
}
self.locations.push(location.clone());
}
}
#[derive(Debug)]
pub struct DependencyMap {
map: HashMap<(Group, Artifact), VersionWithLocations>,
}
impl DependencyMap {
pub fn iter(
&self,
) -> std::collections::hash_map::Iter<(Group, Artifact), VersionWithLocations> {
self.map.iter()
}
}
impl IntoIterator for DependencyMap {
type Item = ((Group, Artifact), VersionWithLocations);
type IntoIter = std::collections::hash_map::IntoIter<(Group, Artifact), VersionWithLocations>;
fn into_iter(self) -> Self::IntoIter {
self.map.into_iter()
}
}
impl Default for DependencyMap {
fn default() -> Self {
Self::new()
}
}
impl DependencyMap {
pub fn new() -> Self {
Self {
map: HashMap::new(),
}
}
pub fn from_dependencies(dependencies: Vec<Dependency>) -> Self {
let mut map = Self::new();
for dependency in dependencies {
map.add_dependency(&dependency);
}
map
}
pub fn add_dependency(&mut self, dependency: &Dependency) {
let key = (dependency.group.clone(), dependency.artifact.clone());
let location = &dependency.version.location;
self.map
.entry(key)
.and_modify(|existing| existing.add(&dependency.version.value, location))
.or_insert_with(|| VersionWithLocations::new(&dependency.version.value, location));
}
}
pub fn collect_sbt_dependencies(project_path: &Path) -> Result<DependencyMap> {
let mut dependency_parser = DependencyParser::new();
let all_dependency_paths = all_dependency_paths(project_path);
let mut file_cache = file_cache::FileCache::new();
for path in &all_dependency_paths {
let code = file_cache.read_to_string(path)?;
dependency_parser.parse_val_defs(path, &code);
}
for path in &all_dependency_paths {
let code = file_cache.read_to_string(path)?;
dependency_parser.parse_dependencies(path, &code);
}
let mut dependencies = dependency_parser.dependencies;
let build_sbt_path = project_path.join("build.sbt");
if build_sbt_path.exists() {
let code = file_cache.read_to_string(&build_sbt_path)?;
if let Some(scala_version) = get_scala_version_from_build_sbt(&build_sbt_path, &code) {
dependencies.push(scala_version);
}
}
Ok(DependencyMap::from_dependencies(dependencies))
}
pub fn write_version_updates(updates: &[(Version, Vec<Location>)]) -> std::io::Result<()> {
let mut updates_by_file: HashMap<PathBuf, Vec<Edit>> = HashMap::new();
for (version, locations) in updates {
for location in locations {
let edit = Edit {
span: location.span.clone(),
text: format!("\"{}\"", version),
};
updates_by_file
.entry(location.path.clone())
.or_default()
.push(edit);
}
}
for (file_path, edits) in updates_by_file {
let original_content = fs::read_to_string(&file_path)?;
let updated_content = Edit::apply_edits(edits, &original_content);
fs::write(file_path, updated_content)?;
}
Ok(())
}
fn all_dependency_paths(project_path: &Path) -> Vec<PathBuf> {
let mut paths = vec![
project_path.join("build.sbt"),
project_path.join("project/plugins.sbt"),
];
collect_scala_files(project_path, &mut paths);
paths.into_iter().filter(|path| path.exists()).collect()
}
fn collect_scala_files(dir: &Path, paths: &mut Vec<PathBuf>) {
if let Ok(entries) = fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
collect_scala_files(&path, paths);
} else if path.extension().and_then(|e| e.to_str()) == Some("scala") {
paths.push(path);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn test_collect_dependencies_from_dir() {
let path = Path::new("/Users/kit/code/archive/scala-update-2/");
let result = collect_sbt_dependencies(&path);
if let Ok(deps) = result {
for (_, dep) in deps.map.iter() {
println!("{:?}", dep);
}
} else {
println!("Error: {:?}", result);
}
}
use std::fs::{self, File};
use std::io::Write;
use tempfile::tempdir;
#[test]
fn test_full_stack_version_update() -> std::io::Result<()> {
println!("Creating temporary directory and sample build.sbt and versions.scala files...");
let dir = tempdir()?;
let build_sbt_path = dir.path().join("build.sbt");
let versions_scala_path = dir.path().join("project/versions.scala");
let mut build_sbt_file = File::create(&build_sbt_path)?;
writeln!(
build_sbt_file,
r#"
import Versions._
libraryDependencies ++= Seq(
"dev.zio" %% "zio" % zio,
"io.github.kitlangton" %% "neotype" % Versions.neotype,
"org.postgresql" % "postgresql" % "42.5.1"
)
"#
)?;
println!("Sample build.sbt file created at {:?}", build_sbt_path);
fs::create_dir_all(versions_scala_path.parent().unwrap())?;
let mut versions_scala_file = File::create(&versions_scala_path)?;
writeln!(
versions_scala_file,
r#"
object Versions {{
val zio = "2.0.0"
val neotype = "0.1.0"
}}
"#
)?;
println!(
"Sample versions.scala file created at {:?}",
versions_scala_path
);
println!("Reading dependencies from the files...");
let dependencies = collect_sbt_dependencies(&dir.path());
println!("Dependencies read: {:?}", dependencies);
println!("Selecting new versions for the dependencies...");
let updates: Vec<(Version, Vec<Location>)> = dependencies
.unwrap()
.map
.iter()
.map(|(_, dep)| (Version::new("999.999.999"), dep.locations.clone()))
.collect();
println!("Updates selected: {:?}", updates);
println!("Writing updated dependencies back to the files...");
write_version_updates(&updates)?;
println!("Updated dependencies written to the files.");
println!("Verifying the updates...");
let updated_build_sbt_content = fs::read_to_string(&build_sbt_path)?;
let updated_versions_scala_content = fs::read_to_string(&versions_scala_path)?;
println!("Updated build.sbt content: {}", updated_build_sbt_content);
println!(
"Updated versions.scala content: {}",
updated_versions_scala_content
);
let expected_build_sbt_content = r#"
import Versions._
libraryDependencies ++= Seq(
"dev.zio" %% "zio" % zio,
"io.github.kitlangton" %% "neotype" % Versions.neotype,
"org.postgresql" % "postgresql" % "999.999.999"
)
"#;
let expected_versions_scala_content = r#"
object Versions {
val zio = "999.999.999"
val neotype = "999.999.999"
}
"#;
assert_eq!(
updated_build_sbt_content.trim(),
expected_build_sbt_content.trim()
);
assert_eq!(
updated_versions_scala_content.trim(),
expected_versions_scala_content.trim()
);
println!("Updates verified successfully.");
Ok(())
}
}