1use check_updates_core::{DependencyCheck, UpdateSeverity};
2use crate::detector::PackageManager;
3use anyhow::{Context, Result};
4use std::collections::{HashSet, HashMap};
5use std::path::{Path, PathBuf};
6use std::fs;
7
8pub struct FileUpdater;
10
11impl Default for FileUpdater {
12 fn default() -> Self {
13 Self::new()
14 }
15}
16
17impl FileUpdater {
18 pub fn new() -> Self {
19 Self
20 }
21
22 pub fn apply_updates(
26 &self,
27 checks: &[DependencyCheck],
28 include_minor: bool,
29 force: bool,
30 ) -> Result<UpdateResult> {
31 let mut modified_files = HashSet::new();
32 let mut package_file_map: HashMap<String, Vec<PathBuf>> = HashMap::new();
33 let mut package_managers = HashSet::new();
34
35 let mut file_updates: HashMap<PathBuf, Vec<(&DependencyCheck, String)>> = HashMap::new();
37
38 for check in checks {
39 let version_spec = if force {
41 check.force_spec.as_ref()
43 } else {
44 match check.severity {
46 Some(UpdateSeverity::Patch) => check.target_spec.as_ref(),
47 Some(UpdateSeverity::Minor) if include_minor => check.target_spec.as_ref(),
48 _ => None, }
50 };
51
52 if let Some(spec) = version_spec
53 && spec.is_rewritable() {
54 let new_version = spec.to_string();
55 file_updates
56 .entry(check.dependency.source_file.clone())
57 .or_default()
58 .push((check, new_version));
59
60 package_file_map
62 .entry(check.dependency.name.clone())
63 .or_default()
64 .push(check.dependency.source_file.clone());
65 }
66 }
67
68 for (file_path, updates) in file_updates {
70 self.update_file(&file_path, &updates)
71 .with_context(|| format!("Failed to update file: {}", file_path.display()))?;
72
73 modified_files.insert(file_path.clone());
74
75 if let Some(pm) = detect_package_manager(&file_path) {
77 package_managers.insert(pm);
78 }
79 }
80
81 let mut multi_file_packages: Vec<String> = package_file_map
83 .iter()
84 .filter_map(|(pkg, files)| {
85 let unique_files: HashSet<_> = files.iter().collect();
86 if unique_files.len() > 1 {
87 Some(pkg.clone())
88 } else {
89 None
90 }
91 })
92 .collect();
93 multi_file_packages.sort();
94
95 Ok(UpdateResult {
96 modified_files,
97 multi_file_packages,
98 package_managers,
99 })
100 }
101
102 fn update_file(&self, file_path: &Path, updates: &[(&DependencyCheck, String)]) -> Result<()> {
104 let content = fs::read_to_string(file_path)
106 .with_context(|| format!("Failed to read file: {}", file_path.display()))?;
107
108 let mut lines: Vec<String> = content.lines().map(std::string::ToString::to_string).collect();
109
110 let mut sorted_updates: Vec<_> = updates.iter().collect();
112 sorted_updates.sort_by_key(|x| std::cmp::Reverse(x.0.dependency.line_number));
113
114 for (check, new_version) in sorted_updates {
116 let line_idx = check.dependency.line_number.saturating_sub(1);
117
118 if line_idx >= lines.len() {
119 continue; }
121
122 let original_line = &lines[line_idx];
123
124 let updated_line = self.replace_version_in_line(
125 original_line,
126 &check.dependency.name,
127 &check.dependency.version_spec.to_string(),
128 new_version,
129 file_path,
130 )?;
131
132 lines[line_idx] = updated_line;
133 }
134
135 let new_content = lines.join("\n");
137 let new_content = if content.ends_with('\n') {
139 format!("{new_content}\n")
140 } else {
141 new_content
142 };
143
144 fs::write(file_path, new_content)
145 .with_context(|| format!("Failed to write file: {}", file_path.display()))?;
146
147 Ok(())
148 }
149
150 fn replace_version_in_line(
152 &self,
153 line: &str,
154 package_name: &str,
155 old_spec: &str,
156 new_spec: &str,
157 file_path: &Path,
158 ) -> Result<String> {
159 let file_name = file_path.file_name()
160 .and_then(|n| n.to_str())
161 .unwrap_or("");
162
163 if file_name.starts_with("requirements") || file_name.ends_with(".txt") {
165 self.replace_in_requirements(line, package_name, old_spec, new_spec)
166 } else if file_name == "pyproject.toml" {
167 self.replace_in_pyproject(line, package_name, old_spec, new_spec)
168 } else if file_name.starts_with("environment.") &&
169 (file_name.ends_with(".yml") || file_name.ends_with(".yaml")) {
170 self.replace_in_conda(line, package_name, old_spec, new_spec)
171 } else {
172 self.replace_in_requirements(line, package_name, old_spec, new_spec)
174 }
175 }
176
177 fn replace_in_requirements(
179 &self,
180 line: &str,
181 package_name: &str,
182 old_spec: &str,
183 new_spec: &str,
184 ) -> Result<String> {
185 if let Some(new_line) = line.replace(&format!("{package_name}{old_spec}"),
189 &format!("{package_name}{new_spec}"))
190 .into()
191 && new_line != line {
192 return Ok(new_line);
193 }
194
195 if line.contains('[')
197 && let Some(bracket_start) = line.find('[')
198 && let Some(bracket_end) = line.find(']') {
199 let before_bracket = &line[..bracket_start];
200 let extras = &line[bracket_start..=bracket_end];
201 let after_bracket = &line[bracket_end + 1..];
202
203 if before_bracket.trim() == package_name {
204 let new_after = after_bracket.replace(old_spec, new_spec);
205 return Ok(format!("{before_bracket}{extras}{new_after}"));
206 }
207 }
208
209 Ok(line.replace(old_spec, new_spec))
211 }
212
213 fn replace_in_pyproject(
215 &self,
216 line: &str,
217 package_name: &str,
218 old_spec: &str,
219 new_spec: &str,
220 ) -> Result<String> {
221 if line.to_lowercase().contains(&package_name.to_lowercase()) {
225 let result = line.replace(
227 &format!("\"{old_spec}\""),
228 &format!("\"{new_spec}\"")
229 );
230 if result != line {
231 return Ok(result);
232 }
233
234 let result = line.replace(
236 &format!("'{old_spec}'"),
237 &format!("'{new_spec}'")
238 );
239 if result != line {
240 return Ok(result);
241 }
242 }
243
244 Ok(line.replace(old_spec, new_spec))
246 }
247
248 fn replace_in_conda(
250 &self,
251 line: &str,
252 package_name: &str,
253 old_spec: &str,
254 new_spec: &str,
255 ) -> Result<String> {
256 let conda_old_spec = old_spec.replace("==", "=");
260 let conda_new_spec = new_spec.replace("==", "=");
261
262 let result = line.replace(
264 &format!("{package_name}{old_spec}"),
265 &format!("{package_name}{new_spec}")
266 );
267 if result != line {
268 return Ok(result);
269 }
270
271 let result = line.replace(
273 &format!("{package_name}{conda_old_spec}"),
274 &format!("{package_name}{conda_new_spec}")
275 );
276 if result != line {
277 return Ok(result);
278 }
279
280 Ok(line.replace(old_spec, new_spec))
282 }
283}
284
285fn detect_package_manager(path: &Path) -> Option<PackageManager> {
287 let file_name = path.file_name()?.to_str()?;
288
289 if file_name.starts_with("requirements") {
290 Some(PackageManager::Pip)
291 } else if file_name == "pyproject.toml" {
292 Some(PackageManager::Uv)
295 } else if file_name.starts_with("environment.") &&
296 (file_name.ends_with(".yml") || file_name.ends_with(".yaml")) {
297 Some(PackageManager::Conda)
298 } else if file_name == "uv.lock" {
299 Some(PackageManager::Uv)
300 } else if file_name == "poetry.lock" {
301 Some(PackageManager::Poetry)
302 } else if file_name == "pdm.lock" {
303 Some(PackageManager::Pdm)
304 } else {
305 None
306 }
307}
308
309#[derive(Debug)]
311pub struct UpdateResult {
312 pub modified_files: HashSet<PathBuf>,
314 pub multi_file_packages: Vec<String>,
316 pub package_managers: HashSet<PackageManager>,
318}
319
320impl UpdateResult {
321 pub fn print_summary(&self) {
323 if !self.multi_file_packages.is_empty() {
324 println!(
325 "\nNote: The following packages were updated in multiple files: {}",
326 self.multi_file_packages.join(", ")
327 );
328 }
329
330 for pm in &self.package_managers {
331 let cmd = match pm {
332 PackageManager::Pip => "pip install -r requirements.txt",
333 PackageManager::Uv => "uv lock",
334 PackageManager::Poetry => "poetry lock",
335 PackageManager::Pdm => "pdm lock",
336 PackageManager::Conda => "conda env update",
337 };
338 println!("Run {cmd} to sync dependencies");
339 }
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use std::io::Write;
347 use tempfile::NamedTempFile;
348
349 #[test]
350 fn test_replace_in_requirements() {
351 let updater = FileUpdater::new();
352
353 let result = updater.replace_in_requirements(
355 "requests==2.28.0",
356 "requests",
357 "==2.28.0",
358 "==2.32.3"
359 ).unwrap();
360 assert_eq!(result, "requests==2.32.3");
361
362 let result = updater.replace_in_requirements(
364 "numpy>=1.24.0,<2.0.0",
365 "numpy",
366 ">=1.24.0,<2.0.0",
367 ">=1.26.0,<2.0.0"
368 ).unwrap();
369 assert_eq!(result, "numpy>=1.26.0,<2.0.0");
370
371 let result = updater.replace_in_requirements(
373 "requests[security]==2.28.0",
374 "requests",
375 "==2.28.0",
376 "==2.32.3"
377 ).unwrap();
378 assert_eq!(result, "requests[security]==2.32.3");
379 }
380
381 #[test]
382 fn test_replace_in_pyproject() {
383 let updater = FileUpdater::new();
384
385 let result = updater.replace_in_pyproject(
387 "requests = \"^2.28.0\"",
388 "requests",
389 "^2.28.0",
390 "^2.32.3"
391 ).unwrap();
392 assert_eq!(result, "requests = \"^2.32.3\"");
393
394 let result = updater.replace_in_pyproject(
396 "numpy = '^1.24.0'",
397 "numpy",
398 "^1.24.0",
399 "^1.26.0"
400 ).unwrap();
401 assert_eq!(result, "numpy = '^1.26.0'");
402 }
403
404 #[test]
405 fn test_replace_in_conda() {
406 let updater = FileUpdater::new();
407
408 let result = updater.replace_in_conda(
410 " - numpy==1.24.0",
411 "numpy",
412 "==1.24.0",
413 "==1.26.0"
414 ).unwrap();
415 assert_eq!(result, " - numpy==1.26.0");
416
417 let result = updater.replace_in_conda(
419 " - requests=2.28.0",
420 "requests",
421 "==2.28.0",
422 "==2.32.3"
423 ).unwrap();
424 assert_eq!(result, " - requests=2.32.3");
425 }
426
427 #[test]
428 fn test_detect_package_manager() {
429 assert_eq!(
430 detect_package_manager(&PathBuf::from("/path/to/requirements.txt")),
431 Some(PackageManager::Pip)
432 );
433
434 assert_eq!(
435 detect_package_manager(&PathBuf::from("/path/to/requirements-dev.txt")),
436 Some(PackageManager::Pip)
437 );
438
439 assert_eq!(
440 detect_package_manager(&PathBuf::from("/path/to/pyproject.toml")),
441 Some(PackageManager::Uv)
442 );
443
444 assert_eq!(
445 detect_package_manager(&PathBuf::from("/path/to/environment.yml")),
446 Some(PackageManager::Conda)
447 );
448
449 assert_eq!(
450 detect_package_manager(&PathBuf::from("/path/to/poetry.lock")),
451 Some(PackageManager::Poetry)
452 );
453 }
454
455 #[test]
456 fn test_update_file_integration() -> Result<()> {
457 use crate::parsers::Dependency;
458 use check_updates_core::{Version, VersionSpec};
459
460 let updater = FileUpdater::new();
461
462 let mut temp_file = NamedTempFile::new()?;
464 writeln!(temp_file, "requests==2.28.0")?;
465 writeln!(temp_file, "numpy>=1.24.0,<2.0.0")?;
466 writeln!(temp_file, "flask==2.0.3")?;
467 temp_file.flush()?;
468
469 let temp_path = temp_file.path().to_path_buf();
470
471 let check1 = DependencyCheck {
473 dependency: Dependency {
474 name: "requests".to_string(),
475 version_spec: VersionSpec::Pinned(Version::new(2, 28, 0)),
476 source_file: temp_path.clone(),
477 line_number: 1,
478 original_line: "requests==2.28.0".to_string(),
479 },
480 installed: Some(Version::new(2, 28, 0)),
481 in_range: Some(Version::new(2, 32, 3)),
482 latest: Version::new(2, 32, 3),
483 target: Some(Version::new(2, 32, 3)),
484 target_spec: Some(VersionSpec::Pinned(Version::new(2, 32, 3))),
485 severity: Some(UpdateSeverity::Minor),
486 force_spec: Some(VersionSpec::Pinned(Version::new(2, 32, 3))),
487 };
488 let check2 = DependencyCheck {
489 dependency: Dependency {
490 name: "flask".to_string(),
491 version_spec: VersionSpec::Pinned(Version::new(2, 0, 3)),
492 source_file: temp_path.clone(),
493 line_number: 3,
494 original_line: "flask==2.0.3".to_string(),
495 },
496 installed: Some(Version::new(2, 0, 3)),
497 in_range: Some(Version::new(2, 3, 3)),
498 latest: Version::new(2, 3, 3),
499 target: Some(Version::new(2, 3, 3)),
500 target_spec: Some(VersionSpec::Pinned(Version::new(2, 3, 3))),
501 severity: Some(UpdateSeverity::Minor),
502 force_spec: Some(VersionSpec::Pinned(Version::new(2, 3, 3))),
503 };
504
505 let updates: Vec<(&DependencyCheck, String)> = vec![
507 (&check1, "==2.32.3".to_string()),
508 (&check2, "==2.3.3".to_string()),
509 ];
510
511 updater.update_file(&temp_path, &updates)?;
513
514 let updated_content = fs::read_to_string(&temp_path)?;
516 let lines: Vec<&str> = updated_content.lines().collect();
517
518 assert_eq!(lines[0], "requests==2.32.3");
520 assert_eq!(lines[1], "numpy>=1.24.0,<2.0.0"); assert_eq!(lines[2], "flask==2.3.3");
522
523 Ok(())
524 }
525
526 #[test]
527 fn test_update_patch_only() -> Result<()> {
528 use crate::parsers::Dependency;
529 use check_updates_core::{Version, VersionSpec};
530
531 let mut file = NamedTempFile::new()?;
532 writeln!(file, "serde==1.0.0")?;
533 writeln!(file, "tokio==1.0.0")?;
534 file.flush()?;
535
536 let temp_path = file.path().to_path_buf();
537
538 let checks = vec![
539 DependencyCheck {
540 dependency: Dependency {
541 name: "serde".to_string(),
542 version_spec: VersionSpec::Pinned(Version::new(1, 0, 0)),
543 source_file: temp_path.clone(),
544 line_number: 1,
545 original_line: "serde==1.0.0".to_string(),
546 },
547 installed: Some(Version::new(1, 0, 0)),
548 in_range: Some(Version::new(1, 0, 200)),
549 latest: Version::new(1, 0, 200),
550 target: Some(Version::new(1, 0, 200)),
551 target_spec: Some(VersionSpec::Pinned(Version::new(1, 0, 200))),
552 severity: Some(UpdateSeverity::Patch),
553 force_spec: Some(VersionSpec::Pinned(Version::new(1, 0, 200))),
554 },
555 DependencyCheck {
556 dependency: Dependency {
557 name: "tokio".to_string(),
558 version_spec: VersionSpec::Pinned(Version::new(1, 0, 0)),
559 source_file: temp_path.clone(),
560 line_number: 2,
561 original_line: "tokio==1.0.0".to_string(),
562 },
563 installed: Some(Version::new(1, 0, 0)),
564 in_range: Some(Version::new(1, 5, 0)),
565 latest: Version::new(1, 5, 0),
566 target: Some(Version::new(1, 5, 0)),
567 target_spec: Some(VersionSpec::Pinned(Version::new(1, 5, 0))),
568 severity: Some(UpdateSeverity::Minor),
569 force_spec: Some(VersionSpec::Pinned(Version::new(1, 5, 0))),
570 },
571 ];
572
573 let updater = FileUpdater::new();
574 updater.apply_updates(&checks, false, false)?; let content = fs::read_to_string(&temp_path)?;
577 assert!(content.contains("==1.0.200"), "serde should be updated: {}", content);
578 assert!(!content.contains("==1.5.0"), "tokio should NOT be updated: {}", content);
579
580 Ok(())
581 }
582
583 #[test]
584 fn test_update_patch_and_minor() -> Result<()> {
585 use crate::parsers::Dependency;
586 use check_updates_core::{Version, VersionSpec};
587
588 let mut file = NamedTempFile::new()?;
589 writeln!(file, "serde==1.0.0")?;
590 writeln!(file, "tokio==1.0.0")?;
591 file.flush()?;
592
593 let temp_path = file.path().to_path_buf();
594
595 let checks = vec![
596 DependencyCheck {
597 dependency: Dependency {
598 name: "serde".to_string(),
599 version_spec: VersionSpec::Pinned(Version::new(1, 0, 0)),
600 source_file: temp_path.clone(),
601 line_number: 1,
602 original_line: "serde==1.0.0".to_string(),
603 },
604 installed: Some(Version::new(1, 0, 0)),
605 in_range: Some(Version::new(1, 0, 200)),
606 latest: Version::new(1, 0, 200),
607 target: Some(Version::new(1, 0, 200)),
608 target_spec: Some(VersionSpec::Pinned(Version::new(1, 0, 200))),
609 severity: Some(UpdateSeverity::Patch),
610 force_spec: Some(VersionSpec::Pinned(Version::new(1, 0, 200))),
611 },
612 DependencyCheck {
613 dependency: Dependency {
614 name: "tokio".to_string(),
615 version_spec: VersionSpec::Pinned(Version::new(1, 0, 0)),
616 source_file: temp_path.clone(),
617 line_number: 2,
618 original_line: "tokio==1.0.0".to_string(),
619 },
620 installed: Some(Version::new(1, 0, 0)),
621 in_range: Some(Version::new(1, 5, 0)),
622 latest: Version::new(1, 5, 0),
623 target: Some(Version::new(1, 5, 0)),
624 target_spec: Some(VersionSpec::Pinned(Version::new(1, 5, 0))),
625 severity: Some(UpdateSeverity::Minor),
626 force_spec: Some(VersionSpec::Pinned(Version::new(1, 5, 0))),
627 },
628 ];
629
630 let updater = FileUpdater::new();
631 updater.apply_updates(&checks, true, false)?; let content = fs::read_to_string(&temp_path)?;
634 assert!(content.contains("==1.0.200"), "serde should be updated: {}", content);
635 assert!(content.contains("==1.5.0"), "tokio should be updated: {}", content);
636
637 Ok(())
638 }
639}