Skip to main content

ncu/
updater.rs

1use check_updates_core::{DependencyCheck, UpdateSeverity};
2use anyhow::{Context, Result};
3use std::collections::HashSet;
4use std::fs;
5use std::path::PathBuf;
6
7/// Updates package.json with new versions
8pub struct FileUpdater;
9
10impl FileUpdater {
11    pub fn new() -> Self {
12        Self
13    }
14
15    /// Apply updates to package.json based on severity filter
16    pub fn apply_updates(
17        &self,
18        checks: &[DependencyCheck],
19        include_minor: bool,
20        force: bool,
21    ) -> Result<UpdateResult> {
22        let mut modified_files = HashSet::new();
23
24        // Group checks by file, filtering by severity
25        let mut file_updates: std::collections::HashMap<PathBuf, Vec<(&DependencyCheck, String)>> =
26            std::collections::HashMap::new();
27
28        for check in checks {
29            let version_spec = if force {
30                check.force_spec.as_ref()
31            } else {
32                match check.severity {
33                    Some(UpdateSeverity::Patch) => check.target_spec.as_ref(),
34                    Some(UpdateSeverity::Minor) if include_minor => check.target_spec.as_ref(),
35                    _ => None,
36                }
37            };
38
39            if let Some(spec) = version_spec
40                && spec.is_rewritable() {
41                // For npm, preserve the original prefix (^, ~, etc.)
42                let new_version = spec.to_string();
43                file_updates
44                    .entry(check.dependency.source_file.clone())
45                    .or_default()
46                    .push((check, new_version));
47            }
48        }
49
50        for (file_path, updates) in file_updates {
51            self.update_file(&file_path, &updates)
52                .with_context(|| format!("Failed to update file: {}", file_path.display()))?;
53            modified_files.insert(file_path);
54        }
55
56        Ok(UpdateResult { modified_files })
57    }
58
59    fn update_file(
60        &self,
61        file_path: &PathBuf,
62        updates: &[(&DependencyCheck, String)],
63    ) -> Result<()> {
64        let content = fs::read_to_string(file_path)
65            .with_context(|| format!("Failed to read file: {}", file_path.display()))?;
66
67        // Parse as JSON, update, and write back preserving formatting
68        let mut parsed: serde_json::Value = serde_json::from_str(&content)
69            .with_context(|| format!("Failed to parse JSON: {}", file_path.display()))?;
70
71        for (check, new_version) in updates {
72            self.update_dependency(&mut parsed, &check.dependency.name, new_version);
73        }
74
75        // Write back with pretty formatting
76        let updated = serde_json::to_string_pretty(&parsed)
77            .with_context(|| "Failed to serialize JSON")?;
78
79        fs::write(file_path, updated + "\n")
80            .with_context(|| format!("Failed to write file: {}", file_path.display()))?;
81
82        Ok(())
83    }
84
85    fn update_dependency(&self, doc: &mut serde_json::Value, name: &str, new_version: &str) {
86        let sections = [
87            "dependencies",
88            "devDependencies",
89            "peerDependencies",
90            "optionalDependencies",
91        ];
92
93        for section in sections {
94            if let Some(deps) = doc.get_mut(section).and_then(|v| v.as_object_mut())
95                && deps.contains_key(name) {
96                    deps.insert(name.to_string(), serde_json::Value::String(new_version.to_string()));
97                }
98        }
99    }
100}
101
102impl Default for FileUpdater {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108#[derive(Debug)]
109pub struct UpdateResult {
110    pub modified_files: HashSet<PathBuf>,
111}
112
113impl UpdateResult {
114    pub fn print_summary(&self) {
115        if !self.modified_files.is_empty() {
116            println!("Run `npm install` to install updated packages");
117        }
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use check_updates_core::{Dependency, Version, VersionSpec};
125    use std::io::Write;
126    use std::str::FromStr;
127    use tempfile::NamedTempFile;
128
129    fn create_check(
130        name: &str,
131        spec_str: &str,
132        path: PathBuf,
133        target_version: &str,
134        severity: UpdateSeverity,
135    ) -> DependencyCheck {
136        let target = Version::from_str(target_version).unwrap();
137        DependencyCheck {
138            dependency: Dependency {
139                name: name.to_string(),
140                version_spec: VersionSpec::parse(spec_str).unwrap(),
141                source_file: path,
142                line_number: 2,
143                original_line: format!("\"{}\": \"{}\"", name, spec_str),
144            },
145            installed: Some(Version::from_str(spec_str.trim_start_matches('^').trim_start_matches('~')).unwrap()),
146            in_range: Some(target.clone()),
147            latest: target.clone(),
148            target: Some(target.clone()),
149            target_spec: Some(VersionSpec::parse(&format!("^{}", target_version)).unwrap()),
150            severity: Some(severity),
151            force_spec: Some(VersionSpec::parse(&format!("^{}", target_version)).unwrap()),
152        }
153    }
154
155    #[test]
156    fn test_update_patch_only() -> Result<()> {
157        let mut file = NamedTempFile::new()?;
158        writeln!(
159            file,
160            r#"{{
161  "dependencies": {{
162    "express": "^4.18.0",
163    "lodash": "^4.17.0"
164  }}
165}}"#
166        )?;
167        file.flush()?;
168
169        let temp_path = file.path().to_path_buf();
170
171        let checks = vec![
172            create_check("express", "^4.18.0", temp_path.clone(), "4.18.2", UpdateSeverity::Patch),
173            create_check("lodash", "^4.17.0", temp_path.clone(), "4.18.0", UpdateSeverity::Minor),
174        ];
175
176        let updater = FileUpdater::new();
177        updater.apply_updates(&checks, false, false)?;
178
179        let content = fs::read_to_string(&temp_path)?;
180        assert!(content.contains("4.18.2"), "express should be updated: {}", content);
181        assert!(!content.contains("4.18.0") || content.contains("^4.18.0"), "lodash should NOT be updated");
182
183        Ok(())
184    }
185}