use std::collections::{BTreeSet, HashSet};
use std::path::PathBuf;
use crate::diagnostic::Severity;
use clap::{ArgAction, Parser};
use guppy::PackageId;
use guppy::graph::DependencyDirection;
use std::collections::HashMap;
use crate::analysis::ExternItemCoordinates;
use crate::analysis::partitioning::{self, HeaderFilenames, default_header_base_name};
use crate::config::{Language, PackageConfig, PackageTypeMode, Style};
use crate::diagnostic::{DiagnosticSink, render_diagnostics};
use crate::{analysis, codegen, config, metadata, topological_sort};
use super::input::{PackageSelection, filter_library_targets, resolve_input, select_packages};
use crate::Collection;
#[derive(Debug, Parser)]
pub(super) struct GenerateArgs {
input: Option<PathBuf>,
#[command(flatten)]
package_selection: PackageSelection,
#[arg(short, action = ArgAction::Count)]
verbose: u8,
#[arg(short, long)]
quiet: bool,
#[arg(short, long)]
config: Option<PathBuf>,
#[arg(short, long)]
lang: Language,
#[arg(long)]
cpp_compat: bool,
#[arg(short, long)]
style: Option<Style>,
#[arg(short, long = "output-dir")]
output_dir: PathBuf,
#[arg(long)]
metadata: Option<PathBuf>,
#[arg(long)]
symbol_file: Option<PathBuf>,
#[arg(long)]
no_header: bool,
#[arg(long)]
bundle: bool,
#[arg(long)]
prune_orphans: bool,
#[arg(long, conflicts_with = "bundle")]
skip_empty: bool,
}
pub(crate) struct PackageTypeOverrides {
pub opaque: HashSet<guppy::PackageId>,
pub skipped: HashSet<guppy::PackageId>,
pub usize_is_size_t: HashMap<guppy::PackageId, bool>,
pub global_usize_is_size_t: bool,
}
impl PackageTypeOverrides {
pub fn usize_is_size_t(&self, package_id: &guppy::PackageId) -> bool {
self.usize_is_size_t
.get(package_id)
.copied()
.unwrap_or(self.global_usize_is_size_t)
}
}
fn parse_package_key(key: &str) -> Result<(&str, Option<guppy::VersionReq>), config::ConfigError> {
if let Some((name, version_str)) = key.split_once('@') {
let req = guppy::VersionReq::parse(version_str).map_err(|e| config::ConfigError {
message: format!(
"invalid version requirement `{version_str}` in [package.\"{key}\"]: {e}"
),
})?;
Ok((name, Some(req)))
} else {
Ok((key, None))
}
}
fn resolve_package_overrides(
package_configs: &HashMap<String, PackageConfig>,
global_usize_is_size_t: bool,
collection: &Collection,
diagnostics: &mut DiagnosticSink,
) -> Result<PackageTypeOverrides, anyhow::Error> {
let mut opaque = HashSet::new();
let mut skipped = HashSet::new();
let mut usize_is_size_t: HashMap<guppy::PackageId, bool> = HashMap::new();
let graph = collection.package_graph();
for (key, config) in package_configs {
let (name, version_req) = parse_package_key(key)?;
let package_set = graph.resolve_package_name(name);
if package_set.is_empty() {
diagnostics
.error(format!(
"package `{name}` not found in the dependency graph"
))
.emit();
continue;
}
let matching: Vec<_> = package_set
.packages(DependencyDirection::Forward)
.filter(|pkg| match &version_req {
Some(req) => req.matches(pkg.version()),
None => true,
})
.collect();
if matching.is_empty() {
diagnostics
.error(format!(
"no version of `{name}` matches requirement `{}`",
version_req.as_ref().unwrap()
))
.emit();
continue;
}
if version_req.is_none() && matching.len() > 1 {
let versions: Vec<_> = matching
.iter()
.map(|p| format!("v{}", p.version()))
.collect();
diagnostics
.error(format!(
"package name `{name}` is ambiguous: matches {}; \
use [package.\"{name}@<version>\"] to disambiguate",
versions.join(" and ")
))
.emit();
continue;
}
if let Some(types) = config.types {
let target = match types {
PackageTypeMode::Opaque => &mut opaque,
PackageTypeMode::Skip => &mut skipped,
};
for pkg in &matching {
target.insert(pkg.id().clone());
}
}
if let Some(value) = config.usize_is_size_t {
for pkg in &matching {
usize_is_size_t.insert(pkg.id().clone(), value);
}
}
}
for id in opaque.intersection(&skipped) {
diagnostics
.error(format!(
"package `{}` is configured with both `types = \"opaque\"` and `types = \"skip\"`; \
pick one",
id.repr()
))
.emit();
}
Ok(PackageTypeOverrides {
opaque,
skipped,
usize_is_size_t,
global_usize_is_size_t,
})
}
fn resolve_header_renames(
renames: &HashMap<String, String>,
collection: &Collection,
diagnostics: &mut DiagnosticSink,
) -> Result<HashMap<PackageId, String>, anyhow::Error> {
let graph = collection.package_graph();
let mut resolved = HashMap::new();
for (key, header_name) in renames {
let (name, version_req) = parse_package_key(key)?;
let package_set = graph.resolve_package_name(name);
if package_set.is_empty() {
diagnostics
.error(format!(
"package `{name}` (from `header_name` rename) not found in the \
dependency graph"
))
.emit();
continue;
}
let matching: Vec<_> = package_set
.packages(DependencyDirection::Forward)
.filter(|pkg| match &version_req {
Some(req) => req.matches(pkg.version()),
None => true,
})
.collect();
if matching.is_empty() {
diagnostics
.error(format!(
"no version of `{name}` matches requirement `{}` (from `header_name` rename)",
version_req.as_ref().unwrap()
))
.emit();
continue;
}
if version_req.is_none() && matching.len() > 1 {
let versions: Vec<_> = matching
.iter()
.map(|p| format!("v{}", p.version()))
.collect();
diagnostics
.error(format!(
"package name `{name}` is ambiguous for `header_name` rename: matches {}; \
use [package.\"{name}@<version>\"] to disambiguate",
versions.join(" and ")
))
.emit();
continue;
}
for pkg in &matching {
resolved.insert(pkg.id().clone(), header_name.clone());
}
}
Ok(resolved)
}
fn target_base_name(
graph: &guppy::graph::PackageGraph,
pkg_id: &PackageId,
fallback_name: &str,
renames: &HashMap<PackageId, String>,
) -> String {
renames
.get(pkg_id)
.cloned()
.or_else(|| default_header_base_name(graph, pkg_id))
.unwrap_or_else(|| fallback_name.replace('-', "_"))
}
pub(super) fn generate(cli: &GenerateArgs) -> anyhow::Result<()> {
if cli.no_header && cli.symbol_file.is_none() {
anyhow::bail!("--no-header requires --symbol-file");
}
let resolved_input = cli.input.as_ref().map(|p| resolve_input(p)).transpose()?;
let raw_config = if let Some(ref config_path) = cli.config {
config::RawConfig::from_toml_file(config_path)?
} else {
config::RawConfig::default()
};
let overrides = config::CliOverrides {
style: cli.style.clone(),
cpp_compat: cli.cpp_compat,
};
let mut config_set = raw_config.into_config(&cli.lang, &overrides)?;
if cli.bundle {
config_set.bundle = true;
}
if cli.skip_empty && config_set.bundle {
anyhow::bail!(
"--skip-empty is only valid in partitioned mode; remove --bundle to use it"
);
}
let metadata_dir = resolved_input
.as_ref()
.map(|r| r.dir().clone())
.unwrap_or_else(|| PathBuf::from("."));
let package_graph = metadata::load_package_graph(cli.metadata.as_ref(), Some(&metadata_dir))?;
let packages = select_packages(
resolved_input.as_ref(),
&cli.package_selection,
&package_graph.workspace(),
)?;
let ws_root: PathBuf = package_graph.workspace().root().to_path_buf().into();
let debug = std::env::var("CHEADERGEN_DEBUG").is_ok_and(|v| v == "true" || v == "1");
let mut diagnostics = DiagnosticSink::new(ws_root, debug);
let explicit_names: HashSet<String> =
cli.package_selection.packages.iter().cloned().collect();
let packages =
filter_library_targets(packages, &package_graph, &explicit_names, &mut diagnostics);
if packages.is_empty() {
return render_diagnostics_or_bail(&mut diagnostics, debug);
}
if config_set.bundle && packages.len() > 1 {
anyhow::bail!(
"--bundle is only valid with a single target package, but {} were selected",
packages.len()
);
}
if cli.prune_orphans && config_set.bundle {
anyhow::bail!(
"--prune-orphans is only valid in partitioned mode; remove --bundle to use it"
);
}
let toolchain = std::env::var("CHEADERGEN_DOCS_TOOLCHAIN")
.unwrap_or_else(|_| metadata::DOCS_TOOLCHAIN.to_string());
if !cli.quiet {
eprintln!(
"Generating headers for {} crate(s) using toolchain `{toolchain}`...",
packages.len()
);
}
let collection = metadata::create_collection(package_graph)?;
collection
.compute_batch(packages.iter().map(|(id, _)| id.clone()))
.map_err(|e| anyhow::anyhow!(e))?;
let mut all_symbols = BTreeSet::new();
let (package_configs, global_usize_is_size_t) = match &config_set.default {
config::Config::C(c) => (&c.common.package_configs, c.common.usize_is_size_t),
config::Config::Cxx(c) => (&c.common.package_configs, c.common.usize_is_size_t),
};
let type_overrides = resolve_package_overrides(
package_configs,
global_usize_is_size_t,
&collection,
&mut diagnostics,
)?;
let header_renames =
resolve_header_renames(&config_set.header_renames, &collection, &mut diagnostics)?;
let package_base_names: std::collections::HashSet<String> = packages
.iter()
.map(|(id, name)| target_base_name(collection.package_graph(), id, name, &header_renames))
.collect();
for header_name in config_set.header_names() {
if !package_base_names.contains(header_name) {
diagnostics
.warning(format!(
"`[header.\"{header_name}\"]` in config does not match any selected \
package's generated header name"
))
.emit();
}
}
if config_set.bundle {
for (package_id, package_name) in &packages {
if !cli.quiet {
eprintln!("Generating header for `{package_name}`...");
}
let base_name = target_base_name(
collection.package_graph(),
package_id,
package_name,
&header_renames,
);
let config = config_set.for_header(&base_name);
match generate_one_crate(
package_id,
package_name,
config,
&collection,
cli,
&type_overrides,
&mut diagnostics,
) {
Ok(symbols) => {
all_symbols.extend(symbols);
}
Err(e) => {
diagnostics
.error(format!("failed to generate header for `{package_name}`"))
.with_error_chain(e.as_ref())
.emit();
}
}
}
} else {
match generate_partitioned(
&packages,
&config_set,
&header_renames,
&collection,
cli,
&type_overrides,
&mut diagnostics,
) {
Ok(symbols) => {
all_symbols.extend(symbols);
}
Err(e) => {
diagnostics
.error("failed to generate partitioned headers".to_string())
.with_error_chain(e.as_ref())
.emit();
}
}
}
if let Some(ref symbol_file) = cli.symbol_file {
codegen::write_symbol_file(&all_symbols, symbol_file)?;
}
render_diagnostics_or_bail(&mut diagnostics, debug)
}
fn render_diagnostics_or_bail(diagnostics: &mut DiagnosticSink, debug: bool) -> anyhow::Result<()> {
if diagnostics.is_empty() {
return Ok(());
}
let has_hidden_causes = diagnostics.has_hidden_causes();
let all = diagnostics.drain();
let use_color = std::env::var("NO_COLOR").is_err();
let rendered = render_diagnostics(&all, use_color);
eprint!("{rendered}");
if !debug && has_hidden_causes {
eprintln!("note: rerun with `CHEADERGEN_DEBUG=true` for more details");
} else {
eprintln!();
}
if all.iter().any(|d| d.severity == Severity::Error) {
anyhow::bail!("aborting due to previous error(s)");
}
Ok(())
}
fn generate_partitioned(
packages: &[(PackageId, String)],
config_set: &config::ConfigSet,
header_renames: &HashMap<PackageId, String>,
collection: &Collection,
cli: &GenerateArgs,
type_overrides: &PackageTypeOverrides,
diagnostics: &mut DiagnosticSink,
) -> anyhow::Result<BTreeSet<String>> {
let mut all_symbols = BTreeSet::new();
let graph = collection.package_graph();
let mut target_extern_items: Vec<(PackageId, analysis::extern_items::ExternItems)> = Vec::new();
for (package_id, package_name) in packages {
let krate = collection
.get_or_compute(package_id)
.map_err(|e| anyhow::anyhow!(e))?;
if !cli.quiet {
let root_item = krate.core.krate.index.get(&krate.core.krate.root_item_id);
let root_name = root_item
.as_ref()
.and_then(|item| item.name.as_deref())
.unwrap_or("<unknown>");
eprintln!(
"Successfully loaded rustdoc JSON for `{package_name}`: root module `{root_name}`"
);
}
let coordinates = ExternItemCoordinates::collect(collection, package_id, diagnostics)
.map_err(|e| anyhow::anyhow!(e))?;
if cli.symbol_file.is_some() {
all_symbols.extend(analysis::collect_symbols(&coordinates, krate));
}
let base_name = target_base_name(graph, package_id, package_name, header_renames);
let config = config_set.for_header(&base_name);
let c_config = match config {
config::Config::C(c) => c,
_ => anyhow::bail!("Only C output is currently supported"),
};
let extern_items =
coordinates.resolve(collection, &c_config.common, type_overrides, diagnostics);
if !cli.quiet {
eprintln!(
"`{package_name}`: {} function(s), {} static(s), {} constant(s)",
extern_items.fns.len(),
extern_items.statics.len(),
extern_items.constants.len()
);
}
target_extern_items.push((package_id.clone(), extern_items));
}
if cli.no_header {
return Ok(all_symbols);
}
let first_base_name =
target_base_name(graph, &packages[0].0, &packages[0].1, header_renames);
let first_config = config_set.for_header(&first_base_name);
let c_config = match first_config {
config::Config::C(c) => c,
_ => anyhow::bail!("Only C output is currently supported"),
};
let all_type_defs = analysis::collect_type_definitions_multi(
&target_extern_items,
collection,
c_config.enum_prefix_with_name,
type_overrides,
diagnostics,
)?;
if !cli.quiet {
eprintln!(
"Collected {} type definitions across all targets",
all_type_defs.len()
);
}
let partitioned =
partitioning::partition_types(all_type_defs, &target_extern_items, type_overrides);
let all_header_pkg_ids: Vec<&PackageId> = partitioned.per_crate.keys().collect();
let filenames = HeaderFilenames::new(&all_header_pkg_ids, graph, header_renames)
.map_err(|e| anyhow::anyhow!(e))?;
let header_deps = partitioning::compute_header_deps(
&partitioned,
&target_extern_items,
type_overrides,
&filenames,
cli.lang.extension(),
);
let target_ids: HashSet<&PackageId> = packages.iter().map(|(id, _)| id).collect();
let multi_header = partitioned.per_crate.len() > 1;
fs_err::create_dir_all(&cli.output_dir)?;
let mut written: HashSet<String> = HashSet::new();
for (pkg_id, mut type_defs) in partitioned.per_crate {
let is_target = target_ids.contains(&pkg_id);
let config = config_set.for_header(filenames.base_name(&pkg_id));
let c_cfg = match config {
config::Config::C(c) => c,
_ => continue,
};
let deps = header_deps.get(&pkg_id);
if let Some(deps) = deps {
for fwd in &deps.forward_decls {
type_defs.push(fwd.clone());
}
}
let (fns, statics, constants) = if is_target {
let items = target_extern_items
.iter()
.find(|(id, _)| *id == pkg_id)
.map(|(_, items)| items)
.unwrap();
(&items.fns[..], &items.statics[..], &items.constants[..])
} else {
(&[][..], &[][..], &[][..])
};
if cli.skip_empty
&& type_defs.is_empty()
&& fns.is_empty()
&& statics.is_empty()
&& constants.is_empty()
{
continue;
}
analysis::sort_by_key(&mut type_defs, config::SortKey::SourceOrder, collection);
topological_sort::topological_sort(&mut type_defs, collection, diagnostics);
let krate_data = collection
.get_or_compute(&pkg_id)
.map_err(|e| anyhow::anyhow!(e))?;
let assoc_constants =
analysis::find_assoc_constants(&type_defs, krate_data, collection, diagnostics);
let dep_includes = deps.map(|d| &d.includes[..]).unwrap_or(&[]);
let type_hints = deps.map(|d| &d.type_hints[..]).unwrap_or(&[]);
let mut header = String::new();
codegen::generate_c_header(
c_cfg,
&type_defs,
constants,
&assoc_constants,
fns,
statics,
dep_includes,
type_hints,
&type_overrides.skipped,
multi_header,
collection,
&mut header,
);
let filename = filenames.filename(&pkg_id, cli.lang.extension());
fs_err::write(cli.output_dir.join(&filename), &header)?;
if !cli.quiet {
let kind = if is_target { "target" } else { "dependency" };
eprintln!("Wrote {kind} header: {filename}");
}
written.insert(filename);
}
if cli.prune_orphans {
prune_orphan_headers(&cli.output_dir, cli.lang.extension(), &written, cli.quiet)?;
}
Ok(all_symbols)
}
fn prune_orphan_headers(
output_dir: &std::path::Path,
lang_extension: &str,
keep: &HashSet<String>,
quiet: bool,
) -> anyhow::Result<()> {
let target_ext = std::ffi::OsStr::new(lang_extension);
for entry in fs_err::read_dir(output_dir)? {
let entry = entry?;
if !entry.file_type()?.is_file() {
continue;
}
let path = entry.path();
if path.extension() != Some(target_ext) {
continue;
}
let Some(name) = path.file_name().and_then(|n| n.to_str()) else {
continue;
};
if keep.contains(name) {
continue;
}
let name = name.to_string();
fs_err::remove_file(&path)?;
if !quiet {
eprintln!("Removed orphan header: {name}");
}
}
Ok(())
}
fn generate_one_crate(
package_id: &guppy::PackageId,
package_name: &str,
config: &config::Config,
collection: &Collection,
cli: &GenerateArgs,
overrides: &PackageTypeOverrides,
diagnostics: &mut DiagnosticSink,
) -> anyhow::Result<BTreeSet<String>> {
let krate = collection
.get_or_compute(package_id)
.map_err(|e| anyhow::anyhow!(e))?;
if !cli.quiet {
let root_item = krate.core.krate.index.get(&krate.core.krate.root_item_id);
let root_name = root_item
.as_ref()
.and_then(|item| item.name.as_deref())
.unwrap_or("<unknown>");
eprintln!(
"Successfully loaded rustdoc JSON for `{package_name}`: root module `{root_name}`"
);
}
let extern_items = ExternItemCoordinates::collect(collection, package_id, diagnostics)
.map_err(|e| anyhow::anyhow!(e))?;
if !cli.quiet {
eprintln!(
"Found {} extern \"C\" function(s):",
extern_items.fn_ids.len()
);
for id in &extern_items.fn_ids {
let name = krate
.core
.krate
.index
.get(id)
.and_then(|item| item.name.clone())
.unwrap_or_else(|| "<unnamed>".to_string());
eprintln!(" - {name}");
}
eprintln!(
"Found {} exported static(s):",
extern_items.static_ids.len()
);
for id in &extern_items.static_ids {
let name = krate
.core
.krate
.index
.get(id)
.and_then(|item| item.name.clone())
.unwrap_or_else(|| "<unnamed>".to_string());
eprintln!(" - {name}");
}
}
let symbols = if cli.symbol_file.is_some() {
analysis::collect_symbols(&extern_items, krate)
} else {
BTreeSet::new()
};
if !cli.no_header {
let c_config = match config {
config::Config::C(c) => c,
_ => anyhow::bail!("Only C output is currently supported"),
};
let extern_items = extern_items.resolve(collection, &c_config.common, overrides, diagnostics);
if !cli.quiet {
eprintln!("Resolved {} function(s) to IR", extern_items.fns.len());
eprintln!("Resolved {} static(s) to IR", extern_items.statics.len());
eprintln!(
"Resolved {} constant(s) to IR",
extern_items.constants.len()
);
}
let mut type_defs = analysis::collect_type_definitions(
&extern_items,
collection,
c_config.enum_prefix_with_name,
overrides,
diagnostics,
)?;
analysis::sort_by_key(&mut type_defs, config::SortKey::SourceOrder, collection);
topological_sort::topological_sort(&mut type_defs, collection, diagnostics);
let assoc_constants =
analysis::find_assoc_constants(&type_defs, krate, collection, diagnostics);
let mut header = String::new();
codegen::generate_c_header(
c_config,
&type_defs,
&extern_items.constants,
&assoc_constants,
&extern_items.fns,
&extern_items.statics,
&[],
&[],
&overrides.skipped,
false,
collection,
&mut header,
);
let base = default_header_base_name(collection.package_graph(), package_id)
.unwrap_or_else(|| package_name.replace('-', "_"));
let filename = format!("{}.{}", base, cli.lang.extension());
fs_err::create_dir_all(&cli.output_dir)?;
fs_err::write(cli.output_dir.join(filename), &header)?;
}
Ok(symbols)
}