use super::{write_file, IndexMap, VkRegistryData};
use heck::ToSnakeCase;
use proc_macro2::{Ident, Literal, TokenStream};
use quote::{format_ident, quote};
use std::fmt::Write as _;
use vk_parse::Extension;
fn required_if_supported(name: &str) -> bool {
#[allow(clippy::match_like_matches_macro)]
match name {
"VK_KHR_portability_subset" => true,
_ => false,
}
}
fn conflicts_extensions(name: &str) -> &'static [&'static str] {
match name {
"VK_KHR_buffer_device_address" => &["VK_EXT_buffer_device_address"],
"VK_EXT_buffer_device_address" => &["VK_KHR_buffer_device_address"],
_ => &[],
}
}
pub fn write(vk_data: &VkRegistryData) {
write_device_extensions(vk_data);
write_instance_extensions(vk_data);
}
#[derive(Clone, Debug)]
struct ExtensionsMember {
name: Ident,
doc: String,
raw: String,
required_if_supported: bool,
requires: Vec<RequiresOneOf>,
conflicts_device_extensions: Vec<Ident>,
status: Option<ExtensionStatus>,
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct RequiresOneOf {
pub api_version: Option<(String, String)>,
pub device_extensions: Vec<Ident>,
pub instance_extensions: Vec<Ident>,
}
#[derive(Clone, Debug)]
enum Replacement {
Core((String, String)),
DeviceExtension(Ident),
InstanceExtension(Ident),
}
#[derive(Clone, Debug)]
enum ExtensionStatus {
Promoted(Replacement),
Deprecated(Option<Replacement>),
}
fn write_device_extensions(vk_data: &VkRegistryData) {
write_file(
"device_extensions.rs",
format!(
"vk.xml header version {}.{}.{}",
vk_data.header_version.0, vk_data.header_version.1, vk_data.header_version.2
),
device_extensions_output(&extensions_members("device", &vk_data.extensions)),
);
}
fn write_instance_extensions(vk_data: &VkRegistryData) {
write_file(
"instance_extensions.rs",
format!(
"vk.xml header version {}.{}.{}",
vk_data.header_version.0, vk_data.header_version.1, vk_data.header_version.2
),
instance_extensions_output(&extensions_members("instance", &vk_data.extensions)),
);
}
fn device_extensions_output(members: &[ExtensionsMember]) -> TokenStream {
let common = extensions_common_output(format_ident!("DeviceExtensions"), members);
let check_requirements_items = members.iter().map(|ExtensionsMember {
name,
requires,
conflicts_device_extensions,
required_if_supported,
..
}| {
let name_string = name.to_string();
let requires_items = requires.iter().map(|require| {
let require_items = require.api_version.iter().map(|version| {
let version = format_ident!("V{}_{}", version.0, version.1);
quote! { api_version >= crate::Version::#version }
}).chain(require.instance_extensions.iter().map(|ext| {
quote! { instance_extensions.#ext }
})).chain(require.device_extensions.iter().map(|ext| {
quote! { device_extensions.#ext }
}));
let api_version_items = require.api_version.as_ref().map(|version| {
let version = format_ident!("V{}_{}", version.0, version.1);
quote! { Some(crate::Version::#version) }
}).unwrap_or_else(|| quote!{ None });
let device_extensions_items = require.device_extensions.iter().map(|ext| ext.to_string());
let instance_extensions_items = require.instance_extensions.iter().map(|ext| ext.to_string());
quote! {
if !(#(#require_items)||*) {
return Err(crate::device::ExtensionRestrictionError {
extension: #name_string,
restriction: crate::device::ExtensionRestriction::Requires(crate::RequiresOneOf {
api_version: #api_version_items,
device_extensions: &[#(#device_extensions_items),*],
instance_extensions: &[#(#instance_extensions_items),*],
..Default::default()
}),
})
}
}
});
let conflicts_device_extensions_items = conflicts_device_extensions.iter().map(|extension| {
let string = extension.to_string();
quote! {
if self.#extension {
return Err(crate::device::ExtensionRestrictionError {
extension: #name_string,
restriction: crate::device::ExtensionRestriction::ConflictsDeviceExtension(#string),
});
}
}
});
let required_if_supported = if *required_if_supported {
quote! {
if supported.#name {
return Err(crate::device::ExtensionRestrictionError {
extension: #name_string,
restriction: crate::device::ExtensionRestriction::RequiredIfSupported,
});
}
}
} else {
quote! {}
};
quote! {
if self.#name {
if !supported.#name {
return Err(crate::device::ExtensionRestrictionError {
extension: #name_string,
restriction: crate::device::ExtensionRestriction::NotSupported,
});
}
#(#requires_items)*
#(#conflicts_device_extensions_items)*
} else {
#required_if_supported
}
}
});
quote! {
#common
impl DeviceExtensions {
pub(super) fn check_requirements(
&self,
supported: &DeviceExtensions,
api_version: crate::Version,
instance_extensions: &crate::instance::InstanceExtensions,
) -> Result<(), crate::device::ExtensionRestrictionError> {
let device_extensions = self;
#(#check_requirements_items)*
Ok(())
}
}
}
}
fn instance_extensions_output(members: &[ExtensionsMember]) -> TokenStream {
let common = extensions_common_output(format_ident!("InstanceExtensions"), members);
let check_requirements_items =
members
.iter()
.map(|ExtensionsMember { name, requires, .. }| {
let name_string = name.to_string();
let requires_items = requires.iter().map(|require| {
let require_items = require
.api_version
.iter()
.map(|version| {
let version = format_ident!("V{}_{}", version.0, version.1);
quote! { api_version >= crate::Version::#version }
})
.chain(require.instance_extensions.iter().map(|ext| {
quote! { instance_extensions.#ext }
}))
.chain(require.device_extensions.iter().map(|ext| {
quote! { device_extensions.#ext }
}));
let api_version_items = require
.api_version
.as_ref()
.map(|version| {
let version = format_ident!("V{}_{}", version.0, version.1);
quote! { Some(crate::Version::#version) }
})
.unwrap_or_else(|| quote! { None });
let device_extensions_items =
require.device_extensions.iter().map(|ext| ext.to_string());
let instance_extensions_items = require
.instance_extensions
.iter()
.map(|ext| ext.to_string());
quote! {
if !(#(#require_items)||*) {
return Err(crate::instance::ExtensionRestrictionError {
extension: #name_string,
restriction: crate::instance::ExtensionRestriction::Requires(crate::RequiresOneOf {
api_version: #api_version_items,
device_extensions: &[#(#device_extensions_items),*],
instance_extensions: &[#(#instance_extensions_items),*],
..Default::default()
}),
})
}
}
});
quote! {
if self.#name {
if !supported.#name {
return Err(crate::instance::ExtensionRestrictionError {
extension: #name_string,
restriction: crate::instance::ExtensionRestriction::NotSupported,
});
}
#(#requires_items)*
}
}
});
quote! {
#common
impl InstanceExtensions {
pub(super) fn check_requirements(
&self,
supported: &InstanceExtensions,
api_version: crate::Version,
) -> Result<(), crate::instance::ExtensionRestrictionError> {
let instance_extensions = self;
#(#check_requirements_items)*
Ok(())
}
}
}
}
fn extensions_common_output(struct_name: Ident, members: &[ExtensionsMember]) -> TokenStream {
let struct_items = members.iter().map(|ExtensionsMember { name, doc, .. }| {
quote! {
#[doc = #doc]
pub #name: bool,
}
});
let empty_items = members.iter().map(|ExtensionsMember { name, .. }| {
quote! {
#name: false,
}
});
let intersects_items = members.iter().map(|ExtensionsMember { name, .. }| {
quote! {
(self.#name && other.#name)
}
});
let contains_items = members.iter().map(|ExtensionsMember { name, .. }| {
quote! {
(self.#name || !other.#name)
}
});
let union_items = members.iter().map(|ExtensionsMember { name, .. }| {
quote! {
#name: self.#name || other.#name,
}
});
let intersection_items = members.iter().map(|ExtensionsMember { name, .. }| {
quote! {
#name: self.#name && other.#name,
}
});
let difference_items = members.iter().map(|ExtensionsMember { name, .. }| {
quote! {
#name: self.#name && !other.#name,
}
});
let symmetric_difference_items = members.iter().map(|ExtensionsMember { name, .. }| {
quote! {
#name: self.#name ^ other.#name,
}
});
let debug_items = members.iter().map(|ExtensionsMember { name, raw, .. }| {
quote! {
if self.#name {
if !first { write!(f, ", ")? }
else { first = false; }
f.write_str(#raw)?;
}
}
});
let from_str_for_extensions_items =
members.iter().map(|ExtensionsMember { name, raw, .. }| {
let raw = Literal::string(raw);
quote! {
#raw => { extensions.#name = true; }
}
});
let from_extensions_for_vec_cstring_items =
members.iter().map(|ExtensionsMember { name, raw, .. }| {
quote! {
if x.#name { data.push(std::ffi::CString::new(#raw).unwrap()); }
}
});
quote! {
#[derive(Copy, Clone, PartialEq, Eq)]
pub struct #struct_name {
#(#struct_items)*
pub _ne: crate::NonExhaustive,
}
impl Default for #struct_name {
#[inline]
fn default() -> Self {
Self::empty()
}
}
impl #struct_name {
#[inline]
pub const fn empty() -> Self {
Self {
#(#empty_items)*
_ne: crate::NonExhaustive(()),
}
}
#[deprecated(since = "0.31.0", note = "Use `empty` instead.")]
#[inline]
pub const fn none() -> Self {
Self::empty()
}
#[inline]
pub const fn intersects(&self, other: &Self) -> bool {
#(#intersects_items)||*
}
#[inline]
pub const fn contains(&self, other: &Self) -> bool {
#(#contains_items)&&*
}
#[deprecated(since = "0.31.0", note = "Use `contains` instead.")]
#[inline]
pub const fn is_superset_of(&self, other: &Self) -> bool {
self.contains(other)
}
#[inline]
pub const fn union(&self, other: &Self) -> Self {
Self {
#(#union_items)*
_ne: crate::NonExhaustive(()),
}
}
#[inline]
pub const fn intersection(&self, other: &Self) -> Self {
Self {
#(#intersection_items)*
_ne: crate::NonExhaustive(()),
}
}
#[inline]
pub const fn difference(&self, other: &Self) -> Self {
Self {
#(#difference_items)*
_ne: crate::NonExhaustive(()),
}
}
#[inline]
pub const fn symmetric_difference(&self, other: &Self) -> Self {
Self {
#(#symmetric_difference_items)*
_ne: crate::NonExhaustive(()),
}
}
}
impl std::ops::BitAnd for #struct_name {
type Output = #struct_name;
#[inline]
fn bitand(self, rhs: Self) -> Self::Output {
self.union(&rhs)
}
}
impl std::ops::BitAndAssign for #struct_name {
#[inline]
fn bitand_assign(&mut self, rhs: Self) {
*self = self.union(&rhs);
}
}
impl std::ops::BitOr for #struct_name {
type Output = #struct_name;
#[inline]
fn bitor(self, rhs: Self) -> Self::Output {
self.intersection(&rhs)
}
}
impl std::ops::BitOrAssign for #struct_name {
#[inline]
fn bitor_assign(&mut self, rhs: Self) {
*self = self.intersection(&rhs);
}
}
impl std::ops::BitXor for #struct_name {
type Output = #struct_name;
#[inline]
fn bitxor(self, rhs: Self) -> Self::Output {
self.symmetric_difference(&rhs)
}
}
impl std::ops::BitXorAssign for #struct_name {
#[inline]
fn bitxor_assign(&mut self, rhs: Self) {
*self = self.symmetric_difference(&rhs);
}
}
impl std::ops::Sub for #struct_name {
type Output = #struct_name;
#[inline]
fn sub(self, rhs: Self) -> Self::Output {
self.difference(&rhs)
}
}
impl std::ops::SubAssign for #struct_name {
#[inline]
fn sub_assign(&mut self, rhs: Self) {
*self = self.difference(&rhs);
}
}
impl std::fmt::Debug for #struct_name {
#[allow(unused_assignments)]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(f, "[")?;
let mut first = true;
#(#debug_items)*
write!(f, "]")
}
}
impl<'a> FromIterator<&'a str> for #struct_name {
fn from_iter<I>(iter: I) -> Self
where I: IntoIterator<Item = &'a str>
{
let mut extensions = Self::empty();
for name in iter {
match name {
#(#from_str_for_extensions_items)*
_ => (),
}
}
extensions
}
}
impl<'a> From<&'a #struct_name> for Vec<std::ffi::CString> {
fn from(x: &'a #struct_name) -> Self {
let mut data = Self::new();
#(#from_extensions_for_vec_cstring_items)*
data
}
}
}
}
fn extensions_members(ty: &str, extensions: &IndexMap<&str, &Extension>) -> Vec<ExtensionsMember> {
extensions
.values()
.filter(|ext| ext.ext_type.as_ref().unwrap() == ty)
.map(|ext| {
let raw = ext.name.to_owned();
let name = raw.strip_prefix("VK_").unwrap().to_snake_case();
let mut requires = Vec::new();
if let Some(core) = ext.requires_core.as_ref() {
let (major, minor) = core.split_once('.').unwrap();
requires.push(RequiresOneOf {
api_version: Some((major.to_owned(), minor.to_owned())),
..Default::default()
});
}
if let Some(req) = ext.requires.as_ref() {
requires.extend(req.split(',').map(|mut vk_name| {
let mut dependencies = RequiresOneOf::default();
loop {
if let Some(version) = vk_name.strip_prefix("VK_VERSION_") {
let (major, minor) = version.split_once('_').unwrap();
dependencies.api_version = Some((major.to_owned(), minor.to_owned()));
break;
} else {
let ident = format_ident!(
"{}",
vk_name.strip_prefix("VK_").unwrap().to_snake_case()
);
let extension = extensions[vk_name];
match extension.ext_type.as_deref() {
Some("device") => &mut dependencies.device_extensions,
Some("instance") => &mut dependencies.instance_extensions,
_ => unreachable!(),
}
.insert(0, ident);
if let Some(promotedto) = extension.promotedto.as_ref() {
vk_name = promotedto.as_str();
} else {
break;
}
}
}
dependencies
}));
}
let conflicts_extensions = conflicts_extensions(&ext.name);
let mut member = ExtensionsMember {
name: format_ident!("{}", name),
doc: String::new(),
raw,
required_if_supported: required_if_supported(ext.name.as_str()),
requires,
conflicts_device_extensions: conflicts_extensions
.iter()
.filter(|&&vk_name| extensions[vk_name].ext_type.as_ref().unwrap() == "device")
.map(|vk_name| {
format_ident!("{}", vk_name.strip_prefix("VK_").unwrap().to_snake_case())
})
.collect(),
status: ext
.promotedto
.as_deref()
.and_then(|pr| {
if let Some(version) = pr.strip_prefix("VK_VERSION_") {
let (major, minor) = version.split_once('_').unwrap();
Some(ExtensionStatus::Promoted(Replacement::Core((
major.to_owned(),
minor.to_owned(),
))))
} else {
let member = pr.strip_prefix("VK_").unwrap().to_snake_case();
match extensions[pr].ext_type.as_ref().unwrap().as_str() {
"device" => Some(ExtensionStatus::Promoted(
Replacement::DeviceExtension(format_ident!("{}", member)),
)),
"instance" => Some(ExtensionStatus::Promoted(
Replacement::InstanceExtension(format_ident!("{}", member)),
)),
_ => unreachable!(),
}
}
})
.or_else(|| {
ext.deprecatedby.as_deref().and_then(|depr| {
if depr.is_empty() {
Some(ExtensionStatus::Deprecated(None))
} else if let Some(version) = depr.strip_prefix("VK_VERSION_") {
let (major, minor) = version.split_once('_').unwrap();
Some(ExtensionStatus::Deprecated(Some(Replacement::Core((
major.parse().unwrap(),
minor.parse().unwrap(),
)))))
} else {
let member = depr.strip_prefix("VK_").unwrap().to_snake_case();
match extensions[depr].ext_type.as_ref().unwrap().as_str() {
"device" => Some(ExtensionStatus::Deprecated(Some(
Replacement::DeviceExtension(format_ident!("{}", member)),
))),
"instance" => Some(ExtensionStatus::Deprecated(Some(
Replacement::InstanceExtension(format_ident!("{}", member)),
))),
_ => unreachable!(),
}
}
})
}),
};
make_doc(&mut member);
member
})
.collect()
}
fn make_doc(ext: &mut ExtensionsMember) {
let writer = &mut ext.doc;
write!(writer, "- [Vulkan documentation](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/{}.html)", ext.raw).unwrap();
if ext.required_if_supported {
write!(
writer,
"\n- Must be enabled if it is supported by the physical device",
)
.unwrap();
}
if let Some(status) = ext.status.as_ref() {
match status {
ExtensionStatus::Promoted(replacement) => {
write!(writer, "\n- Promoted to ",).unwrap();
match replacement {
Replacement::Core(version) => {
write!(writer, "Vulkan {}.{}", version.0, version.1).unwrap();
}
Replacement::DeviceExtension(ext) => {
write!(writer, "[`{}`](crate::device::DeviceExtensions::{0})", ext)
.unwrap();
}
Replacement::InstanceExtension(ext) => {
write!(
writer,
"[`{}`](crate::instance::InstanceExtensions::{0})",
ext
)
.unwrap();
}
}
}
ExtensionStatus::Deprecated(replacement) => {
write!(writer, "\n- Deprecated ",).unwrap();
match replacement {
Some(Replacement::Core(version)) => {
write!(writer, "by Vulkan {}.{}", version.0, version.1).unwrap();
}
Some(Replacement::DeviceExtension(ext)) => {
write!(
writer,
"by [`{}`](crate::device::DeviceExtensions::{0})",
ext
)
.unwrap();
}
Some(Replacement::InstanceExtension(ext)) => {
write!(
writer,
"by [`{}`](crate::instance::InstanceExtensions::{0})",
ext
)
.unwrap();
}
None => {
write!(writer, "without a replacement").unwrap();
}
}
}
}
}
if !ext.requires.is_empty() {
write!(writer, "\n- Requires:").unwrap();
}
for require in &ext.requires {
let mut line = Vec::new();
if let Some((major, minor)) = require.api_version.as_ref() {
line.push(format!("Vulkan API version {}.{}", major, minor));
}
line.extend(require.device_extensions.iter().map(|ext| {
format!(
"device extension [`{}`](crate::device::DeviceExtensions::{0})",
ext
)
}));
line.extend(require.instance_extensions.iter().map(|ext| {
format!(
"instance extension [`{}`](crate::instance::InstanceExtensions::{0})",
ext
)
}));
if line.len() == 1 {
write!(writer, "\n - {}", line[0]).unwrap();
} else {
write!(writer, "\n - One of: {}", line.join(", ")).unwrap();
}
}
if !ext.conflicts_device_extensions.is_empty() {
let links: Vec<_> = ext
.conflicts_device_extensions
.iter()
.map(|ext| format!("[`{}`](crate::device::DeviceExtensions::{0})", ext))
.collect();
write!(
writer,
"\n- Conflicts with device extension{}: {}",
if ext.conflicts_device_extensions.len() > 1 {
"s"
} else {
""
},
links.join(", ")
)
.unwrap();
}
}