python_check_updates/
updater.rs

1use crate::resolver::DependencyCheck;
2use crate::detector::PackageManager;
3use anyhow::{Context, Result};
4use std::collections::{HashSet, HashMap};
5use std::path::PathBuf;
6use std::fs;
7
8/// Updates dependency files with new versions
9pub struct FileUpdater;
10
11impl FileUpdater {
12    pub fn new() -> Self {
13        Self
14    }
15
16    /// Apply updates to dependency files
17    pub fn apply_updates(&self, checks: &[DependencyCheck]) -> Result<UpdateResult> {
18        let mut modified_files = HashSet::new();
19        let mut package_file_map: HashMap<String, Vec<PathBuf>> = HashMap::new();
20        let mut package_managers = HashSet::new();
21
22        // Group checks by file
23        let mut file_updates: HashMap<PathBuf, Vec<&DependencyCheck>> = HashMap::new();
24        for check in checks {
25            if check.has_update() {
26                file_updates
27                    .entry(check.dependency.source_file.clone())
28                    .or_insert_with(Vec::new)
29                    .push(check);
30
31                // Track which packages appear in which files
32                package_file_map
33                    .entry(check.dependency.name.clone())
34                    .or_insert_with(Vec::new)
35                    .push(check.dependency.source_file.clone());
36            }
37        }
38
39        // Update each file
40        for (file_path, updates) in file_updates {
41            self.update_file(&file_path, updates)
42                .with_context(|| format!("Failed to update file: {}", file_path.display()))?;
43
44            modified_files.insert(file_path.clone());
45
46            // Detect package manager from file name
47            if let Some(pm) = detect_package_manager(&file_path) {
48                package_managers.insert(pm);
49            }
50        }
51
52        // Find packages updated in multiple files
53        let mut multi_file_packages: Vec<String> = package_file_map
54            .iter()
55            .filter_map(|(pkg, files)| {
56                let unique_files: HashSet<_> = files.iter().collect();
57                if unique_files.len() > 1 {
58                    Some(pkg.clone())
59                } else {
60                    None
61                }
62            })
63            .collect();
64        multi_file_packages.sort();
65
66        Ok(UpdateResult {
67            modified_files,
68            multi_file_packages,
69            package_managers,
70        })
71    }
72
73    /// Update a single file with the given dependency updates
74    fn update_file(&self, file_path: &PathBuf, updates: Vec<&DependencyCheck>) -> Result<()> {
75        // Read the entire file
76        let content = fs::read_to_string(file_path)
77            .with_context(|| format!("Failed to read file: {}", file_path.display()))?;
78
79        let mut lines: Vec<String> = content.lines().map(|s| s.to_string()).collect();
80
81        // Sort updates by line number in descending order to avoid offset issues
82        let mut sorted_updates = updates;
83        sorted_updates.sort_by(|a, b| b.dependency.line_number.cmp(&a.dependency.line_number));
84
85        // Apply each update
86        for update in sorted_updates {
87            let line_idx = update.dependency.line_number.saturating_sub(1);
88
89            if line_idx >= lines.len() {
90                continue; // Skip if line number is out of bounds
91            }
92
93            let original_line = &lines[line_idx];
94
95            // Get the new version spec
96            if let Some(new_spec) = &update.update_to {
97                let updated_line = self.replace_version_in_line(
98                    original_line,
99                    &update.dependency.name,
100                    &update.dependency.version_spec.to_string(),
101                    &new_spec.to_string(),
102                    file_path,
103                )?;
104
105                lines[line_idx] = updated_line;
106            }
107        }
108
109        // Write the file back
110        let new_content = lines.join("\n");
111        // Add trailing newline if original had one
112        let new_content = if content.ends_with('\n') {
113            format!("{}\n", new_content)
114        } else {
115            new_content
116        };
117
118        fs::write(file_path, new_content)
119            .with_context(|| format!("Failed to write file: {}", file_path.display()))?;
120
121        Ok(())
122    }
123
124    /// Replace version specification in a line
125    fn replace_version_in_line(
126        &self,
127        line: &str,
128        package_name: &str,
129        old_spec: &str,
130        new_spec: &str,
131        file_path: &PathBuf,
132    ) -> Result<String> {
133        let file_name = file_path.file_name()
134            .and_then(|n| n.to_str())
135            .unwrap_or("");
136
137        // Determine file type and use appropriate replacement strategy
138        if file_name.starts_with("requirements") || file_name.ends_with(".txt") {
139            self.replace_in_requirements(line, package_name, old_spec, new_spec)
140        } else if file_name == "pyproject.toml" {
141            self.replace_in_pyproject(line, package_name, old_spec, new_spec)
142        } else if file_name.starts_with("environment.") &&
143                  (file_name.ends_with(".yml") || file_name.ends_with(".yaml")) {
144            self.replace_in_conda(line, package_name, old_spec, new_spec)
145        } else {
146            // Default to requirements.txt style
147            self.replace_in_requirements(line, package_name, old_spec, new_spec)
148        }
149    }
150
151    /// Replace version in requirements.txt format
152    fn replace_in_requirements(
153        &self,
154        line: &str,
155        package_name: &str,
156        old_spec: &str,
157        new_spec: &str,
158    ) -> Result<String> {
159        // Format: package==1.0.0 or package>=1.0.0,<2.0.0 or package[extras]==1.0.0
160
161        // Try exact match first
162        if let Some(new_line) = line.replace(&format!("{}{}", package_name, old_spec),
163                                              &format!("{}{}", package_name, new_spec))
164                                    .into() {
165            if new_line != line {
166                return Ok(new_line);
167            }
168        }
169
170        // Try with brackets (extras)
171        if line.contains('[') {
172            if let Some(bracket_start) = line.find('[') {
173                if let Some(bracket_end) = line.find(']') {
174                    let before_bracket = &line[..bracket_start];
175                    let extras = &line[bracket_start..=bracket_end];
176                    let after_bracket = &line[bracket_end + 1..];
177
178                    if before_bracket.trim() == package_name {
179                        let new_after = after_bracket.replace(old_spec, new_spec);
180                        return Ok(format!("{}{}{}", before_bracket, extras, new_after));
181                    }
182                }
183            }
184        }
185
186        // Fallback: simple string replacement
187        Ok(line.replace(old_spec, new_spec))
188    }
189
190    /// Replace version in pyproject.toml format
191    fn replace_in_pyproject(
192        &self,
193        line: &str,
194        package_name: &str,
195        old_spec: &str,
196        new_spec: &str,
197    ) -> Result<String> {
198        // Format: package = "^1.0.0" or package = {version = "^1.0.0", ...}
199
200        // Check if line contains the package name (case-insensitive for TOML keys)
201        if line.to_lowercase().contains(&package_name.to_lowercase()) {
202            // Replace the version spec, preserving quotes
203            let result = line.replace(
204                &format!("\"{}\"", old_spec),
205                &format!("\"{}\"", new_spec)
206            );
207            if result != line {
208                return Ok(result);
209            }
210
211            // Try single quotes
212            let result = line.replace(
213                &format!("'{}'", old_spec),
214                &format!("'{}'", new_spec)
215            );
216            if result != line {
217                return Ok(result);
218            }
219        }
220
221        // Fallback
222        Ok(line.replace(old_spec, new_spec))
223    }
224
225    /// Replace version in conda environment.yml format
226    fn replace_in_conda(
227        &self,
228        line: &str,
229        package_name: &str,
230        old_spec: &str,
231        new_spec: &str,
232    ) -> Result<String> {
233        // Format: - package==1.0.0 or - package=1.0.0
234
235        // Conda uses = instead of == sometimes
236        let conda_old_spec = old_spec.replace("==", "=");
237        let conda_new_spec = new_spec.replace("==", "=");
238
239        // Try with ==
240        let result = line.replace(
241            &format!("{}{}", package_name, old_spec),
242            &format!("{}{}", package_name, new_spec)
243        );
244        if result != line {
245            return Ok(result);
246        }
247
248        // Try with single =
249        let result = line.replace(
250            &format!("{}{}", package_name, conda_old_spec),
251            &format!("{}{}", package_name, conda_new_spec)
252        );
253        if result != line {
254            return Ok(result);
255        }
256
257        // Fallback
258        Ok(line.replace(old_spec, new_spec))
259    }
260}
261
262/// Detect package manager from file path
263fn detect_package_manager(path: &PathBuf) -> Option<PackageManager> {
264    let file_name = path.file_name()?.to_str()?;
265
266    if file_name.starts_with("requirements") {
267        Some(PackageManager::Pip)
268    } else if file_name == "pyproject.toml" {
269        // We'd need to read the file to determine if it's uv, poetry, or pdm
270        // For now, default to uv as it's the most common
271        Some(PackageManager::Uv)
272    } else if file_name.starts_with("environment.") &&
273              (file_name.ends_with(".yml") || file_name.ends_with(".yaml")) {
274        Some(PackageManager::Conda)
275    } else if file_name == "uv.lock" {
276        Some(PackageManager::Uv)
277    } else if file_name == "poetry.lock" {
278        Some(PackageManager::Poetry)
279    } else if file_name == "pdm.lock" {
280        Some(PackageManager::Pdm)
281    } else {
282        None
283    }
284}
285
286/// Result of applying updates
287#[derive(Debug)]
288pub struct UpdateResult {
289    /// Files that were modified
290    pub modified_files: HashSet<PathBuf>,
291    /// Packages that were updated in multiple files
292    pub multi_file_packages: Vec<String>,
293    /// Package managers detected (for sync command suggestions)
294    pub package_managers: HashSet<PackageManager>,
295}
296
297impl UpdateResult {
298    /// Print post-update messages
299    pub fn print_summary(&self) {
300        if !self.multi_file_packages.is_empty() {
301            println!(
302                "\nNote: The following packages were updated in multiple files: {}",
303                self.multi_file_packages.join(", ")
304            );
305        }
306
307        for pm in &self.package_managers {
308            let cmd = match pm {
309                PackageManager::Pip => "pip install -r requirements.txt",
310                PackageManager::Uv => "uv lock",
311                PackageManager::Poetry => "poetry lock",
312                PackageManager::Pdm => "pdm lock",
313                PackageManager::Conda => "conda env update",
314            };
315            println!("Run {} to sync dependencies", cmd);
316        }
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use std::io::Write;
324    use tempfile::NamedTempFile;
325
326    #[test]
327    fn test_replace_in_requirements() {
328        let updater = FileUpdater::new();
329
330        // Test basic pinned version
331        let result = updater.replace_in_requirements(
332            "requests==2.28.0",
333            "requests",
334            "==2.28.0",
335            "==2.32.3"
336        ).unwrap();
337        assert_eq!(result, "requests==2.32.3");
338
339        // Test range version
340        let result = updater.replace_in_requirements(
341            "numpy>=1.24.0,<2.0.0",
342            "numpy",
343            ">=1.24.0,<2.0.0",
344            ">=1.26.0,<2.0.0"
345        ).unwrap();
346        assert_eq!(result, "numpy>=1.26.0,<2.0.0");
347
348        // Test with extras
349        let result = updater.replace_in_requirements(
350            "requests[security]==2.28.0",
351            "requests",
352            "==2.28.0",
353            "==2.32.3"
354        ).unwrap();
355        assert_eq!(result, "requests[security]==2.32.3");
356    }
357
358    #[test]
359    fn test_replace_in_pyproject() {
360        let updater = FileUpdater::new();
361
362        // Test with double quotes
363        let result = updater.replace_in_pyproject(
364            "requests = \"^2.28.0\"",
365            "requests",
366            "^2.28.0",
367            "^2.32.3"
368        ).unwrap();
369        assert_eq!(result, "requests = \"^2.32.3\"");
370
371        // Test with single quotes
372        let result = updater.replace_in_pyproject(
373            "numpy = '^1.24.0'",
374            "numpy",
375            "^1.24.0",
376            "^1.26.0"
377        ).unwrap();
378        assert_eq!(result, "numpy = '^1.26.0'");
379    }
380
381    #[test]
382    fn test_replace_in_conda() {
383        let updater = FileUpdater::new();
384
385        // Test with == operator
386        let result = updater.replace_in_conda(
387            "  - numpy==1.24.0",
388            "numpy",
389            "==1.24.0",
390            "==1.26.0"
391        ).unwrap();
392        assert_eq!(result, "  - numpy==1.26.0");
393
394        // Test with single = operator
395        let result = updater.replace_in_conda(
396            "  - requests=2.28.0",
397            "requests",
398            "==2.28.0",
399            "==2.32.3"
400        ).unwrap();
401        assert_eq!(result, "  - requests=2.32.3");
402    }
403
404    #[test]
405    fn test_detect_package_manager() {
406        assert_eq!(
407            detect_package_manager(&PathBuf::from("/path/to/requirements.txt")),
408            Some(PackageManager::Pip)
409        );
410
411        assert_eq!(
412            detect_package_manager(&PathBuf::from("/path/to/requirements-dev.txt")),
413            Some(PackageManager::Pip)
414        );
415
416        assert_eq!(
417            detect_package_manager(&PathBuf::from("/path/to/pyproject.toml")),
418            Some(PackageManager::Uv)
419        );
420
421        assert_eq!(
422            detect_package_manager(&PathBuf::from("/path/to/environment.yml")),
423            Some(PackageManager::Conda)
424        );
425
426        assert_eq!(
427            detect_package_manager(&PathBuf::from("/path/to/poetry.lock")),
428            Some(PackageManager::Poetry)
429        );
430    }
431
432    #[test]
433    fn test_update_file_integration() -> Result<()> {
434        use crate::parsers::Dependency;
435        use crate::version::{Version, VersionSpec};
436
437        let updater = FileUpdater::new();
438
439        // Create a temporary requirements.txt file
440        let mut temp_file = NamedTempFile::new()?;
441        writeln!(temp_file, "requests==2.28.0")?;
442        writeln!(temp_file, "numpy>=1.24.0,<2.0.0")?;
443        writeln!(temp_file, "flask==2.0.3")?;
444        temp_file.flush()?;
445
446        let temp_path = temp_file.path().to_path_buf();
447
448        // Create mock dependency checks
449        let checks = vec![
450            DependencyCheck {
451                dependency: Dependency {
452                    name: "requests".to_string(),
453                    version_spec: VersionSpec::Pinned(Version::new(2, 28, 0)),
454                    source_file: temp_path.clone(),
455                    line_number: 1,
456                    original_line: "requests==2.28.0".to_string(),
457                },
458                installed: Some(Version::new(2, 28, 0)),
459                in_range: Some(Version::new(2, 32, 3)),
460                latest: Version::new(2, 32, 3),
461                update_to: Some(VersionSpec::Pinned(Version::new(2, 32, 3))),
462            },
463            DependencyCheck {
464                dependency: Dependency {
465                    name: "flask".to_string(),
466                    version_spec: VersionSpec::Pinned(Version::new(2, 0, 3)),
467                    source_file: temp_path.clone(),
468                    line_number: 3,
469                    original_line: "flask==2.0.3".to_string(),
470                },
471                installed: Some(Version::new(2, 0, 3)),
472                in_range: Some(Version::new(2, 3, 3)),
473                latest: Version::new(2, 3, 3),
474                update_to: Some(VersionSpec::Pinned(Version::new(2, 3, 3))),
475            },
476        ];
477
478        // Apply updates
479        updater.update_file(&temp_path, checks.iter().collect())?;
480
481        // Read the updated file
482        let updated_content = fs::read_to_string(&temp_path)?;
483        let lines: Vec<&str> = updated_content.lines().collect();
484
485        // Verify updates
486        assert_eq!(lines[0], "requests==2.32.3");
487        assert_eq!(lines[1], "numpy>=1.24.0,<2.0.0"); // Unchanged
488        assert_eq!(lines[2], "flask==2.3.3");
489
490        Ok(())
491    }
492}