Skip to main content

pcu/
updater.rs

1use check_updates_core::{DependencyCheck, UpdateSeverity};
2use crate::detector::PackageManager;
3use anyhow::{Context, Result};
4use std::collections::{HashSet, HashMap};
5use std::path::{Path, PathBuf};
6use std::fs;
7
8/// Updates dependency files with new versions
9pub struct FileUpdater;
10
11impl Default for FileUpdater {
12    fn default() -> Self {
13        Self::new()
14    }
15}
16
17impl FileUpdater {
18    pub fn new() -> Self {
19        Self
20    }
21
22    /// Apply updates to dependency files based on severity filter
23    /// - include_minor: false = patch only, true = patch + minor
24    /// - force: true = all severities AND use absolute latest version
25    pub fn apply_updates(
26        &self,
27        checks: &[DependencyCheck],
28        include_minor: bool,
29        force: bool,
30    ) -> Result<UpdateResult> {
31        let mut modified_files = HashSet::new();
32        let mut package_file_map: HashMap<String, Vec<PathBuf>> = HashMap::new();
33        let mut package_managers = HashSet::new();
34
35        // Group checks by file, filtering by severity
36        let mut file_updates: HashMap<PathBuf, Vec<(&DependencyCheck, String)>> = HashMap::new();
37
38        for check in checks {
39            // Determine which version spec to use
40            let version_spec = if force {
41                // Force mode: use absolute latest for all packages
42                check.force_spec.as_ref()
43            } else {
44                // Normal mode: filter by severity and use target_spec
45                match check.severity {
46                    Some(UpdateSeverity::Patch) => check.target_spec.as_ref(),
47                    Some(UpdateSeverity::Minor) if include_minor => check.target_spec.as_ref(),
48                    _ => None, // Skip major updates and minor (if not included)
49                }
50            };
51
52            if let Some(spec) = version_spec
53                && spec.is_rewritable() {
54                let new_version = spec.to_string();
55                file_updates
56                    .entry(check.dependency.source_file.clone())
57                    .or_default()
58                    .push((check, new_version));
59
60                // Track which packages appear in which files
61                package_file_map
62                    .entry(check.dependency.name.clone())
63                    .or_default()
64                    .push(check.dependency.source_file.clone());
65            }
66        }
67
68        // Update each file
69        for (file_path, updates) in file_updates {
70            self.update_file(&file_path, &updates)
71                .with_context(|| format!("Failed to update file: {}", file_path.display()))?;
72
73            modified_files.insert(file_path.clone());
74
75            // Detect package manager from file name
76            if let Some(pm) = detect_package_manager(&file_path) {
77                package_managers.insert(pm);
78            }
79        }
80
81        // Find packages updated in multiple files
82        let mut multi_file_packages: Vec<String> = package_file_map
83            .iter()
84            .filter_map(|(pkg, files)| {
85                let unique_files: HashSet<_> = files.iter().collect();
86                if unique_files.len() > 1 {
87                    Some(pkg.clone())
88                } else {
89                    None
90                }
91            })
92            .collect();
93        multi_file_packages.sort();
94
95        Ok(UpdateResult {
96            modified_files,
97            multi_file_packages,
98            package_managers,
99        })
100    }
101
102    /// Update a single file with the given dependency updates
103    fn update_file(&self, file_path: &Path, updates: &[(&DependencyCheck, String)]) -> Result<()> {
104        // Read the entire file
105        let content = fs::read_to_string(file_path)
106            .with_context(|| format!("Failed to read file: {}", file_path.display()))?;
107
108        let mut lines: Vec<String> = content.lines().map(std::string::ToString::to_string).collect();
109
110        // Sort updates by line number in descending order to avoid offset issues
111        let mut sorted_updates: Vec<_> = updates.iter().collect();
112        sorted_updates.sort_by_key(|x| std::cmp::Reverse(x.0.dependency.line_number));
113
114        // Apply each update
115        for (check, new_version) in sorted_updates {
116            let line_idx = check.dependency.line_number.saturating_sub(1);
117
118            if line_idx >= lines.len() {
119                continue; // Skip if line number is out of bounds
120            }
121
122            let original_line = &lines[line_idx];
123
124            let updated_line = self.replace_version_in_line(
125                original_line,
126                &check.dependency.name,
127                &check.dependency.version_spec.to_string(),
128                new_version,
129                file_path,
130            )?;
131
132            lines[line_idx] = updated_line;
133        }
134
135        // Write the file back
136        let new_content = lines.join("\n");
137        // Add trailing newline if original had one
138        let new_content = if content.ends_with('\n') {
139            format!("{new_content}\n")
140        } else {
141            new_content
142        };
143
144        fs::write(file_path, new_content)
145            .with_context(|| format!("Failed to write file: {}", file_path.display()))?;
146
147        Ok(())
148    }
149
150    /// Replace version specification in a line
151    fn replace_version_in_line(
152        &self,
153        line: &str,
154        package_name: &str,
155        old_spec: &str,
156        new_spec: &str,
157        file_path: &Path,
158    ) -> Result<String> {
159        let file_name = file_path.file_name()
160            .and_then(|n| n.to_str())
161            .unwrap_or("");
162
163        // Determine file type and use appropriate replacement strategy
164        if file_name.starts_with("requirements") || file_name.ends_with(".txt") {
165            self.replace_in_requirements(line, package_name, old_spec, new_spec)
166        } else if file_name == "pyproject.toml" {
167            self.replace_in_pyproject(line, package_name, old_spec, new_spec)
168        } else if file_name.starts_with("environment.") &&
169                  (file_name.ends_with(".yml") || file_name.ends_with(".yaml")) {
170            self.replace_in_conda(line, package_name, old_spec, new_spec)
171        } else {
172            // Default to requirements.txt style
173            self.replace_in_requirements(line, package_name, old_spec, new_spec)
174        }
175    }
176
177    /// Replace version in requirements.txt format
178    fn replace_in_requirements(
179        &self,
180        line: &str,
181        package_name: &str,
182        old_spec: &str,
183        new_spec: &str,
184    ) -> Result<String> {
185        // Format: package==1.0.0 or package>=1.0.0,<2.0.0 or package[extras]==1.0.0
186
187        // Try exact match first
188        if let Some(new_line) = line.replace(&format!("{package_name}{old_spec}"),
189                                              &format!("{package_name}{new_spec}"))
190                                    .into()
191            && new_line != line {
192                return Ok(new_line);
193            }
194
195        // Try with brackets (extras)
196        if line.contains('[')
197            && let Some(bracket_start) = line.find('[')
198                && let Some(bracket_end) = line.find(']') {
199                    let before_bracket = &line[..bracket_start];
200                    let extras = &line[bracket_start..=bracket_end];
201                    let after_bracket = &line[bracket_end + 1..];
202
203                    if before_bracket.trim() == package_name {
204                        let new_after = after_bracket.replace(old_spec, new_spec);
205                        return Ok(format!("{before_bracket}{extras}{new_after}"));
206                    }
207                }
208
209        // Fallback: simple string replacement
210        Ok(line.replace(old_spec, new_spec))
211    }
212
213    /// Replace version in pyproject.toml format
214    fn replace_in_pyproject(
215        &self,
216        line: &str,
217        package_name: &str,
218        old_spec: &str,
219        new_spec: &str,
220    ) -> Result<String> {
221        // Format: package = "^1.0.0" or package = {version = "^1.0.0", ...}
222
223        // Check if line contains the package name (case-insensitive for TOML keys)
224        if line.to_lowercase().contains(&package_name.to_lowercase()) {
225            // Replace the version spec, preserving quotes
226            let result = line.replace(
227                &format!("\"{old_spec}\""),
228                &format!("\"{new_spec}\"")
229            );
230            if result != line {
231                return Ok(result);
232            }
233
234            // Try single quotes
235            let result = line.replace(
236                &format!("'{old_spec}'"),
237                &format!("'{new_spec}'")
238            );
239            if result != line {
240                return Ok(result);
241            }
242        }
243
244        // Fallback
245        Ok(line.replace(old_spec, new_spec))
246    }
247
248    /// Replace version in conda environment.yml format
249    fn replace_in_conda(
250        &self,
251        line: &str,
252        package_name: &str,
253        old_spec: &str,
254        new_spec: &str,
255    ) -> Result<String> {
256        // Format: - package==1.0.0 or - package=1.0.0
257
258        // Conda uses = instead of == sometimes
259        let conda_old_spec = old_spec.replace("==", "=");
260        let conda_new_spec = new_spec.replace("==", "=");
261
262        // Try with ==
263        let result = line.replace(
264            &format!("{package_name}{old_spec}"),
265            &format!("{package_name}{new_spec}")
266        );
267        if result != line {
268            return Ok(result);
269        }
270
271        // Try with single =
272        let result = line.replace(
273            &format!("{package_name}{conda_old_spec}"),
274            &format!("{package_name}{conda_new_spec}")
275        );
276        if result != line {
277            return Ok(result);
278        }
279
280        // Fallback
281        Ok(line.replace(old_spec, new_spec))
282    }
283}
284
285/// Detect package manager from file path
286fn detect_package_manager(path: &Path) -> Option<PackageManager> {
287    let file_name = path.file_name()?.to_str()?;
288
289    if file_name.starts_with("requirements") {
290        Some(PackageManager::Pip)
291    } else if file_name == "pyproject.toml" {
292        // We'd need to read the file to determine if it's uv, poetry, or pdm
293        // For now, default to uv as it's the most common
294        Some(PackageManager::Uv)
295    } else if file_name.starts_with("environment.") &&
296              (file_name.ends_with(".yml") || file_name.ends_with(".yaml")) {
297        Some(PackageManager::Conda)
298    } else if file_name == "uv.lock" {
299        Some(PackageManager::Uv)
300    } else if file_name == "poetry.lock" {
301        Some(PackageManager::Poetry)
302    } else if file_name == "pdm.lock" {
303        Some(PackageManager::Pdm)
304    } else {
305        None
306    }
307}
308
309/// Result of applying updates
310#[derive(Debug)]
311pub struct UpdateResult {
312    /// Files that were modified
313    pub modified_files: HashSet<PathBuf>,
314    /// Packages that were updated in multiple files
315    pub multi_file_packages: Vec<String>,
316    /// Package managers detected (for sync command suggestions)
317    pub package_managers: HashSet<PackageManager>,
318}
319
320impl UpdateResult {
321    /// Print post-update messages
322    pub fn print_summary(&self) {
323        if !self.multi_file_packages.is_empty() {
324            println!(
325                "\nNote: The following packages were updated in multiple files: {}",
326                self.multi_file_packages.join(", ")
327            );
328        }
329
330        for pm in &self.package_managers {
331            let cmd = match pm {
332                PackageManager::Pip => "pip install -r requirements.txt",
333                PackageManager::Uv => "uv lock",
334                PackageManager::Poetry => "poetry lock",
335                PackageManager::Pdm => "pdm lock",
336                PackageManager::Conda => "conda env update",
337            };
338            println!("Run {cmd} to sync dependencies");
339        }
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use std::io::Write;
347    use tempfile::NamedTempFile;
348
349    #[test]
350    fn test_replace_in_requirements() {
351        let updater = FileUpdater::new();
352
353        // Test basic pinned version
354        let result = updater.replace_in_requirements(
355            "requests==2.28.0",
356            "requests",
357            "==2.28.0",
358            "==2.32.3"
359        ).unwrap();
360        assert_eq!(result, "requests==2.32.3");
361
362        // Test range version
363        let result = updater.replace_in_requirements(
364            "numpy>=1.24.0,<2.0.0",
365            "numpy",
366            ">=1.24.0,<2.0.0",
367            ">=1.26.0,<2.0.0"
368        ).unwrap();
369        assert_eq!(result, "numpy>=1.26.0,<2.0.0");
370
371        // Test with extras
372        let result = updater.replace_in_requirements(
373            "requests[security]==2.28.0",
374            "requests",
375            "==2.28.0",
376            "==2.32.3"
377        ).unwrap();
378        assert_eq!(result, "requests[security]==2.32.3");
379    }
380
381    #[test]
382    fn test_replace_in_pyproject() {
383        let updater = FileUpdater::new();
384
385        // Test with double quotes
386        let result = updater.replace_in_pyproject(
387            "requests = \"^2.28.0\"",
388            "requests",
389            "^2.28.0",
390            "^2.32.3"
391        ).unwrap();
392        assert_eq!(result, "requests = \"^2.32.3\"");
393
394        // Test with single quotes
395        let result = updater.replace_in_pyproject(
396            "numpy = '^1.24.0'",
397            "numpy",
398            "^1.24.0",
399            "^1.26.0"
400        ).unwrap();
401        assert_eq!(result, "numpy = '^1.26.0'");
402    }
403
404    #[test]
405    fn test_replace_in_conda() {
406        let updater = FileUpdater::new();
407
408        // Test with == operator
409        let result = updater.replace_in_conda(
410            "  - numpy==1.24.0",
411            "numpy",
412            "==1.24.0",
413            "==1.26.0"
414        ).unwrap();
415        assert_eq!(result, "  - numpy==1.26.0");
416
417        // Test with single = operator
418        let result = updater.replace_in_conda(
419            "  - requests=2.28.0",
420            "requests",
421            "==2.28.0",
422            "==2.32.3"
423        ).unwrap();
424        assert_eq!(result, "  - requests=2.32.3");
425    }
426
427    #[test]
428    fn test_detect_package_manager() {
429        assert_eq!(
430            detect_package_manager(&PathBuf::from("/path/to/requirements.txt")),
431            Some(PackageManager::Pip)
432        );
433
434        assert_eq!(
435            detect_package_manager(&PathBuf::from("/path/to/requirements-dev.txt")),
436            Some(PackageManager::Pip)
437        );
438
439        assert_eq!(
440            detect_package_manager(&PathBuf::from("/path/to/pyproject.toml")),
441            Some(PackageManager::Uv)
442        );
443
444        assert_eq!(
445            detect_package_manager(&PathBuf::from("/path/to/environment.yml")),
446            Some(PackageManager::Conda)
447        );
448
449        assert_eq!(
450            detect_package_manager(&PathBuf::from("/path/to/poetry.lock")),
451            Some(PackageManager::Poetry)
452        );
453    }
454
455    #[test]
456    fn test_update_file_integration() -> Result<()> {
457        use crate::parsers::Dependency;
458        use check_updates_core::{Version, VersionSpec};
459
460        let updater = FileUpdater::new();
461
462        // Create a temporary requirements.txt file
463        let mut temp_file = NamedTempFile::new()?;
464        writeln!(temp_file, "requests==2.28.0")?;
465        writeln!(temp_file, "numpy>=1.24.0,<2.0.0")?;
466        writeln!(temp_file, "flask==2.0.3")?;
467        temp_file.flush()?;
468
469        let temp_path = temp_file.path().to_path_buf();
470
471        // Create mock dependency checks
472        let check1 = DependencyCheck {
473            dependency: Dependency {
474                name: "requests".to_string(),
475                version_spec: VersionSpec::Pinned(Version::new(2, 28, 0)),
476                source_file: temp_path.clone(),
477                line_number: 1,
478                original_line: "requests==2.28.0".to_string(),
479            },
480            installed: Some(Version::new(2, 28, 0)),
481            in_range: Some(Version::new(2, 32, 3)),
482            latest: Version::new(2, 32, 3),
483            target: Some(Version::new(2, 32, 3)),
484            target_spec: Some(VersionSpec::Pinned(Version::new(2, 32, 3))),
485            severity: Some(UpdateSeverity::Minor),
486            force_spec: Some(VersionSpec::Pinned(Version::new(2, 32, 3))),
487        };
488        let check2 = DependencyCheck {
489            dependency: Dependency {
490                name: "flask".to_string(),
491                version_spec: VersionSpec::Pinned(Version::new(2, 0, 3)),
492                source_file: temp_path.clone(),
493                line_number: 3,
494                original_line: "flask==2.0.3".to_string(),
495            },
496            installed: Some(Version::new(2, 0, 3)),
497            in_range: Some(Version::new(2, 3, 3)),
498            latest: Version::new(2, 3, 3),
499            target: Some(Version::new(2, 3, 3)),
500            target_spec: Some(VersionSpec::Pinned(Version::new(2, 3, 3))),
501            severity: Some(UpdateSeverity::Minor),
502            force_spec: Some(VersionSpec::Pinned(Version::new(2, 3, 3))),
503        };
504
505        // Create updates with version strings
506        let updates: Vec<(&DependencyCheck, String)> = vec![
507            (&check1, "==2.32.3".to_string()),
508            (&check2, "==2.3.3".to_string()),
509        ];
510
511        // Apply updates
512        updater.update_file(&temp_path, &updates)?;
513
514        // Read the updated file
515        let updated_content = fs::read_to_string(&temp_path)?;
516        let lines: Vec<&str> = updated_content.lines().collect();
517
518        // Verify updates
519        assert_eq!(lines[0], "requests==2.32.3");
520        assert_eq!(lines[1], "numpy>=1.24.0,<2.0.0"); // Unchanged
521        assert_eq!(lines[2], "flask==2.3.3");
522
523        Ok(())
524    }
525
526    #[test]
527    fn test_update_patch_only() -> Result<()> {
528        use crate::parsers::Dependency;
529        use check_updates_core::{Version, VersionSpec};
530
531        let mut file = NamedTempFile::new()?;
532        writeln!(file, "serde==1.0.0")?;
533        writeln!(file, "tokio==1.0.0")?;
534        file.flush()?;
535
536        let temp_path = file.path().to_path_buf();
537
538        let checks = vec![
539            DependencyCheck {
540                dependency: Dependency {
541                    name: "serde".to_string(),
542                    version_spec: VersionSpec::Pinned(Version::new(1, 0, 0)),
543                    source_file: temp_path.clone(),
544                    line_number: 1,
545                    original_line: "serde==1.0.0".to_string(),
546                },
547                installed: Some(Version::new(1, 0, 0)),
548                in_range: Some(Version::new(1, 0, 200)),
549                latest: Version::new(1, 0, 200),
550                target: Some(Version::new(1, 0, 200)),
551                target_spec: Some(VersionSpec::Pinned(Version::new(1, 0, 200))),
552                severity: Some(UpdateSeverity::Patch),
553                force_spec: Some(VersionSpec::Pinned(Version::new(1, 0, 200))),
554            },
555            DependencyCheck {
556                dependency: Dependency {
557                    name: "tokio".to_string(),
558                    version_spec: VersionSpec::Pinned(Version::new(1, 0, 0)),
559                    source_file: temp_path.clone(),
560                    line_number: 2,
561                    original_line: "tokio==1.0.0".to_string(),
562                },
563                installed: Some(Version::new(1, 0, 0)),
564                in_range: Some(Version::new(1, 5, 0)),
565                latest: Version::new(1, 5, 0),
566                target: Some(Version::new(1, 5, 0)),
567                target_spec: Some(VersionSpec::Pinned(Version::new(1, 5, 0))),
568                severity: Some(UpdateSeverity::Minor),
569                force_spec: Some(VersionSpec::Pinned(Version::new(1, 5, 0))),
570            },
571        ];
572
573        let updater = FileUpdater::new();
574        updater.apply_updates(&checks, false, false)?; // patch only
575
576        let content = fs::read_to_string(&temp_path)?;
577        assert!(content.contains("==1.0.200"), "serde should be updated: {}", content);
578        assert!(!content.contains("==1.5.0"), "tokio should NOT be updated: {}", content);
579
580        Ok(())
581    }
582
583    #[test]
584    fn test_update_patch_and_minor() -> Result<()> {
585        use crate::parsers::Dependency;
586        use check_updates_core::{Version, VersionSpec};
587
588        let mut file = NamedTempFile::new()?;
589        writeln!(file, "serde==1.0.0")?;
590        writeln!(file, "tokio==1.0.0")?;
591        file.flush()?;
592
593        let temp_path = file.path().to_path_buf();
594
595        let checks = vec![
596            DependencyCheck {
597                dependency: Dependency {
598                    name: "serde".to_string(),
599                    version_spec: VersionSpec::Pinned(Version::new(1, 0, 0)),
600                    source_file: temp_path.clone(),
601                    line_number: 1,
602                    original_line: "serde==1.0.0".to_string(),
603                },
604                installed: Some(Version::new(1, 0, 0)),
605                in_range: Some(Version::new(1, 0, 200)),
606                latest: Version::new(1, 0, 200),
607                target: Some(Version::new(1, 0, 200)),
608                target_spec: Some(VersionSpec::Pinned(Version::new(1, 0, 200))),
609                severity: Some(UpdateSeverity::Patch),
610                force_spec: Some(VersionSpec::Pinned(Version::new(1, 0, 200))),
611            },
612            DependencyCheck {
613                dependency: Dependency {
614                    name: "tokio".to_string(),
615                    version_spec: VersionSpec::Pinned(Version::new(1, 0, 0)),
616                    source_file: temp_path.clone(),
617                    line_number: 2,
618                    original_line: "tokio==1.0.0".to_string(),
619                },
620                installed: Some(Version::new(1, 0, 0)),
621                in_range: Some(Version::new(1, 5, 0)),
622                latest: Version::new(1, 5, 0),
623                target: Some(Version::new(1, 5, 0)),
624                target_spec: Some(VersionSpec::Pinned(Version::new(1, 5, 0))),
625                severity: Some(UpdateSeverity::Minor),
626                force_spec: Some(VersionSpec::Pinned(Version::new(1, 5, 0))),
627            },
628        ];
629
630        let updater = FileUpdater::new();
631        updater.apply_updates(&checks, true, false)?; // patch + minor
632
633        let content = fs::read_to_string(&temp_path)?;
634        assert!(content.contains("==1.0.200"), "serde should be updated: {}", content);
635        assert!(content.contains("==1.5.0"), "tokio should be updated: {}", content);
636
637        Ok(())
638    }
639}