cargo_autoinherit/
lib.rs

1use crate::dedup::MinimalVersionSet;
2use anyhow::{anyhow, Context};
3use cargo_manifest::{Dependency, DependencyDetail, DepsSet, Manifest, Workspace};
4use guppy::VersionReq;
5use std::collections::{BTreeMap, BTreeSet};
6use std::fmt::Formatter;
7use toml_edit::{Array, Key};
8
9mod dedup;
10
11#[derive(Debug, Default, Clone, clap::Args)]
12pub struct AutoInheritConf {
13    #[arg(
14        long,
15        help = "Represents inherited dependencies as `package.workspace = true` if possible."
16    )]
17    pub prefer_simple_dotted: bool,
18    /// Package name(s) of workspace member(s) to exclude.
19    #[arg(short, long)]
20    exclude_members: Vec<String>,
21}
22
23#[derive(Debug, Default)]
24struct AutoInheritMetadata {
25    exclude_members: Vec<String>,
26}
27
28impl AutoInheritMetadata {
29    fn from_workspace(workspace: &Workspace<toml::Table>) -> Result<Self, anyhow::Error> {
30        fn error() -> anyhow::Error {
31            anyhow!("Excpected value of `exclude` in `workspace.metadata.cargo-autoinherit` to be an array of strings")
32        }
33
34        let Some(exclude) = workspace
35            .metadata
36            .as_ref()
37            .and_then(|m| m.get("cargo-autoinherit"))
38            .and_then(|v| v.as_table())
39            .and_then(|t| t.get("exclude-members").or(t.get("exclude_members")))
40        else {
41            return Ok(Self::default());
42        };
43
44        let exclude: Vec<String> = match exclude {
45            toml::Value::Array(excluded) => excluded
46                .iter()
47                .map(|v| v.as_str().ok_or_else(error).map(|s| s.to_string()))
48                .try_fold(Vec::with_capacity(excluded.len()), |mut res, item| {
49                    res.push(item?);
50                    Ok::<_, anyhow::Error>(res)
51                })?,
52            _ => return Err(error()),
53        };
54        Ok(Self {
55            exclude_members: exclude,
56        })
57    }
58}
59
60/// Rewrites a `path` dependency as being absolute, based on a given path
61fn rewrite_dep_paths_as_absolute<'a, P: AsRef<std::path::Path>>(
62    deps: impl Iterator<Item = &'a mut Dependency>,
63    parent: P,
64) {
65    deps.for_each(|dep| {
66        if let Dependency::Detailed(detail) = dep {
67            detail.path = detail.path.as_mut().map(|path| {
68                parent
69                    .as_ref()
70                    .join(path)
71                    .canonicalize()
72                    .unwrap()
73                    .to_str()
74                    .expect("Canonicalized absolute path contained non-UTF-8 segments.")
75                    .to_string()
76            })
77        }
78    });
79}
80
81/// Rewrites a `path` dependency as being relative, based on a given path
82fn rewrite_dep_path_as_relative<P: AsRef<std::path::Path>>(dep: &mut Dependency, parent: P) {
83    if let Dependency::Detailed(detail) = dep {
84        detail.path = detail.path.as_mut().map(|path| {
85            pathdiff::diff_paths(path, parent.as_ref().canonicalize().unwrap())
86                .expect(
87                    "Error rewriting dependency path as relative: unable to determine path diff.",
88                )
89                .to_str()
90                .expect("Error rewriting dependency path as relative: path diff is not UTF-8.")
91                .to_string()
92        })
93    }
94}
95
96// Gets the first entry out of the document as a table if it exists,
97// or gets the second one if it doesn't. If that doesn't exist
98// either, then it returns an error.
99// Borrowing rules make it hard to do this in a function,
100// so here we are.
101macro_rules! get_either_table_mut {
102    ($first:literal, $second:literal, $manifest_toml:expr) => {
103        if let Some(i) = $manifest_toml
104            .get_mut($first)
105            .and_then(|d| d.as_table_mut())
106        {
107            Ok(i)
108        } else if let Some(i) = $manifest_toml
109            .get_mut($second)
110            .and_then(|d| d.as_table_mut())
111        {
112            Ok(i)
113        } else {
114            Err(anyhow::anyhow!(concat!(
115                "Failed to find `[",
116                $first,
117                "]` table in root manifest."
118            )))
119        }
120    };
121}
122
123pub fn auto_inherit(conf: AutoInheritConf) -> Result<(), anyhow::Error> {
124    let metadata = guppy::MetadataCommand::new().exec().context(
125        "Failed to execute `cargo metadata`. Was the command invoked inside a Rust project?",
126    )?;
127    let graph = metadata
128        .build_graph()
129        .context("Failed to build package graph")?;
130    let workspace_root = graph.workspace().root();
131    let mut root_manifest: Manifest<toml::Value, toml::Table> = {
132        let contents = fs_err::read_to_string(workspace_root.join("Cargo.toml").as_std_path())
133            .context("Failed to read root manifest")?;
134        toml::from_str(&contents).context("Failed to parse root manifest")?
135    };
136    let Some(workspace) = &mut root_manifest.workspace else {
137        anyhow::bail!(
138            "`cargo autoinherit` can only be run in a workspace. \
139            The root manifest ({}) does not have a `workspace` field.",
140            workspace_root
141        )
142    };
143
144    let autoinherit_metadata = AutoInheritMetadata::from_workspace(workspace)?;
145    let excluded_members = BTreeSet::from_iter(
146        conf.exclude_members
147            .into_iter()
148            .chain(autoinherit_metadata.exclude_members),
149    );
150
151    let mut package_name2specs: BTreeMap<String, Action> = BTreeMap::new();
152    if let Some(deps) = &mut workspace.dependencies {
153        rewrite_dep_paths_as_absolute(deps.values_mut(), workspace_root);
154        process_deps(deps, &mut package_name2specs);
155    }
156
157    for member_id in graph.workspace().member_ids() {
158        let package = graph.metadata(member_id)?;
159        assert!(package.in_workspace());
160
161        let mut manifest: Manifest = {
162            if excluded_members.contains(package.name()) {
163                println!("Excluded workspace member `{}`", package.name());
164                continue;
165            }
166            let contents = fs_err::read_to_string(package.manifest_path().as_std_path())
167                .context("Failed to read root manifest")?;
168            toml::from_str(&contents).context("Failed to parse root manifest")?
169        };
170        if let Some(deps) = &mut manifest.dependencies {
171            rewrite_dep_paths_as_absolute(
172                deps.values_mut(),
173                package.manifest_path().parent().unwrap(),
174            );
175            process_deps(deps, &mut package_name2specs);
176        }
177        if let Some(deps) = &mut manifest.dev_dependencies {
178            rewrite_dep_paths_as_absolute(
179                deps.values_mut(),
180                package.manifest_path().parent().unwrap(),
181            );
182            process_deps(deps, &mut package_name2specs);
183        }
184        if let Some(deps) = &mut manifest.build_dependencies {
185            rewrite_dep_paths_as_absolute(
186                deps.values_mut(),
187                package.manifest_path().parent().unwrap(),
188            );
189            process_deps(deps, &mut package_name2specs);
190        }
191    }
192
193    let mut package_name2inherited_source: BTreeMap<String, SharedDependency> = BTreeMap::new();
194    'outer: for (package_name, action) in package_name2specs {
195        let Action::TryInherit(specs) = action else {
196            eprintln!("`{package_name}` won't be auto-inherited because it appears at least once from a source type \
197                that we currently don't support (e.g. private registry, path dependency).");
198            continue;
199        };
200        if specs.len() > 1 {
201            eprintln!("`{package_name}` won't be auto-inherited because there are multiple sources for it:");
202            for spec in specs.into_iter() {
203                eprintln!("  - {}", spec.source);
204            }
205            continue 'outer;
206        }
207
208        let spec = specs.into_iter().next().unwrap();
209        package_name2inherited_source.insert(package_name, spec);
210    }
211
212    // Add new "shared" dependencies to `[workspace.dependencies]`
213    let mut workspace_toml: toml_edit::DocumentMut = {
214        let contents = fs_err::read_to_string(workspace_root.join("Cargo.toml").as_std_path())
215            .context("Failed to read root manifest")?;
216        contents.parse().context("Failed to parse root manifest")?
217    };
218    let workspace_table = workspace_toml.as_table_mut()["workspace"]
219        .as_table_mut()
220        .expect(
221            "Failed to find `[workspace]` table in root manifest. \
222        This is a bug in `cargo_autoinherit`.",
223        );
224    let workspace_deps = workspace_table
225        .entry("dependencies")
226        .or_insert(toml_edit::Item::Table(toml_edit::Table::new()))
227        .as_table_mut()
228        .expect("Failed to find `[workspace.dependencies]` table in root manifest.");
229    let mut was_modified = false;
230    for (package_name, source) in &package_name2inherited_source {
231        if workspace_deps.get(package_name).is_some() {
232            continue;
233        } else {
234            let mut dep = shared2dep(source);
235            rewrite_dep_path_as_relative(&mut dep, workspace_root);
236
237            insert_preserving_decor(workspace_deps, package_name, dep2toml_item(&dep));
238            was_modified = true;
239        }
240    }
241    if was_modified {
242        fs_err::write(
243            workspace_root.join("Cargo.toml").as_std_path(),
244            workspace_toml.to_string(),
245        )
246        .context("Failed to write manifest")?;
247    }
248
249    // Inherit new "shared" dependencies in each member's manifest
250    for member_id in graph.workspace().member_ids() {
251        let package = graph.metadata(member_id)?;
252        if excluded_members.contains(package.name()) {
253            continue;
254        }
255
256        let manifest_contents = fs_err::read_to_string(package.manifest_path().as_std_path())
257            .context("Failed to read root manifest")?;
258        let manifest: Manifest =
259            toml::from_str(&manifest_contents).context("Failed to parse root manifest")?;
260        let mut manifest_toml: toml_edit::DocumentMut = manifest_contents
261            .parse()
262            .context("Failed to parse root manifest")?;
263        let mut was_modified = false;
264        if let Some(deps) = &manifest.dependencies {
265            let deps_toml = manifest_toml["dependencies"]
266                .as_table_mut()
267                .expect("Failed to find `[dependencies]` table in root manifest.");
268            inherit_deps(
269                deps,
270                deps_toml,
271                &package_name2inherited_source,
272                &mut was_modified,
273                conf.prefer_simple_dotted,
274            );
275        }
276        if let Some(deps) = &manifest.dev_dependencies {
277            let deps_toml =
278                get_either_table_mut!("dev-dependencies", "dev_dependencies", manifest_toml)?;
279
280            inherit_deps(
281                deps,
282                deps_toml,
283                &package_name2inherited_source,
284                &mut was_modified,
285                conf.prefer_simple_dotted,
286            );
287        }
288        if let Some(deps) = &manifest.build_dependencies {
289            let deps_toml =
290                get_either_table_mut!("build-dependencies", "build_dependencies", manifest_toml)?;
291
292            inherit_deps(
293                deps,
294                deps_toml,
295                &package_name2inherited_source,
296                &mut was_modified,
297                conf.prefer_simple_dotted,
298            );
299        }
300        if was_modified {
301            fs_err::write(
302                package.manifest_path().as_std_path(),
303                manifest_toml.to_string(),
304            )
305            .context("Failed to write manifest")?;
306        }
307    }
308
309    Ok(())
310}
311
312enum Action {
313    TryInherit(MinimalVersionSet),
314    Skip,
315}
316
317impl Default for Action {
318    fn default() -> Self {
319        Action::TryInherit(MinimalVersionSet::default())
320    }
321}
322
323fn inherit_deps(
324    deps: &DepsSet,
325    toml_deps: &mut toml_edit::Table,
326    package_name2spec: &BTreeMap<String, SharedDependency>,
327    was_modified: &mut bool,
328    prefer_simple_dotted: bool,
329) {
330    for (name, dep) in deps {
331        let package_name = dep.package().unwrap_or(name.as_str());
332        if !package_name2spec.contains_key(package_name) {
333            continue;
334        }
335        match dep {
336            Dependency::Simple(_) => {
337                let mut inherited = toml_edit::InlineTable::new();
338                inherited.insert("workspace", toml_edit::value(true).into_value().unwrap());
339                inherited.set_dotted(prefer_simple_dotted);
340
341                insert_preserving_decor(toml_deps, name, toml_edit::Item::Value(inherited.into()));
342                *was_modified = true;
343            }
344            Dependency::Inherited(_) => {
345                // Nothing to do.
346            }
347            Dependency::Detailed(details) => {
348                let mut inherited = toml_edit::InlineTable::new();
349                inherited.insert("workspace", toml_edit::value(true).into_value().unwrap());
350                if let Some(features) = &details.features {
351                    inherited.insert(
352                        "features",
353                        toml_edit::Value::Array(Array::from_iter(features.iter())),
354                    );
355                }
356                if let Some(optional) = details.optional {
357                    inherited.insert("optional", toml_edit::value(optional).into_value().unwrap());
358                }
359
360                if inherited.len() == 1 {
361                    inherited.set_dotted(prefer_simple_dotted);
362                }
363
364                insert_preserving_decor(toml_deps, name, toml_edit::Item::Value(inherited.into()));
365                *was_modified = true;
366            }
367        }
368    }
369}
370
371fn insert_preserving_decor(table: &mut toml_edit::Table, key: &str, mut value: toml_edit::Item) {
372    fn get_decor(item: &toml_edit::Item) -> Option<toml_edit::Decor> {
373        match item {
374            toml_edit::Item::Value(v) => Some(v.decor().clone()),
375            toml_edit::Item::Table(t) => Some(t.decor().clone()),
376            _ => None,
377        }
378    }
379
380    fn set_decor(item: &mut toml_edit::Item, decor: toml_edit::Decor) {
381        match item {
382            toml_edit::Item::Value(v) => {
383                *v.decor_mut() = decor;
384            }
385            toml_edit::Item::Table(t) => {
386                *t.decor_mut() = decor;
387            }
388            _ => unreachable!(),
389        }
390    }
391
392    let mut new_key = Key::new(key);
393    if let Some((existing_key, existing_value)) = table.get_key_value(key) {
394        new_key = new_key.with_leaf_decor(existing_key.leaf_decor().to_owned());
395
396        if let Some(mut decor) = get_decor(existing_value) {
397            // Tables tend to have newline whitespacing that doesn't agree with other types
398            if existing_value.is_table() && !value.is_table() {
399                decor.set_prefix(" ");
400            }
401            set_decor(&mut value, decor);
402        }
403    }
404    table.insert_formatted(&new_key, value);
405}
406
407fn process_deps(deps: &DepsSet, package_name2specs: &mut BTreeMap<String, Action>) {
408    for (name, details) in deps {
409        match dep2shared_dep(details) {
410            SourceType::Shareable(source) => {
411                let action = package_name2specs.entry(name.clone()).or_default();
412                if let Action::TryInherit(set) = action {
413                    set.insert(source);
414                }
415            }
416            SourceType::Inherited => {}
417            SourceType::MustBeSkipped => {
418                package_name2specs.insert(name.clone(), Action::Skip);
419            }
420        }
421    }
422}
423
424#[derive(Clone, Debug, Eq, PartialEq, Hash)]
425struct SharedDependency {
426    default_features: bool,
427    source: DependencySource,
428}
429
430#[derive(Clone, Debug, Eq, PartialEq, Hash)]
431enum DependencySource {
432    Version(VersionReq),
433    Git {
434        git: String,
435        branch: Option<String>,
436        tag: Option<String>,
437        rev: Option<String>,
438        version: Option<VersionReq>,
439    },
440    Path {
441        path: String,
442        version: Option<VersionReq>,
443    },
444}
445
446impl std::fmt::Display for DependencySource {
447    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
448        match self {
449            DependencySource::Version(version) => write!(f, "version: {}", version),
450            DependencySource::Git {
451                git,
452                branch,
453                tag,
454                rev,
455                version,
456            } => {
457                write!(f, "git: {}", git)?;
458                if let Some(branch) = branch {
459                    write!(f, ", branch: {}", branch)?;
460                }
461                if let Some(tag) = tag {
462                    write!(f, ", tag: {}", tag)?;
463                }
464                if let Some(rev) = rev {
465                    write!(f, ", rev: {}", rev)?;
466                }
467                if let Some(version) = version {
468                    write!(f, ", version: {}", version)?;
469                }
470                Ok(())
471            }
472            DependencySource::Path { path, version } => {
473                write!(f, "path: {}", path)?;
474                if let Some(version) = version {
475                    write!(f, ", version: {}", version)?;
476                }
477                Ok(())
478            }
479        }
480    }
481}
482
483enum SourceType {
484    Shareable(SharedDependency),
485    Inherited,
486    MustBeSkipped,
487}
488
489fn dep2shared_dep(dep: &Dependency) -> SourceType {
490    match dep {
491        Dependency::Simple(version) => {
492            let version_req =
493                VersionReq::parse(version).expect("Failed to parse version requirement");
494            SourceType::Shareable(SharedDependency {
495                default_features: true,
496                source: DependencySource::Version(version_req),
497            })
498        }
499        Dependency::Inherited(_) => SourceType::Inherited,
500        Dependency::Detailed(d) => {
501            let mut source = None;
502            // We ignore custom registries for now.
503            if d.registry.is_some() || d.registry_index.is_some() {
504                return SourceType::MustBeSkipped;
505            }
506            if d.path.is_some() {
507                source = Some(DependencySource::Path {
508                    path: d.path.as_ref().unwrap().to_owned(),
509                    version: d.version.as_ref().map(|v| {
510                        VersionReq::parse(v).expect("Failed to parse version requirement")
511                    }),
512                });
513            } else if let Some(git) = &d.git {
514                source = Some(DependencySource::Git {
515                    git: git.to_owned(),
516                    branch: d.branch.to_owned(),
517                    tag: d.tag.to_owned(),
518                    rev: d.rev.to_owned(),
519                    version: d.version.as_ref().map(|v| {
520                        VersionReq::parse(v).expect("Failed to parse version requirement")
521                    }),
522                });
523            } else if let Some(version) = &d.version {
524                let version_req =
525                    VersionReq::parse(version).expect("Failed to parse version requirement");
526                source = Some(DependencySource::Version(version_req));
527            }
528            match source {
529                None => SourceType::MustBeSkipped,
530                Some(source) => SourceType::Shareable(SharedDependency {
531                    default_features: d.default_features.unwrap_or(true),
532                    source,
533                }),
534            }
535        }
536    }
537}
538
539fn shared2dep(shared_dependency: &SharedDependency) -> Dependency {
540    let SharedDependency {
541        default_features,
542        source,
543    } = shared_dependency;
544    match source {
545        DependencySource::Version(version) => {
546            if *default_features {
547                Dependency::Simple(version.to_string())
548            } else {
549                Dependency::Detailed(DependencyDetail {
550                    version: Some(version.to_string()),
551                    default_features: Some(false),
552                    ..DependencyDetail::default()
553                })
554            }
555        }
556        DependencySource::Git {
557            git,
558            branch,
559            tag,
560            rev,
561            version,
562        } => Dependency::Detailed(DependencyDetail {
563            package: None,
564            version: version.as_ref().map(|v| v.to_string()),
565            registry: None,
566            registry_index: None,
567            path: None,
568            git: Some(git.clone()),
569            branch: branch.clone(),
570            tag: tag.clone(),
571            rev: rev.clone(),
572            features: None,
573            optional: None,
574            default_features: if *default_features { None } else { Some(false) },
575        }),
576        DependencySource::Path { path, version } => Dependency::Detailed(DependencyDetail {
577            package: None,
578            version: version.as_ref().map(|v| v.to_string()),
579            registry: None,
580            registry_index: None,
581            path: Some(path.clone()),
582            git: None,
583            branch: None,
584            tag: None,
585            rev: None,
586            features: None,
587            optional: None,
588            default_features: if *default_features { None } else { Some(false) },
589        }),
590    }
591}
592
593fn dep2toml_item(dependency: &Dependency) -> toml_edit::Item {
594    match dependency {
595        Dependency::Simple(version) => toml_edit::value(version.trim_start_matches('^').to_owned()),
596        Dependency::Inherited(inherited) => {
597            let mut table = toml_edit::InlineTable::new();
598            table.get_or_insert("workspace", true);
599            if let Some(features) = &inherited.features {
600                table.get_or_insert("features", Array::from_iter(features.iter()));
601            }
602            if let Some(optional) = inherited.optional {
603                table.get_or_insert("optional", optional);
604            }
605            toml_edit::Item::Value(toml_edit::Value::InlineTable(table))
606        }
607        Dependency::Detailed(details) => {
608            let mut table = toml_edit::InlineTable::new();
609            let DependencyDetail {
610                version,
611                registry,
612                registry_index,
613                path,
614                git,
615                branch,
616                tag,
617                rev,
618                features,
619                optional,
620                default_features,
621                package,
622            } = details;
623
624            if let Some(version) = version {
625                table.get_or_insert("version", version.trim_start_matches('^'));
626            }
627            if let Some(registry) = registry {
628                table.get_or_insert("registry", registry);
629            }
630            if let Some(registry_index) = registry_index {
631                table.get_or_insert("registry-index", registry_index);
632            }
633            if let Some(path) = path {
634                table.get_or_insert("path", path);
635            }
636            if let Some(git) = git {
637                table.get_or_insert("git", git);
638            }
639            if let Some(branch) = branch {
640                table.get_or_insert("branch", branch);
641            }
642            if let Some(tag) = tag {
643                table.get_or_insert("tag", tag);
644            }
645            if let Some(rev) = rev {
646                table.get_or_insert("rev", rev);
647            }
648            if let Some(features) = features {
649                table.get_or_insert("features", Array::from_iter(features.iter()));
650            }
651            if let Some(optional) = optional {
652                table.get_or_insert(
653                    "optional",
654                    toml_edit::value(*optional).into_value().unwrap(),
655                );
656            }
657            if let Some(default_features) = default_features {
658                table.get_or_insert(
659                    "default-features",
660                    toml_edit::value(*default_features).into_value().unwrap(),
661                );
662            }
663            if let Some(package) = package {
664                table.get_or_insert("package", package);
665            }
666
667            toml_edit::Item::Value(toml_edit::Value::InlineTable(table))
668        }
669    }
670}