use core::ffi::c_void;
use std::any::{Any, TypeId};
use crate::raw::PNextChainable;
use crate::raw::bindings::{VkBaseOutStructure, VkStructureType};
trait ErasedChain: Any {
fn as_base_mut_ptr(&mut self) -> *mut VkBaseOutStructure;
fn as_any_ref(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
fn stored_type_id(&self) -> TypeId;
fn clone_erased(&self) -> Box<dyn ErasedChain>;
}
impl<T: PNextChainable> ErasedChain for T {
#[inline]
fn as_base_mut_ptr(&mut self) -> *mut VkBaseOutStructure {
<T as PNextChainable>::as_base_mut_ptr(self)
}
#[inline]
fn as_any_ref(&self) -> &dyn Any {
self
}
#[inline]
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
#[inline]
fn stored_type_id(&self) -> TypeId {
TypeId::of::<T>()
}
fn clone_erased(&self) -> Box<dyn ErasedChain> {
let mut cloned = self.clone();
unsafe {
let base = &mut cloned as *mut T as *mut VkBaseOutStructure;
(*base).pNext = std::ptr::null_mut();
}
Box::new(cloned)
}
}
#[derive(Default)]
pub struct PNextChain {
nodes: Vec<Box<dyn ErasedChain>>,
}
impl Clone for PNextChain {
fn clone(&self) -> Self {
let mut new = Self {
nodes: self.nodes.iter().map(|n| n.clone_erased()).collect(),
};
new.relink();
new
}
}
impl PNextChain {
pub fn new() -> Self {
Self::default()
}
pub fn push<T: PNextChainable>(&mut self, item: T) -> &mut Self {
self.nodes.push(Box::new(item));
self.relink();
self
}
pub fn append(&mut self, mut other: PNextChain) -> &mut Self {
self.nodes.append(&mut other.nodes);
self.relink();
self
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn head(&self) -> *const c_void {
self.head_raw() as *const c_void
}
pub fn head_mut(&mut self) -> *mut c_void {
self.head_raw_mut() as *mut c_void
}
fn head_raw(&self) -> *const VkBaseOutStructure {
match self.nodes.first() {
Some(b) => {
let erased: &dyn ErasedChain = b.as_ref();
erased as *const dyn ErasedChain as *const VkBaseOutStructure
}
None => std::ptr::null(),
}
}
fn head_raw_mut(&mut self) -> *mut VkBaseOutStructure {
match self.nodes.first_mut() {
Some(b) => b.as_base_mut_ptr(),
None => std::ptr::null_mut(),
}
}
pub fn get<T: PNextChainable>(&self) -> Option<&T> {
let target = TypeId::of::<T>();
for n in &self.nodes {
if n.stored_type_id() == target {
return n.as_any_ref().downcast_ref::<T>();
}
}
None
}
pub fn get_mut<T: PNextChainable>(&mut self) -> Option<&mut T> {
let target = TypeId::of::<T>();
for n in &mut self.nodes {
if n.stored_type_id() == target {
return n.as_any_mut().downcast_mut::<T>();
}
}
None
}
pub fn structure_types(&self) -> impl Iterator<Item = VkStructureType> + '_ {
self.nodes.iter().map(|n| {
unsafe {
let base = n.as_ref() as *const dyn ErasedChain as *const VkBaseOutStructure;
(*base).sType
}
})
}
fn relink(&mut self) {
let ptrs: Vec<*mut VkBaseOutStructure> = self
.nodes
.iter_mut()
.map(|b| b.as_base_mut_ptr())
.collect();
for (i, &p) in ptrs.iter().enumerate() {
let next = if i + 1 < ptrs.len() {
ptrs[i + 1]
} else {
std::ptr::null_mut()
};
unsafe {
(*p).pNext = next;
}
}
}
}
impl std::fmt::Debug for PNextChain {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PNextChain")
.field("len", &self.nodes.len())
.field(
"structure_types",
&self.structure_types().collect::<Vec<_>>(),
)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::raw::bindings::{
VkPhysicalDeviceVulkan11Features, VkPhysicalDeviceVulkan12Features,
VkPhysicalDeviceVulkan13Features,
};
#[test]
fn empty_chain_has_null_head() {
let chain = PNextChain::new();
assert!(chain.is_empty());
assert_eq!(chain.len(), 0);
assert!(chain.head().is_null());
}
#[test]
fn single_push_links_to_null() {
let mut chain = PNextChain::new();
chain.push(VkPhysicalDeviceVulkan12Features::new_pnext());
assert_eq!(chain.len(), 1);
assert!(!chain.head().is_null());
unsafe {
let head = chain.head() as *const VkBaseOutStructure;
assert_eq!(
(*head).sType,
VkStructureType::STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES
);
assert!((*head).pNext.is_null());
}
}
#[test]
fn multi_push_forms_ordered_chain() {
let mut chain = PNextChain::new();
chain.push(VkPhysicalDeviceVulkan11Features::new_pnext());
chain.push(VkPhysicalDeviceVulkan12Features::new_pnext());
chain.push(VkPhysicalDeviceVulkan13Features::new_pnext());
let types: Vec<_> = chain.structure_types().collect();
assert_eq!(
types,
vec![
VkStructureType::STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES,
VkStructureType::STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES,
VkStructureType::STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_3_FEATURES,
]
);
unsafe {
let mut cur = chain.head() as *const VkBaseOutStructure;
let mut seen = Vec::new();
while !cur.is_null() {
seen.push((*cur).sType);
cur = (*cur).pNext as *const VkBaseOutStructure;
}
assert_eq!(seen, types);
}
}
#[test]
fn new_pnext_sets_stype_correctly() {
let f = VkPhysicalDeviceVulkan12Features::new_pnext();
assert_eq!(
f.sType,
VkStructureType::STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES
);
assert!(f.pNext.is_null());
assert_eq!(f.timelineSemaphore, 0);
assert_eq!(f.bufferDeviceAddress, 0);
}
#[test]
fn get_returns_pushed_struct() {
let mut chain = PNextChain::new();
let mut v12 = VkPhysicalDeviceVulkan12Features::new_pnext();
v12.timelineSemaphore = 1;
v12.bufferDeviceAddress = 1;
chain.push(v12);
chain.push(VkPhysicalDeviceVulkan13Features::new_pnext());
let got = chain
.get::<VkPhysicalDeviceVulkan12Features>()
.expect("v12 present");
assert_eq!(got.timelineSemaphore, 1);
assert_eq!(got.bufferDeviceAddress, 1);
assert!(
chain
.get::<crate::raw::bindings::VkPhysicalDeviceFeatures2>()
.is_none()
);
}
#[test]
fn extension_struct_chains_like_core_features() {
use crate::raw::bindings::VkPhysicalDeviceCooperativeMatrixFeaturesKHR;
let mut chain = PNextChain::new();
chain.push(VkPhysicalDeviceVulkan12Features::new_pnext());
let mut coop = VkPhysicalDeviceCooperativeMatrixFeaturesKHR::new_pnext();
coop.cooperativeMatrix = 1;
chain.push(coop);
let types: Vec<_> = chain.structure_types().collect();
assert_eq!(
types,
vec![
VkStructureType::STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES,
VkStructureType::STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR,
]
);
let back = chain
.get::<VkPhysicalDeviceCooperativeMatrixFeaturesKHR>()
.expect("coop present");
assert_eq!(back.cooperativeMatrix, 1);
assert_eq!(back.cooperativeMatrixRobustBufferAccess, 0);
}
#[test]
fn push_after_head_still_relinks_previous_tail() {
let mut chain = PNextChain::new();
chain.push(VkPhysicalDeviceVulkan11Features::new_pnext());
chain.push(VkPhysicalDeviceVulkan12Features::new_pnext());
unsafe {
let head = chain.head() as *const VkBaseOutStructure;
let second = (*head).pNext as *const VkBaseOutStructure;
assert!(!second.is_null());
assert_eq!(
(*second).sType,
VkStructureType::STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES
);
}
}
}