Skip to main content

embeddenator_workspace/
patch.rs

1//! Cargo patch management for local development.
2//!
3//! This module provides functionality to patch git dependencies to use local
4//! paths during development, and restore them when done.
5
6use anyhow::{Context, Result};
7use colored::Colorize;
8use std::collections::{HashMap, HashSet};
9use std::path::{Path, PathBuf};
10use toml_edit::{value, DocumentMut, Item, Table};
11
12use crate::workspace::WorkspaceScanner;
13
14/// Information about a git dependency that can be patched.
15#[derive(Debug, Clone)]
16pub struct GitDependency {
17    pub name: String,
18    pub git_url: String,
19    pub branch_or_tag: Option<String>,
20    pub local_path: PathBuf,
21}
22
23/// Manager for Cargo patch operations.
24pub struct PatchManager {
25    workspace_root: PathBuf,
26}
27
28impl PatchManager {
29    /// Create a new patch manager.
30    pub fn new(workspace_root: impl AsRef<Path>) -> Self {
31        Self {
32            workspace_root: workspace_root.as_ref().to_path_buf(),
33        }
34    }
35
36    /// Discover all embeddenator repos and their git dependencies.
37    pub fn discover_patchable_dependencies(&self) -> Result<Vec<GitDependency>> {
38        let scanner = WorkspaceScanner::new(&self.workspace_root);
39        let manifests = scanner.find_manifests()?;
40
41        let mut git_deps: HashMap<String, GitDependency> = HashMap::new();
42        let mut available_repos: HashSet<String> = HashSet::new();
43
44        // First pass: identify all available local repos
45        for manifest in &manifests {
46            if manifest.package_name.starts_with("embeddenator") {
47                available_repos.insert(manifest.package_name.clone());
48            }
49        }
50
51        // Second pass: find git dependencies that have local equivalents
52        for manifest in &manifests {
53            let content = std::fs::read_to_string(&manifest.path)?;
54            let doc: DocumentMut = content.parse()?;
55
56            // Check dependencies, dev-dependencies, build-dependencies
57            for section in &["dependencies", "dev-dependencies", "build-dependencies"] {
58                if let Some(Item::Table(deps_table)) = doc.get(section) {
59                    for (name, dep_item) in deps_table.iter() {
60                        if let Some(git_dep) = Self::parse_git_dependency(name, dep_item) {
61                            // Check if we have this repo locally
62                            if available_repos.contains(name) {
63                                // Find the local path
64                                if let Some(local_path) = self.find_local_repo_path(name) {
65                                    git_deps.insert(
66                                        name.to_string(),
67                                        GitDependency {
68                                            name: name.to_string(),
69                                            git_url: git_dep.0,
70                                            branch_or_tag: git_dep.1,
71                                            local_path,
72                                        },
73                                    );
74                                }
75                            }
76                        }
77                    }
78                }
79            }
80        }
81
82        let mut deps: Vec<GitDependency> = git_deps.into_values().collect();
83        deps.sort_by(|a, b| a.name.cmp(&b.name));
84        Ok(deps)
85    }
86
87    /// Parse git dependency from TOML item.
88    fn parse_git_dependency(_name: &str, item: &Item) -> Option<(String, Option<String>)> {
89        // Handle both inline tables and regular tables
90        let git_url = item.get("git")?.as_str()?.to_string();
91        let branch_or_tag = item
92            .get("branch")
93            .or_else(|| item.get("tag"))
94            .and_then(|v| v.as_str())
95            .map(|s| s.to_string());
96        Some((git_url, branch_or_tag))
97    }
98
99    /// Find the local path for a repository.
100    fn find_local_repo_path(&self, repo_name: &str) -> Option<PathBuf> {
101        let expected_path = self.workspace_root.join(repo_name);
102        if expected_path.join("Cargo.toml").exists() {
103            Some(expected_path)
104        } else {
105            None
106        }
107    }
108
109    /// Apply local patches to .cargo/config.toml
110    pub fn apply_patches(&self, deps: &[GitDependency], verify: bool) -> Result<PatchReport> {
111        let cargo_dir = self.workspace_root.join(".cargo");
112        let config_path = cargo_dir.join("config.toml");
113
114        // Create .cargo directory if it doesn't exist
115        if !cargo_dir.exists() {
116            std::fs::create_dir(&cargo_dir).context("Failed to create .cargo directory")?;
117        }
118
119        // Load or create config.toml
120        let mut doc: DocumentMut = if config_path.exists() {
121            let content = std::fs::read_to_string(&config_path)?;
122            content.parse()?
123        } else {
124            DocumentMut::new()
125        };
126
127        let mut patched_count = 0;
128
129        // Group dependencies by git URL
130        let mut patches_by_url: HashMap<String, Vec<&GitDependency>> = HashMap::new();
131        for dep in deps {
132            patches_by_url
133                .entry(dep.git_url.clone())
134                .or_default()
135                .push(dep);
136        }
137
138        // Apply patches for each git URL
139        for (git_url, deps_for_url) in patches_by_url {
140            let patch_key = format!("patch.\"{}\"", git_url);
141
142            // Create patch section if it doesn't exist
143            if doc.get(&patch_key).is_none() {
144                doc[&patch_key] = Item::Table(Table::new());
145            }
146
147            if let Some(Item::Table(patch_table)) = doc.get_mut(&patch_key) {
148                for dep in deps_for_url {
149                    // Create patch entry
150                    let mut dep_table = Table::new();
151                    dep_table.insert("path", value(dep.local_path.to_string_lossy().to_string()));
152
153                    patch_table.insert(&dep.name, Item::Table(dep_table));
154                    patched_count += 1;
155                }
156            }
157        }
158
159        // Save the config file
160        std::fs::write(&config_path, doc.to_string())
161            .context("Failed to write .cargo/config.toml")?;
162
163        let mut report = PatchReport {
164            patched_count,
165            config_path: config_path.clone(),
166            verified: false,
167            verification_error: None,
168        };
169
170        // Verify patches if requested
171        if verify {
172            match self.verify_patches() {
173                Ok(_) => report.verified = true,
174                Err(e) => report.verification_error = Some(e.to_string()),
175            }
176        }
177
178        Ok(report)
179    }
180
181    /// Remove all patches from .cargo/config.toml
182    pub fn remove_patches(&self) -> Result<ResetReport> {
183        let cargo_dir = self.workspace_root.join(".cargo");
184        let config_path = cargo_dir.join("config.toml");
185
186        if !config_path.exists() {
187            return Ok(ResetReport {
188                removed_count: 0,
189                config_path,
190                config_deleted: false,
191            });
192        }
193
194        let content = std::fs::read_to_string(&config_path)?;
195        let mut doc: DocumentMut = content.parse()?;
196
197        let mut removed_count = 0;
198
199        // Find all patch.* sections (both dotted keys like patch."url" and nested [patch] table)
200        let mut keys_to_remove = Vec::new();
201
202        for (key, _) in doc.as_table().iter() {
203            if key == "patch" {
204                // Handle [patch] table with nested sources
205                if let Some(Item::Table(patch_table)) = doc.get("patch") {
206                    for (_source_url, dep_item) in patch_table.iter() {
207                        if let Item::Table(deps) = dep_item {
208                            removed_count += deps.len();
209                        }
210                    }
211                }
212                keys_to_remove.push(key.to_string());
213            } else if key.starts_with("patch.") {
214                // Handle dotted keys like [patch."https://..."]
215                if let Some(Item::Table(patch_deps)) = doc.get(key) {
216                    removed_count += patch_deps.len();
217                }
218                keys_to_remove.push(key.to_string());
219            }
220        }
221
222        // Remove all patch sections
223        for key in keys_to_remove {
224            doc.remove(&key);
225        }
226
227        // Check if the document is now empty or only has whitespace
228        let is_empty = doc.as_table().is_empty();
229
230        if is_empty {
231            // Delete the config file
232            std::fs::remove_file(&config_path)?;
233            Ok(ResetReport {
234                removed_count,
235                config_path,
236                config_deleted: true,
237            })
238        } else {
239            // Save the modified config
240            std::fs::write(&config_path, doc.to_string())?;
241            Ok(ResetReport {
242                removed_count,
243                config_path,
244                config_deleted: false,
245            })
246        }
247    }
248
249    /// Verify that patches are working by running cargo metadata.
250    fn verify_patches(&self) -> Result<()> {
251        use std::process::Command;
252
253        let output = Command::new("cargo")
254            .arg("metadata")
255            .arg("--format-version=1")
256            .current_dir(&self.workspace_root)
257            .output()
258            .context("Failed to run cargo metadata")?;
259
260        if !output.status.success() {
261            let stderr = String::from_utf8_lossy(&output.stderr);
262            anyhow::bail!("cargo metadata failed:\n{}", stderr);
263        }
264
265        Ok(())
266    }
267
268    /// Clean cargo cache (useful after removing patches).
269    pub fn clean_cache(&self) -> Result<()> {
270        use std::process::Command;
271
272        println!("{}", "  Cleaning cargo cache...".dimmed());
273
274        let output = Command::new("cargo")
275            .arg("clean")
276            .current_dir(&self.workspace_root)
277            .output()
278            .context("Failed to run cargo clean")?;
279
280        if !output.status.success() {
281            let stderr = String::from_utf8_lossy(&output.stderr);
282            anyhow::bail!("cargo clean failed:\n{}", stderr);
283        }
284
285        Ok(())
286    }
287}
288
289/// Report from applying patches.
290#[derive(Debug)]
291pub struct PatchReport {
292    pub patched_count: usize,
293    pub config_path: PathBuf,
294    pub verified: bool,
295    pub verification_error: Option<String>,
296}
297
298/// Report from removing patches.
299#[derive(Debug)]
300pub struct ResetReport {
301    pub removed_count: usize,
302    pub config_path: PathBuf,
303    pub config_deleted: bool,
304}
305
306impl PatchReport {
307    pub fn print(&self) {
308        println!(
309            "\n{} {} patches written to {}",
310            "✓".green().bold(),
311            self.patched_count,
312            self.config_path.display().to_string().bright_white()
313        );
314
315        if self.verified {
316            println!("{} Patches verified successfully", "✓".green().bold());
317        } else if let Some(err) = &self.verification_error {
318            println!("{} Verification failed: {}", "✗".red().bold(), err);
319            println!(
320                "\n{} Run 'cargo build' to diagnose the issue",
321                "Suggestion:".cyan().bold()
322            );
323        }
324    }
325}
326
327impl ResetReport {
328    pub fn print(&self) {
329        if self.removed_count == 0 {
330            println!("{} No patches found to remove", "Info:".blue().bold());
331        } else {
332            println!(
333                "\n{} {} patches removed",
334                "✓".green().bold(),
335                self.removed_count
336            );
337
338            if self.config_deleted {
339                println!(
340                    "  {} deleted (empty)",
341                    self.config_path.display().to_string().dimmed()
342                );
343            } else {
344                println!(
345                    "  {} updated",
346                    self.config_path.display().to_string().dimmed()
347                );
348            }
349        }
350    }
351}
352
353#[cfg(test)]
354#[path = "patch_tests.rs"]
355mod tests;