Skip to main content

affected_core/resolvers/
sbt.rs

1use anyhow::{Context, Result};
2use regex::Regex;
3use std::collections::HashMap;
4use std::path::Path;
5
6use crate::resolvers::{file_to_package, Resolver};
7use crate::types::{Ecosystem, Package, PackageId, ProjectGraph};
8
9/// SbtResolver detects Scala sbt multi-project builds via `build.sbt`.
10///
11/// Uses regex to parse `lazy val` project declarations and `.dependsOn()` dependency references.
12pub struct SbtResolver;
13impl super::sealed::Sealed for SbtResolver {}
14
15impl Resolver for SbtResolver {
16    fn ecosystem(&self) -> Ecosystem {
17        Ecosystem::Sbt
18    }
19
20    fn detect(&self, root: &Path) -> bool {
21        root.join("build.sbt").exists()
22    }
23
24    fn resolve(&self, root: &Path) -> Result<ProjectGraph> {
25        let build_sbt_path = root.join("build.sbt");
26        let content =
27            std::fs::read_to_string(&build_sbt_path).context("Failed to read build.sbt")?;
28
29        let projects = parse_sbt_projects(&content);
30        let dependencies = parse_sbt_dependencies(&content);
31
32        tracing::debug!("Sbt: found {} projects: {:?}", projects.len(), projects);
33
34        // Build a mapping from variable name to PackageId
35        let var_to_id: HashMap<String, PackageId> = projects
36            .iter()
37            .map(|(var_name, _)| (var_name.clone(), PackageId(var_name.clone())))
38            .collect();
39
40        let mut packages = HashMap::new();
41
42        for (var_name, dir_path) in &projects {
43            let module_dir = root.join(dir_path);
44
45            if !module_dir.exists() {
46                tracing::debug!(
47                    "Sbt: project '{}' directory '{}' does not exist, skipping",
48                    var_name,
49                    dir_path
50                );
51                continue;
52            }
53
54            let pkg_id = PackageId(var_name.clone());
55            let manifest_path = build_sbt_path.clone();
56
57            packages.insert(
58                pkg_id.clone(),
59                Package {
60                    id: pkg_id,
61                    name: var_name.clone(),
62                    version: None,
63                    path: module_dir,
64                    manifest_path,
65                },
66            );
67        }
68
69        // Build dependency edges from .dependsOn() references
70        let mut edges = Vec::new();
71
72        for (var_name, deps) in &dependencies {
73            if let Some(from_id) = var_to_id.get(var_name) {
74                // Only add edges for packages we actually resolved
75                if !packages.contains_key(from_id) {
76                    continue;
77                }
78
79                for dep_name in deps {
80                    if let Some(to_id) = var_to_id.get(dep_name) {
81                        if packages.contains_key(to_id) && to_id != from_id {
82                            edges.push((from_id.clone(), to_id.clone()));
83                        }
84                    }
85                }
86            }
87        }
88
89        Ok(ProjectGraph {
90            packages,
91            edges,
92            root: root.to_path_buf(),
93        })
94    }
95
96    fn package_for_file(&self, graph: &ProjectGraph, file: &Path) -> Option<PackageId> {
97        file_to_package(graph, file)
98    }
99
100    fn test_command(&self, package_id: &PackageId) -> Vec<String> {
101        vec!["sbt".into(), format!("{}/test", package_id.0)]
102    }
103}
104
105/// Parse `lazy val` project declarations from a `build.sbt` file.
106///
107/// Handles two forms:
108/// - `lazy val core = (project in file("core"))` -- explicit directory
109/// - `lazy val core = project` -- directory defaults to the variable name
110///
111/// Returns a vec of `(variable_name, directory_path)` tuples.
112fn parse_sbt_projects(content: &str) -> Vec<(String, String)> {
113    let mut projects = Vec::new();
114    let mut matched_vars: std::collections::HashSet<String> = std::collections::HashSet::new();
115
116    // Pattern for `lazy val foo = (project in file("bar"))`
117    let re_with_file =
118        Regex::new(r#"lazy\s+val\s+(\w+)\s*=\s*\(?\s*project\s+in\s+file\("([^"]+)"\)"#).unwrap();
119
120    // Pattern for bare `lazy val foo = project` (end of line or followed by newline + dot)
121    let re_bare_eol = Regex::new(r#"lazy\s+val\s+(\w+)\s*=\s*\(?\s*project\s*$"#).unwrap();
122    let re_bare_chain = Regex::new(r#"lazy\s+val\s+(\w+)\s*=\s*\(?\s*project\s*\n\s*\."#).unwrap();
123
124    // First, find all projects with explicit file("...") declarations
125    for cap in re_with_file.captures_iter(content) {
126        let var_name = cap[1].to_string();
127        let dir_path = cap[2].to_string();
128        matched_vars.insert(var_name.clone());
129        projects.push((var_name, dir_path));
130    }
131
132    // Find bare `lazy val foo = project` at end of line
133    for line in content.lines() {
134        let trimmed = line.trim();
135        if let Some(cap) = re_bare_eol.captures(trimmed) {
136            let var_name = cap[1].to_string();
137            if !matched_vars.contains(&var_name) {
138                matched_vars.insert(var_name.clone());
139                projects.push((var_name.clone(), var_name));
140            }
141        }
142    }
143
144    // Also handle bare project with chained calls (multiline)
145    for cap in re_bare_chain.captures_iter(content) {
146        let var_name = cap[1].to_string();
147        if !matched_vars.contains(&var_name) {
148            matched_vars.insert(var_name.clone());
149            projects.push((var_name.clone(), var_name));
150        }
151    }
152
153    projects
154}
155
156/// Parse `.dependsOn()` dependency declarations from a `build.sbt` file.
157///
158/// Handles:
159/// - `.dependsOn(a)` -- single dependency
160/// - `.dependsOn(a, b)` -- multiple comma-separated dependencies
161/// - `.dependsOn(a).dependsOn(b)` -- chained calls
162///
163/// Returns a map of variable_name -> vec of dependency variable names.
164fn parse_sbt_dependencies(content: &str) -> HashMap<String, Vec<String>> {
165    let mut deps: HashMap<String, Vec<String>> = HashMap::new();
166
167    let lazy_val_re = Regex::new(r#"lazy\s+val\s+(\w+)\s*="#).unwrap();
168    let depends_on_re = Regex::new(r#"\.dependsOn\(([^)]+)\)"#).unwrap();
169
170    // Find the byte offset of each `lazy val` declaration to split into blocks.
171    let mut block_starts: Vec<(String, usize)> = Vec::new();
172    for cap in lazy_val_re.captures_iter(content) {
173        let var_name = cap[1].to_string();
174        let start = cap.get(0).unwrap().start();
175        block_starts.push((var_name, start));
176    }
177
178    // Process each block: from this lazy val to the next (or end of string)
179    for i in 0..block_starts.len() {
180        let (ref var_name, start) = block_starts[i];
181        let end = if i + 1 < block_starts.len() {
182            block_starts[i + 1].1
183        } else {
184            content.len()
185        };
186        let block = &content[start..end];
187
188        let mut var_deps = Vec::new();
189
190        for dep_cap in depends_on_re.captures_iter(block) {
191            let dep_list = &dep_cap[1];
192            for dep in dep_list.split(',') {
193                let dep = dep.trim();
194                if !dep.is_empty() && !var_deps.contains(&dep.to_string()) {
195                    var_deps.push(dep.to_string());
196                }
197            }
198        }
199
200        if !var_deps.is_empty() {
201            deps.insert(var_name.clone(), var_deps);
202        }
203    }
204
205    deps
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn test_detect_build_sbt() {
214        let dir = tempfile::tempdir().unwrap();
215        std::fs::write(
216            dir.path().join("build.sbt"),
217            "ThisBuild / scalaVersion := \"3.3.1\"\n",
218        )
219        .unwrap();
220        assert!(SbtResolver.detect(dir.path()));
221    }
222
223    #[test]
224    fn test_detect_no_sbt() {
225        let dir = tempfile::tempdir().unwrap();
226        assert!(!SbtResolver.detect(dir.path()));
227    }
228
229    #[test]
230    fn test_parse_sbt_projects() {
231        let content = r#"
232ThisBuild / scalaVersion := "3.3.1"
233
234lazy val common = (project in file("common"))
235
236lazy val core = (project in file("core"))
237  .dependsOn(common)
238
239lazy val api = (project in file("api"))
240  .dependsOn(core, common)
241"#;
242        let projects = parse_sbt_projects(content);
243        assert_eq!(projects.len(), 3);
244        assert!(projects.contains(&("common".to_string(), "common".to_string())));
245        assert!(projects.contains(&("core".to_string(), "core".to_string())));
246        assert!(projects.contains(&("api".to_string(), "api".to_string())));
247    }
248
249    #[test]
250    fn test_parse_sbt_project_without_file() {
251        let content = "lazy val core = project\n";
252        let projects = parse_sbt_projects(content);
253        assert_eq!(projects.len(), 1);
254        assert_eq!(projects[0], ("core".to_string(), "core".to_string()));
255    }
256
257    #[test]
258    fn test_parse_sbt_depends_on() {
259        let content = r#"
260lazy val common = (project in file("common"))
261
262lazy val core = (project in file("core"))
263  .dependsOn(common)
264
265lazy val api = (project in file("api"))
266  .dependsOn(core, common)
267"#;
268        let deps = parse_sbt_dependencies(content);
269
270        assert!(!deps.contains_key("common"));
271
272        let core_deps = deps.get("core").unwrap();
273        assert_eq!(core_deps, &vec!["common".to_string()]);
274
275        let api_deps = deps.get("api").unwrap();
276        assert!(api_deps.contains(&"core".to_string()));
277        assert!(api_deps.contains(&"common".to_string()));
278        assert_eq!(api_deps.len(), 2);
279    }
280
281    #[test]
282    fn test_parse_sbt_chained_depends_on() {
283        let content = r#"
284lazy val common = (project in file("common"))
285
286lazy val core = (project in file("core"))
287
288lazy val api = (project in file("api"))
289  .dependsOn(common).dependsOn(core)
290"#;
291        let deps = parse_sbt_dependencies(content);
292        let api_deps = deps.get("api").unwrap();
293        assert!(api_deps.contains(&"common".to_string()));
294        assert!(api_deps.contains(&"core".to_string()));
295        assert_eq!(api_deps.len(), 2);
296    }
297
298    #[test]
299    fn test_resolve_sbt_project() {
300        let dir = tempfile::tempdir().unwrap();
301
302        let build_sbt = r#"
303ThisBuild / scalaVersion := "3.3.1"
304ThisBuild / version      := "0.1.0"
305
306lazy val common = (project in file("common"))
307
308lazy val core = (project in file("core"))
309  .dependsOn(common)
310
311lazy val api = (project in file("api"))
312  .dependsOn(core, common)
313
314lazy val root = (project in file("."))
315  .aggregate(common, core, api)
316"#;
317
318        std::fs::write(dir.path().join("build.sbt"), build_sbt).unwrap();
319
320        // Create project directories
321        std::fs::create_dir_all(dir.path().join("common")).unwrap();
322        std::fs::create_dir_all(dir.path().join("core")).unwrap();
323        std::fs::create_dir_all(dir.path().join("api")).unwrap();
324
325        let graph = SbtResolver.resolve(dir.path()).unwrap();
326
327        // root maps to "." which is the tempdir itself, so it should also resolve
328        assert!(graph.packages.contains_key(&PackageId("common".into())));
329        assert!(graph.packages.contains_key(&PackageId("core".into())));
330        assert!(graph.packages.contains_key(&PackageId("api".into())));
331        assert!(graph.packages.contains_key(&PackageId("root".into())));
332        assert_eq!(graph.packages.len(), 4);
333
334        // core depends on common
335        assert!(graph
336            .edges
337            .contains(&(PackageId("core".into()), PackageId("common".into()))));
338
339        // api depends on core and common
340        assert!(graph
341            .edges
342            .contains(&(PackageId("api".into()), PackageId("core".into()))));
343        assert!(graph
344            .edges
345            .contains(&(PackageId("api".into()), PackageId("common".into()))));
346    }
347
348    #[test]
349    fn test_test_command() {
350        let cmd = SbtResolver.test_command(&PackageId("core".into()));
351        assert_eq!(cmd, vec!["sbt", "core/test"]);
352    }
353}