1use crate::resolver::DependencyCheck;
2use crate::detector::PackageManager;
3use anyhow::{Context, Result};
4use std::collections::{HashSet, HashMap};
5use std::path::PathBuf;
6use std::fs;
7
8pub struct FileUpdater;
10
11impl FileUpdater {
12 pub fn new() -> Self {
13 Self
14 }
15
16 pub fn apply_updates(&self, checks: &[DependencyCheck]) -> Result<UpdateResult> {
18 let mut modified_files = HashSet::new();
19 let mut package_file_map: HashMap<String, Vec<PathBuf>> = HashMap::new();
20 let mut package_managers = HashSet::new();
21
22 let mut file_updates: HashMap<PathBuf, Vec<&DependencyCheck>> = HashMap::new();
24 for check in checks {
25 if check.has_update() {
26 file_updates
27 .entry(check.dependency.source_file.clone())
28 .or_insert_with(Vec::new)
29 .push(check);
30
31 package_file_map
33 .entry(check.dependency.name.clone())
34 .or_insert_with(Vec::new)
35 .push(check.dependency.source_file.clone());
36 }
37 }
38
39 for (file_path, updates) in file_updates {
41 self.update_file(&file_path, updates)
42 .with_context(|| format!("Failed to update file: {}", file_path.display()))?;
43
44 modified_files.insert(file_path.clone());
45
46 if let Some(pm) = detect_package_manager(&file_path) {
48 package_managers.insert(pm);
49 }
50 }
51
52 let mut multi_file_packages: Vec<String> = package_file_map
54 .iter()
55 .filter_map(|(pkg, files)| {
56 let unique_files: HashSet<_> = files.iter().collect();
57 if unique_files.len() > 1 {
58 Some(pkg.clone())
59 } else {
60 None
61 }
62 })
63 .collect();
64 multi_file_packages.sort();
65
66 Ok(UpdateResult {
67 modified_files,
68 multi_file_packages,
69 package_managers,
70 })
71 }
72
73 fn update_file(&self, file_path: &PathBuf, updates: Vec<&DependencyCheck>) -> Result<()> {
75 let content = fs::read_to_string(file_path)
77 .with_context(|| format!("Failed to read file: {}", file_path.display()))?;
78
79 let mut lines: Vec<String> = content.lines().map(|s| s.to_string()).collect();
80
81 let mut sorted_updates = updates;
83 sorted_updates.sort_by(|a, b| b.dependency.line_number.cmp(&a.dependency.line_number));
84
85 for update in sorted_updates {
87 let line_idx = update.dependency.line_number.saturating_sub(1);
88
89 if line_idx >= lines.len() {
90 continue; }
92
93 let original_line = &lines[line_idx];
94
95 if let Some(new_spec) = &update.update_to {
97 let updated_line = self.replace_version_in_line(
98 original_line,
99 &update.dependency.name,
100 &update.dependency.version_spec.to_string(),
101 &new_spec.to_string(),
102 file_path,
103 )?;
104
105 lines[line_idx] = updated_line;
106 }
107 }
108
109 let new_content = lines.join("\n");
111 let new_content = if content.ends_with('\n') {
113 format!("{}\n", new_content)
114 } else {
115 new_content
116 };
117
118 fs::write(file_path, new_content)
119 .with_context(|| format!("Failed to write file: {}", file_path.display()))?;
120
121 Ok(())
122 }
123
124 fn replace_version_in_line(
126 &self,
127 line: &str,
128 package_name: &str,
129 old_spec: &str,
130 new_spec: &str,
131 file_path: &PathBuf,
132 ) -> Result<String> {
133 let file_name = file_path.file_name()
134 .and_then(|n| n.to_str())
135 .unwrap_or("");
136
137 if file_name.starts_with("requirements") || file_name.ends_with(".txt") {
139 self.replace_in_requirements(line, package_name, old_spec, new_spec)
140 } else if file_name == "pyproject.toml" {
141 self.replace_in_pyproject(line, package_name, old_spec, new_spec)
142 } else if file_name.starts_with("environment.") &&
143 (file_name.ends_with(".yml") || file_name.ends_with(".yaml")) {
144 self.replace_in_conda(line, package_name, old_spec, new_spec)
145 } else {
146 self.replace_in_requirements(line, package_name, old_spec, new_spec)
148 }
149 }
150
151 fn replace_in_requirements(
153 &self,
154 line: &str,
155 package_name: &str,
156 old_spec: &str,
157 new_spec: &str,
158 ) -> Result<String> {
159 if let Some(new_line) = line.replace(&format!("{}{}", package_name, old_spec),
163 &format!("{}{}", package_name, new_spec))
164 .into() {
165 if new_line != line {
166 return Ok(new_line);
167 }
168 }
169
170 if line.contains('[') {
172 if let Some(bracket_start) = line.find('[') {
173 if let Some(bracket_end) = line.find(']') {
174 let before_bracket = &line[..bracket_start];
175 let extras = &line[bracket_start..=bracket_end];
176 let after_bracket = &line[bracket_end + 1..];
177
178 if before_bracket.trim() == package_name {
179 let new_after = after_bracket.replace(old_spec, new_spec);
180 return Ok(format!("{}{}{}", before_bracket, extras, new_after));
181 }
182 }
183 }
184 }
185
186 Ok(line.replace(old_spec, new_spec))
188 }
189
190 fn replace_in_pyproject(
192 &self,
193 line: &str,
194 package_name: &str,
195 old_spec: &str,
196 new_spec: &str,
197 ) -> Result<String> {
198 if line.to_lowercase().contains(&package_name.to_lowercase()) {
202 let result = line.replace(
204 &format!("\"{}\"", old_spec),
205 &format!("\"{}\"", new_spec)
206 );
207 if result != line {
208 return Ok(result);
209 }
210
211 let result = line.replace(
213 &format!("'{}'", old_spec),
214 &format!("'{}'", new_spec)
215 );
216 if result != line {
217 return Ok(result);
218 }
219 }
220
221 Ok(line.replace(old_spec, new_spec))
223 }
224
225 fn replace_in_conda(
227 &self,
228 line: &str,
229 package_name: &str,
230 old_spec: &str,
231 new_spec: &str,
232 ) -> Result<String> {
233 let conda_old_spec = old_spec.replace("==", "=");
237 let conda_new_spec = new_spec.replace("==", "=");
238
239 let result = line.replace(
241 &format!("{}{}", package_name, old_spec),
242 &format!("{}{}", package_name, new_spec)
243 );
244 if result != line {
245 return Ok(result);
246 }
247
248 let result = line.replace(
250 &format!("{}{}", package_name, conda_old_spec),
251 &format!("{}{}", package_name, conda_new_spec)
252 );
253 if result != line {
254 return Ok(result);
255 }
256
257 Ok(line.replace(old_spec, new_spec))
259 }
260}
261
262fn detect_package_manager(path: &PathBuf) -> Option<PackageManager> {
264 let file_name = path.file_name()?.to_str()?;
265
266 if file_name.starts_with("requirements") {
267 Some(PackageManager::Pip)
268 } else if file_name == "pyproject.toml" {
269 Some(PackageManager::Uv)
272 } else if file_name.starts_with("environment.") &&
273 (file_name.ends_with(".yml") || file_name.ends_with(".yaml")) {
274 Some(PackageManager::Conda)
275 } else if file_name == "uv.lock" {
276 Some(PackageManager::Uv)
277 } else if file_name == "poetry.lock" {
278 Some(PackageManager::Poetry)
279 } else if file_name == "pdm.lock" {
280 Some(PackageManager::Pdm)
281 } else {
282 None
283 }
284}
285
286#[derive(Debug)]
288pub struct UpdateResult {
289 pub modified_files: HashSet<PathBuf>,
291 pub multi_file_packages: Vec<String>,
293 pub package_managers: HashSet<PackageManager>,
295}
296
297impl UpdateResult {
298 pub fn print_summary(&self) {
300 if !self.multi_file_packages.is_empty() {
301 println!(
302 "\nNote: The following packages were updated in multiple files: {}",
303 self.multi_file_packages.join(", ")
304 );
305 }
306
307 for pm in &self.package_managers {
308 let cmd = match pm {
309 PackageManager::Pip => "pip install -r requirements.txt",
310 PackageManager::Uv => "uv lock",
311 PackageManager::Poetry => "poetry lock",
312 PackageManager::Pdm => "pdm lock",
313 PackageManager::Conda => "conda env update",
314 };
315 println!("Run {} to sync dependencies", cmd);
316 }
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use std::io::Write;
324 use tempfile::NamedTempFile;
325
326 #[test]
327 fn test_replace_in_requirements() {
328 let updater = FileUpdater::new();
329
330 let result = updater.replace_in_requirements(
332 "requests==2.28.0",
333 "requests",
334 "==2.28.0",
335 "==2.32.3"
336 ).unwrap();
337 assert_eq!(result, "requests==2.32.3");
338
339 let result = updater.replace_in_requirements(
341 "numpy>=1.24.0,<2.0.0",
342 "numpy",
343 ">=1.24.0,<2.0.0",
344 ">=1.26.0,<2.0.0"
345 ).unwrap();
346 assert_eq!(result, "numpy>=1.26.0,<2.0.0");
347
348 let result = updater.replace_in_requirements(
350 "requests[security]==2.28.0",
351 "requests",
352 "==2.28.0",
353 "==2.32.3"
354 ).unwrap();
355 assert_eq!(result, "requests[security]==2.32.3");
356 }
357
358 #[test]
359 fn test_replace_in_pyproject() {
360 let updater = FileUpdater::new();
361
362 let result = updater.replace_in_pyproject(
364 "requests = \"^2.28.0\"",
365 "requests",
366 "^2.28.0",
367 "^2.32.3"
368 ).unwrap();
369 assert_eq!(result, "requests = \"^2.32.3\"");
370
371 let result = updater.replace_in_pyproject(
373 "numpy = '^1.24.0'",
374 "numpy",
375 "^1.24.0",
376 "^1.26.0"
377 ).unwrap();
378 assert_eq!(result, "numpy = '^1.26.0'");
379 }
380
381 #[test]
382 fn test_replace_in_conda() {
383 let updater = FileUpdater::new();
384
385 let result = updater.replace_in_conda(
387 " - numpy==1.24.0",
388 "numpy",
389 "==1.24.0",
390 "==1.26.0"
391 ).unwrap();
392 assert_eq!(result, " - numpy==1.26.0");
393
394 let result = updater.replace_in_conda(
396 " - requests=2.28.0",
397 "requests",
398 "==2.28.0",
399 "==2.32.3"
400 ).unwrap();
401 assert_eq!(result, " - requests=2.32.3");
402 }
403
404 #[test]
405 fn test_detect_package_manager() {
406 assert_eq!(
407 detect_package_manager(&PathBuf::from("/path/to/requirements.txt")),
408 Some(PackageManager::Pip)
409 );
410
411 assert_eq!(
412 detect_package_manager(&PathBuf::from("/path/to/requirements-dev.txt")),
413 Some(PackageManager::Pip)
414 );
415
416 assert_eq!(
417 detect_package_manager(&PathBuf::from("/path/to/pyproject.toml")),
418 Some(PackageManager::Uv)
419 );
420
421 assert_eq!(
422 detect_package_manager(&PathBuf::from("/path/to/environment.yml")),
423 Some(PackageManager::Conda)
424 );
425
426 assert_eq!(
427 detect_package_manager(&PathBuf::from("/path/to/poetry.lock")),
428 Some(PackageManager::Poetry)
429 );
430 }
431
432 #[test]
433 fn test_update_file_integration() -> Result<()> {
434 use crate::parsers::Dependency;
435 use crate::version::{Version, VersionSpec};
436
437 let updater = FileUpdater::new();
438
439 let mut temp_file = NamedTempFile::new()?;
441 writeln!(temp_file, "requests==2.28.0")?;
442 writeln!(temp_file, "numpy>=1.24.0,<2.0.0")?;
443 writeln!(temp_file, "flask==2.0.3")?;
444 temp_file.flush()?;
445
446 let temp_path = temp_file.path().to_path_buf();
447
448 let checks = vec![
450 DependencyCheck {
451 dependency: Dependency {
452 name: "requests".to_string(),
453 version_spec: VersionSpec::Pinned(Version::new(2, 28, 0)),
454 source_file: temp_path.clone(),
455 line_number: 1,
456 original_line: "requests==2.28.0".to_string(),
457 },
458 installed: Some(Version::new(2, 28, 0)),
459 in_range: Some(Version::new(2, 32, 3)),
460 latest: Version::new(2, 32, 3),
461 update_to: Some(VersionSpec::Pinned(Version::new(2, 32, 3))),
462 },
463 DependencyCheck {
464 dependency: Dependency {
465 name: "flask".to_string(),
466 version_spec: VersionSpec::Pinned(Version::new(2, 0, 3)),
467 source_file: temp_path.clone(),
468 line_number: 3,
469 original_line: "flask==2.0.3".to_string(),
470 },
471 installed: Some(Version::new(2, 0, 3)),
472 in_range: Some(Version::new(2, 3, 3)),
473 latest: Version::new(2, 3, 3),
474 update_to: Some(VersionSpec::Pinned(Version::new(2, 3, 3))),
475 },
476 ];
477
478 updater.update_file(&temp_path, checks.iter().collect())?;
480
481 let updated_content = fs::read_to_string(&temp_path)?;
483 let lines: Vec<&str> = updated_content.lines().collect();
484
485 assert_eq!(lines[0], "requests==2.32.3");
487 assert_eq!(lines[1], "numpy>=1.24.0,<2.0.0"); assert_eq!(lines[2], "flask==2.3.3");
489
490 Ok(())
491 }
492}