Skip to main content

pro_core/
affected.rs

1//! Affected detection for workspace members
2//!
3//! Detects which workspace members have changed based on git diff.
4//! Useful for CI/CD pipelines to only build/test affected packages.
5//!
6//! ```bash
7//! # List affected packages
8//! rx affected
9//!
10//! # Run tests only on affected packages
11//! rx run --affected pytest
12//! ```
13
14use std::collections::{HashMap, HashSet};
15use std::path::{Path, PathBuf};
16use std::process::Command;
17
18use crate::workspace::Workspace;
19use crate::{Error, Result};
20
21/// Configuration for affected detection
22#[derive(Debug, Clone)]
23pub struct AffectedConfig {
24    /// Base ref to compare against (default: main or master)
25    pub base: String,
26    /// Head ref to compare (default: HEAD)
27    pub head: String,
28    /// Include uncommitted changes
29    pub uncommitted: bool,
30    /// Include untracked files
31    pub untracked: bool,
32}
33
34impl Default for AffectedConfig {
35    fn default() -> Self {
36        Self {
37            base: "main".to_string(),
38            head: "HEAD".to_string(),
39            uncommitted: true,
40            untracked: true,
41        }
42    }
43}
44
45impl AffectedConfig {
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    pub fn with_base(mut self, base: impl Into<String>) -> Self {
51        self.base = base.into();
52        self
53    }
54
55    pub fn with_head(mut self, head: impl Into<String>) -> Self {
56        self.head = head.into();
57        self
58    }
59}
60
61/// Result of affected detection
62#[derive(Debug, Clone)]
63pub struct AffectedResult {
64    /// Directly affected members (files changed in their directory)
65    pub direct: Vec<PathBuf>,
66    /// All affected members including transitive (depends on changed packages)
67    pub all: Vec<PathBuf>,
68    /// Changed files that triggered the detection
69    pub changed_files: Vec<PathBuf>,
70}
71
72/// Detect affected workspace members based on git changes
73pub fn detect_affected(workspace: &Workspace, config: &AffectedConfig) -> Result<AffectedResult> {
74    let root = &workspace.root;
75
76    // Get changed files
77    let changed_files = get_changed_files(root, config)?;
78
79    if changed_files.is_empty() {
80        return Ok(AffectedResult {
81            direct: Vec::new(),
82            all: Vec::new(),
83            changed_files: Vec::new(),
84        });
85    }
86
87    // Map files to members
88    let members = workspace.members();
89    let mut directly_affected: HashSet<PathBuf> = HashSet::new();
90
91    for file in &changed_files {
92        // Make file path relative to workspace root if absolute
93        let relative_file = if file.is_absolute() {
94            file.strip_prefix(root).unwrap_or(file)
95        } else {
96            file.as_path()
97        };
98
99        // Check which member this file belongs to
100        for member in members {
101            let relative_member = member.strip_prefix(root).unwrap_or(member);
102
103            if relative_file.starts_with(relative_member) {
104                directly_affected.insert(member.clone());
105                break;
106            }
107        }
108    }
109
110    // For now, all affected = directly affected
111    // Future: Add transitive dependency detection
112    let direct: Vec<PathBuf> = directly_affected.iter().cloned().collect();
113    let all = direct.clone();
114
115    Ok(AffectedResult {
116        direct,
117        all,
118        changed_files,
119    })
120}
121
122/// Get list of changed files from git
123fn get_changed_files(repo_root: &Path, config: &AffectedConfig) -> Result<Vec<PathBuf>> {
124    let mut changed_files = HashSet::new();
125
126    // Detect default branch if "main" doesn't exist
127    let base = detect_base_branch(repo_root, &config.base)?;
128
129    // Get committed changes between base and head
130    let diff_output = Command::new("git")
131        .args([
132            "diff",
133            "--name-only",
134            &format!("{}...{}", base, config.head),
135        ])
136        .current_dir(repo_root)
137        .output()
138        .map_err(|e| Error::Config(format!("Failed to run git diff: {}", e)))?;
139
140    if diff_output.status.success() {
141        let stdout = String::from_utf8_lossy(&diff_output.stdout);
142        for line in stdout.lines() {
143            if !line.is_empty() {
144                changed_files.insert(PathBuf::from(line));
145            }
146        }
147    }
148
149    // Include uncommitted changes (staged + unstaged)
150    if config.uncommitted {
151        // Staged changes
152        let staged_output = Command::new("git")
153            .args(["diff", "--name-only", "--cached"])
154            .current_dir(repo_root)
155            .output()
156            .map_err(|e| Error::Config(format!("Failed to run git diff --cached: {}", e)))?;
157
158        if staged_output.status.success() {
159            let stdout = String::from_utf8_lossy(&staged_output.stdout);
160            for line in stdout.lines() {
161                if !line.is_empty() {
162                    changed_files.insert(PathBuf::from(line));
163                }
164            }
165        }
166
167        // Unstaged changes
168        let unstaged_output = Command::new("git")
169            .args(["diff", "--name-only"])
170            .current_dir(repo_root)
171            .output()
172            .map_err(|e| Error::Config(format!("Failed to run git diff: {}", e)))?;
173
174        if unstaged_output.status.success() {
175            let stdout = String::from_utf8_lossy(&unstaged_output.stdout);
176            for line in stdout.lines() {
177                if !line.is_empty() {
178                    changed_files.insert(PathBuf::from(line));
179                }
180            }
181        }
182    }
183
184    // Include untracked files
185    if config.untracked {
186        let untracked_output = Command::new("git")
187            .args(["ls-files", "--others", "--exclude-standard"])
188            .current_dir(repo_root)
189            .output()
190            .map_err(|e| Error::Config(format!("Failed to run git ls-files: {}", e)))?;
191
192        if untracked_output.status.success() {
193            let stdout = String::from_utf8_lossy(&untracked_output.stdout);
194            for line in stdout.lines() {
195                if !line.is_empty() {
196                    changed_files.insert(PathBuf::from(line));
197                }
198            }
199        }
200    }
201
202    let mut result: Vec<PathBuf> = changed_files.into_iter().collect();
203    result.sort();
204    Ok(result)
205}
206
207/// Detect the default base branch (main, master, or specified)
208fn detect_base_branch(repo_root: &Path, preferred: &str) -> Result<String> {
209    // Check if preferred branch exists
210    let check = Command::new("git")
211        .args(["rev-parse", "--verify", preferred])
212        .current_dir(repo_root)
213        .output();
214
215    if let Ok(output) = check {
216        if output.status.success() {
217            return Ok(preferred.to_string());
218        }
219    }
220
221    // Try common alternatives
222    let alternatives = ["main", "master", "develop", "dev"];
223    for alt in alternatives {
224        if alt == preferred {
225            continue;
226        }
227
228        let check = Command::new("git")
229            .args(["rev-parse", "--verify", alt])
230            .current_dir(repo_root)
231            .output();
232
233        if let Ok(output) = check {
234            if output.status.success() {
235                return Ok(alt.to_string());
236            }
237        }
238    }
239
240    // Try to get the default branch from origin
241    let remote_check = Command::new("git")
242        .args(["symbolic-ref", "refs/remotes/origin/HEAD"])
243        .current_dir(repo_root)
244        .output();
245
246    if let Ok(output) = remote_check {
247        if output.status.success() {
248            let stdout = String::from_utf8_lossy(&output.stdout);
249            if let Some(branch) = stdout.trim().strip_prefix("refs/remotes/origin/") {
250                return Ok(branch.to_string());
251            }
252        }
253    }
254
255    // Fall back to HEAD~1 if nothing else works
256    Ok("HEAD~1".to_string())
257}
258
259/// Build dependency graph for workspace members
260/// Returns a map of member path -> paths it depends on
261pub fn build_dependency_graph(workspace: &Workspace) -> Result<HashMap<PathBuf, Vec<PathBuf>>> {
262    use crate::pep::PyProject;
263
264    let mut graph: HashMap<PathBuf, Vec<PathBuf>> = HashMap::new();
265    let members = workspace.members();
266
267    // Build a map of package names to member paths
268    let mut name_to_path: HashMap<String, PathBuf> = HashMap::new();
269    for member in members {
270        if let Ok(pyproject) = PyProject::load(member) {
271            if let Some(name) = pyproject.name() {
272                name_to_path.insert(name.to_lowercase(), member.clone());
273                // Also add with underscores converted to dashes and vice versa
274                name_to_path.insert(name.to_lowercase().replace('_', "-"), member.clone());
275                name_to_path.insert(name.to_lowercase().replace('-', "_"), member.clone());
276            }
277        }
278    }
279
280    // Build the dependency graph
281    for member in members {
282        let mut deps = Vec::new();
283
284        if let Ok(pyproject) = PyProject::load(member) {
285            // Check regular dependencies
286            for dep in pyproject.dependencies() {
287                let dep_name = dep
288                    .split(|c: char| !c.is_alphanumeric() && c != '-' && c != '_')
289                    .next()
290                    .unwrap_or("")
291                    .to_lowercase();
292
293                if let Some(dep_path) = name_to_path.get(&dep_name) {
294                    if dep_path != member {
295                        deps.push(dep_path.clone());
296                    }
297                }
298            }
299
300            // Check dev dependencies
301            for dep in pyproject.dev_dependencies() {
302                let dep_name = dep
303                    .split(|c: char| !c.is_alphanumeric() && c != '-' && c != '_')
304                    .next()
305                    .unwrap_or("")
306                    .to_lowercase();
307
308                if let Some(dep_path) = name_to_path.get(&dep_name) {
309                    if dep_path != member {
310                        deps.push(dep_path.clone());
311                    }
312                }
313            }
314
315            // Check path dependencies
316            if let Ok(path_deps) = crate::load_path_dependencies(member) {
317                for (_name, path_dep) in path_deps {
318                    let resolved = path_dep.resolve_path(member);
319                    if members.contains(&resolved) && &resolved != member {
320                        deps.push(resolved);
321                    }
322                }
323            }
324        }
325
326        graph.insert(member.clone(), deps);
327    }
328
329    Ok(graph)
330}
331
332/// Get transitively affected members
333/// If A depends on B, and B changed, then A is also affected
334pub fn get_transitive_affected(
335    directly_affected: &[PathBuf],
336    dep_graph: &HashMap<PathBuf, Vec<PathBuf>>,
337) -> Vec<PathBuf> {
338    let mut all_affected: HashSet<PathBuf> = directly_affected.iter().cloned().collect();
339    let mut changed = true;
340
341    // Keep iterating until no new packages are found
342    while changed {
343        changed = false;
344        let current_affected: Vec<PathBuf> = all_affected.iter().cloned().collect();
345
346        for (member, deps) in dep_graph {
347            if all_affected.contains(member) {
348                continue;
349            }
350
351            // If this member depends on any affected package, it's also affected
352            for dep in deps {
353                if current_affected.contains(dep) {
354                    all_affected.insert(member.clone());
355                    changed = true;
356                    break;
357                }
358            }
359        }
360    }
361
362    let mut result: Vec<PathBuf> = all_affected.into_iter().collect();
363    result.sort();
364    result
365}
366
367/// Detect affected with transitive dependencies
368pub fn detect_affected_with_transitive(
369    workspace: &Workspace,
370    config: &AffectedConfig,
371) -> Result<AffectedResult> {
372    let mut result = detect_affected(workspace, config)?;
373
374    if !result.direct.is_empty() {
375        let dep_graph = build_dependency_graph(workspace)?;
376        result.all = get_transitive_affected(&result.direct, &dep_graph);
377    }
378
379    Ok(result)
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn test_affected_config_default() {
388        let config = AffectedConfig::default();
389        assert_eq!(config.base, "main");
390        assert_eq!(config.head, "HEAD");
391        assert!(config.uncommitted);
392        assert!(config.untracked);
393    }
394
395    #[test]
396    fn test_affected_config_builder() {
397        let config = AffectedConfig::new()
398            .with_base("develop")
399            .with_head("feature-branch");
400
401        assert_eq!(config.base, "develop");
402        assert_eq!(config.head, "feature-branch");
403    }
404
405    #[test]
406    fn test_transitive_affected() {
407        let mut graph = HashMap::new();
408
409        // A depends on B
410        // B depends on C
411        // D has no dependencies
412        graph.insert(
413            PathBuf::from("/workspace/a"),
414            vec![PathBuf::from("/workspace/b")],
415        );
416        graph.insert(
417            PathBuf::from("/workspace/b"),
418            vec![PathBuf::from("/workspace/c")],
419        );
420        graph.insert(PathBuf::from("/workspace/c"), vec![]);
421        graph.insert(PathBuf::from("/workspace/d"), vec![]);
422
423        // If C changed, A and B should also be affected
424        let directly_affected = vec![PathBuf::from("/workspace/c")];
425        let all_affected = get_transitive_affected(&directly_affected, &graph);
426
427        assert!(all_affected.contains(&PathBuf::from("/workspace/a")));
428        assert!(all_affected.contains(&PathBuf::from("/workspace/b")));
429        assert!(all_affected.contains(&PathBuf::from("/workspace/c")));
430        assert!(!all_affected.contains(&PathBuf::from("/workspace/d")));
431    }
432}