use std::collections::HashMap;
use std::fs;
use std::path::Path;
use super::{GeneratorError, GeneratorResult};
use crate::parser::vk_types::ExtensionDefinition;
pub struct ExtensionBuilderStats {
pub device_methods: usize,
pub instance_methods: usize,
}
fn is_enabled_for_vulkan(ext: &ExtensionDefinition) -> bool {
if let Some(s) = &ext.supported
&& s == "disabled"
{
return false;
}
if let Some(api) = &ext.api
&& !api.split(',').any(|a| a.trim() == "vulkan")
{
return false;
}
true
}
fn method_name_for(ext_name: &str) -> String {
ext_name
.strip_prefix("VK_")
.unwrap_or(ext_name)
.to_ascii_lowercase()
}
fn const_name_for(ext_name: &str) -> String {
format!(
"{}_EXTENSION_NAME",
ext_name.to_uppercase().replace("VK_", "")
)
}
fn transitive_requires(
ext: &ExtensionDefinition,
by_name: &HashMap<String, &ExtensionDefinition>,
) -> Vec<String> {
let mut out = Vec::new();
let mut pending: Vec<String> = Vec::new();
if let Some(list) = &ext.requires {
pending.extend(list.split(',').map(|s| s.trim().to_string()));
}
for block in &ext.require_blocks {
if let Some(d) = &block.depends
&& !d.contains('+')
&& !d.contains(',')
&& !d.contains('(')
&& !d.contains("::")
{
pending.push(d.trim().to_string());
}
}
while let Some(name) = pending.pop() {
if name.is_empty() {
continue;
}
if !by_name.contains_key(&name) {
continue;
}
if out.iter().any(|e: &String| e == &name) {
continue;
}
out.push(name.clone());
if let Some(dep) = by_name.get(&name)
&& let Some(req) = &dep.requires
{
pending.extend(req.split(',').map(|s| s.trim().to_string()));
}
}
out.sort();
out
}
fn build_doc(ext: &ExtensionDefinition) -> String {
let mut s = String::new();
s.push_str(&format!("Enable `{}`.", ext.name));
if let Some(p) = &ext.promotedto {
s.push_str(&format!("\n\n*Promoted to {}.*", p));
}
if ext.provisional.is_some() {
s.push_str("\n\n**Provisional — API and semantics may change.**");
}
if let Some(d) = &ext.deprecated {
s.push_str(&format!("\n\n**Deprecated:** {}", d));
}
if let Some(o) = &ext.obsoletedby {
s.push_str(&format!("\n\n**Obsoleted by:** `{}`.", o));
}
if let Some(p) = &ext.platform {
s.push_str(&format!("\n\nPlatform: `{}`.", p));
}
s
}
fn emit_method(ext: &ExtensionDefinition, requires: &[String]) -> String {
let mut out = String::new();
for line in build_doc(ext).lines() {
out.push_str(" /// ");
out.push_str(line);
out.push('\n');
}
let method = method_name_for(&ext.name);
let name_const = const_name_for(&ext.name);
out.push_str(&format!(" pub fn {method}(mut self) -> Self {{\n"));
out.push_str(&format!(
" self.enable(crate::raw::bindings::{name_const});\n"
));
for dep in requires {
out.push_str(&format!(
" self.enable(crate::raw::bindings::{});\n",
const_name_for(dep)
));
}
out.push_str(" self\n }\n");
out
}
fn emit_file(kind_type: &str, impl_target: &str, exts: &[&ExtensionDefinition], by_name: &HashMap<String, &ExtensionDefinition>) -> String {
let mut out = String::new();
out.push_str(&format!(
"// Generated by vulkan_gen::extensions_builder_gen — do not edit.\n\
//\n\
// `impl` block for [`{impl_target}`], which exposes a method per\n\
// {kind_type}-level Vulkan extension. Included from\n\
// `vulkane/src/safe/extensions.rs`.\n\n"
));
out.push_str(&format!(
"#[allow(non_snake_case, rustdoc::bare_urls)]\n\
impl {impl_target} {{\n"
));
for ext in exts {
let requires = transitive_requires(ext, by_name);
out.push_str(&emit_method(ext, &requires));
}
out.push_str("}\n");
out
}
pub fn generate_extensions_builders(
intermediate_dir: &Path,
output_dir: &Path,
) -> GeneratorResult<ExtensionBuilderStats> {
let path = intermediate_dir.join("extensions.json");
let content = fs::read_to_string(&path).map_err(GeneratorError::Io)?;
let extensions: Vec<ExtensionDefinition> = serde_json::from_str(&content)?;
let enabled: Vec<&ExtensionDefinition> =
extensions.iter().filter(|e| is_enabled_for_vulkan(e)).collect();
let by_name: HashMap<String, &ExtensionDefinition> = enabled
.iter()
.map(|e| (e.name.clone(), *e))
.collect();
let mut device: Vec<&ExtensionDefinition> = Vec::new();
let mut instance: Vec<&ExtensionDefinition> = Vec::new();
for e in &enabled {
match e.extension_type.as_deref() {
Some("device") => device.push(*e),
Some("instance") => instance.push(*e),
_ => device.push(*e),
}
}
device.sort_by(|a, b| method_name_for(&a.name).cmp(&method_name_for(&b.name)));
instance.sort_by(|a, b| method_name_for(&a.name).cmp(&method_name_for(&b.name)));
fs::create_dir_all(output_dir).map_err(GeneratorError::Io)?;
let device_code = emit_file("device", "DeviceExtensions", &device, &by_name);
let instance_code = emit_file("instance", "InstanceExtensions", &instance, &by_name);
fs::write(
output_dir.join("device_extensions_generated.rs"),
device_code,
)
.map_err(GeneratorError::Io)?;
fs::write(
output_dir.join("instance_extensions_generated.rs"),
instance_code,
)
.map_err(GeneratorError::Io)?;
Ok(ExtensionBuilderStats {
device_methods: device.len(),
instance_methods: instance.len(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::vk_types::{ExtensionRequire, RequireItem};
fn mk_ext(
name: &str,
ext_type: Option<&str>,
requires: Option<&str>,
) -> ExtensionDefinition {
ExtensionDefinition {
name: name.to_string(),
number: None,
extension_type: ext_type.map(|s| s.to_string()),
requires: requires.map(|s| s.to_string()),
requires_core: None,
author: None,
contact: None,
supported: None,
ratified: None,
deprecated: None,
obsoletedby: None,
promotedto: None,
provisional: None,
specialuse: None,
platform: None,
comment: None,
api: None,
sortorder: None,
require_blocks: vec![ExtensionRequire {
api: None,
profile: None,
extension: None,
feature: None,
comment: None,
depends: None,
items: Vec::<RequireItem>::new(),
raw_content: String::new(),
}],
remove_blocks: Vec::new(),
raw_content: String::new(),
source_line: None,
}
}
#[test]
fn method_name_strips_vk_prefix_and_lowercases() {
assert_eq!(method_name_for("VK_KHR_swapchain"), "khr_swapchain");
assert_eq!(
method_name_for("VK_EXT_descriptor_indexing"),
"ext_descriptor_indexing"
);
assert_eq!(method_name_for("VK_NV_cooperative_matrix"), "nv_cooperative_matrix");
}
#[test]
fn const_name_mirrors_existing_generator() {
assert_eq!(const_name_for("VK_KHR_swapchain"), "KHR_SWAPCHAIN_EXTENSION_NAME");
assert_eq!(const_name_for("VK_EXT_debug_utils"), "EXT_DEBUG_UTILS_EXTENSION_NAME");
}
#[test]
fn disabled_extensions_are_filtered() {
let mut ext = mk_ext("VK_KHR_reserved", Some("device"), None);
ext.supported = Some("disabled".to_string());
assert!(!is_enabled_for_vulkan(&ext));
}
#[test]
fn vulkansc_only_extensions_are_filtered() {
let mut ext = mk_ext("VK_KHR_sc_thing", Some("device"), None);
ext.api = Some("vulkansc".to_string());
assert!(!is_enabled_for_vulkan(&ext));
}
#[test]
fn transitive_requires_resolves_chain() {
let a = mk_ext("VK_A", Some("device"), None);
let b = mk_ext("VK_B", Some("device"), Some("VK_A"));
let c = mk_ext("VK_C", Some("device"), Some("VK_B"));
let by_name: HashMap<_, _> = [
("VK_A".to_string(), &a),
("VK_B".to_string(), &b),
("VK_C".to_string(), &c),
]
.into_iter()
.collect();
let deps = transitive_requires(&c, &by_name);
assert!(deps.contains(&"VK_A".to_string()));
assert!(deps.contains(&"VK_B".to_string()));
assert!(!deps.contains(&"VK_C".to_string()));
}
#[test]
fn emitted_method_chains_self_and_deps() {
let ext = mk_ext(
"VK_KHR_cooperative_matrix",
Some("device"),
Some("VK_KHR_get_physical_device_properties2"),
);
let code = emit_method(&ext, &["VK_KHR_get_physical_device_properties2".to_string()]);
assert!(code.contains("pub fn khr_cooperative_matrix(mut self) -> Self"));
assert!(
code.contains(
"self.enable(crate::raw::bindings::KHR_COOPERATIVE_MATRIX_EXTENSION_NAME);"
)
);
assert!(code.contains(
"self.enable(crate::raw::bindings::KHR_GET_PHYSICAL_DEVICE_PROPERTIES2_EXTENSION_NAME);"
));
}
#[test]
fn extensions_without_type_default_to_device() {
let ext = mk_ext("VK_WEIRD_untyped", None, None);
assert!(ext.extension_type.is_none());
}
}