use std::collections::HashSet;
use std::fs;
use std::path::Path;
use super::{GeneratorError, GeneratorResult};
use crate::codegen::camel_to_snake;
use crate::parser::vk_types::{CommandParam, VulkanCommand};
pub struct SafeErgonomicStats {
pub device_methods: usize,
pub instance_methods: usize,
pub physical_device_methods: usize,
pub queue_methods: usize,
pub command_buffer_methods: usize,
pub skipped: usize,
}
#[derive(Copy, Clone)]
enum Target {
Device,
Instance,
PhysicalDevice,
Queue,
CommandBuffer,
}
impl Target {
fn from_first_param(ty: &str) -> Option<Self> {
match ty {
"VkDevice" => Some(Target::Device),
"VkInstance" => Some(Target::Instance),
"VkPhysicalDevice" => Some(Target::PhysicalDevice),
"VkQueue" => Some(Target::Queue),
"VkCommandBuffer" => Some(Target::CommandBuffer),
_ => None,
}
}
fn trait_name(self) -> &'static str {
match self {
Target::Device => "DeviceSafeExt",
Target::Instance => "InstanceSafeExt",
Target::PhysicalDevice => "PhysicalDeviceSafeExt",
Target::Queue => "QueueSafeExt",
Target::CommandBuffer => "CommandBufferRecordingSafeExt",
}
}
fn impl_target(self) -> &'static str {
match self {
Target::Device => "crate::safe::Device",
Target::Instance => "crate::safe::Instance",
Target::PhysicalDevice => "crate::safe::PhysicalDevice",
Target::Queue => "crate::safe::Queue",
Target::CommandBuffer => "crate::safe::CommandBufferRecording<'_>",
}
}
fn raw_handle_expr(self) -> &'static str {
match self {
Target::Device | Target::Instance => "self.inner.handle",
Target::PhysicalDevice => "self.handle",
Target::Queue => "self.handle",
Target::CommandBuffer => "self.raw_cmd()",
}
}
fn dispatch_expr(self) -> &'static str {
match self {
Target::Device | Target::Instance => "self.inner.dispatch",
Target::PhysicalDevice => "self.instance.dispatch",
Target::Queue => "self.device.dispatch",
Target::CommandBuffer => "self.device_dispatch()",
}
}
fn self_kind(self) -> &'static str {
match self {
Target::CommandBuffer => "&mut self",
_ => "&self",
}
}
fn file_name(self) -> &'static str {
match self {
Target::Device => "auto_device_safe_generated.rs",
Target::Instance => "auto_instance_safe_generated.rs",
Target::PhysicalDevice => "auto_physical_device_safe_generated.rs",
Target::Queue => "auto_queue_safe_generated.rs",
Target::CommandBuffer => "auto_command_buffer_safe_generated.rs",
}
}
fn target_ty(self) -> &'static str {
match self {
Target::Device => "VkDevice",
Target::Instance => "VkInstance",
Target::PhysicalDevice => "VkPhysicalDevice",
Target::Queue => "VkQueue",
Target::CommandBuffer => "VkCommandBuffer",
}
}
}
fn is_phase1_handled_command(name: &str) -> bool {
name.starts_with("vkCreate")
|| name.starts_with("vkDestroy")
|| name.starts_with("vkAllocate")
|| name.starts_with("vkFree")
}
fn method_name(cmd_name: &str) -> String {
let stripped = cmd_name.strip_prefix("vk").unwrap_or(cmd_name);
camel_to_snake(stripped)
}
fn pointer_level(p: &CommandParam) -> usize {
p.definition.matches('*').count()
}
fn is_const_ptr(p: &CommandParam) -> bool {
pointer_level(p) == 1 && p.definition.contains("const")
}
fn is_mut_ptr(p: &CommandParam) -> bool {
pointer_level(p) == 1 && !p.definition.contains("const")
}
fn is_scalar(p: &CommandParam) -> bool {
pointer_level(p) == 0
}
fn is_allocator_callbacks(p: &CommandParam) -> bool {
p.type_name == "VkAllocationCallbacks" && pointer_level(p) == 1
}
fn rust_type_name(ty: &str) -> String {
match ty {
"void" => "core::ffi::c_void".to_string(),
"char" => "core::ffi::c_char".to_string(),
"uint8_t" => "u8".to_string(),
"uint16_t" => "u16".to_string(),
"uint32_t" => "u32".to_string(),
"uint64_t" => "u64".to_string(),
"int8_t" => "i8".to_string(),
"int16_t" => "i16".to_string(),
"int32_t" => "i32".to_string(),
"int64_t" => "i64".to_string(),
"float" => "f32".to_string(),
"double" => "f64".to_string(),
"size_t" => "usize".to_string(),
"int" => "i32".to_string(),
other => format!("crate::raw::bindings::{}", other),
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum ReturnKind {
Void,
VkResult,
Scalar, }
fn return_kind(ret: &str) -> Option<ReturnKind> {
match ret {
"void" => Some(ReturnKind::Void),
"VkResult" => Some(ReturnKind::VkResult),
"uint32_t" | "uint64_t" | "int32_t" | "int64_t" | "float" | "double" | "size_t" | "int" => {
Some(ReturnKind::Scalar)
}
other if other.starts_with("Vk") => Some(ReturnKind::Scalar),
_ => None,
}
}
fn rust_scalar_return(ret: &str) -> String {
match ret {
"uint32_t" => "u32".to_string(),
"uint64_t" => "u64".to_string(),
"int32_t" => "i32".to_string(),
"int64_t" => "i64".to_string(),
"float" => "f32".to_string(),
"double" => "f64".to_string(),
"size_t" => "usize".to_string(),
"int" => "i32".to_string(),
other => format!("crate::raw::bindings::{}", other),
}
}
fn escape_param_name(name: &str) -> String {
match name {
"type" | "match" | "impl" | "fn" | "let" | "mut" | "const" | "static" | "if" | "else"
| "while" | "for" | "loop" | "break" | "continue" | "return" | "struct" | "enum"
| "trait" | "mod" | "pub" | "use" | "extern" | "crate" | "self" | "Self" | "super"
| "where" | "async" | "await" | "dyn" | "abstract" | "become" | "box" | "do" | "final"
| "macro" | "override" | "priv" | "typeof" | "unsized" | "virtual" | "yield" | "try"
| "union" | "ref" => format!("r#{}", name),
_ => name.to_string(),
}
}
#[derive(Debug, Clone)]
enum ParamRole {
Handle,
Scalar { rust_ty: String, name: String },
ConstPtr { rust_ty: String, name: String },
MutPtr { rust_ty: String, name: String },
AllocatorPtr,
SliceCount {
data_name: String,
count_rust_ty: String,
},
SliceData { rust_elem_ty: String, name: String },
EnumerateCount,
EnumerateData { rust_elem_ty: String },
Unsupported,
}
fn named_len(p: &CommandParam) -> Option<&str> {
let s = p.len.as_deref()?.trim();
if s.is_empty() || s == "null-terminated" || s == "1" {
return None;
}
if s.parse::<u64>().is_ok() {
return None;
}
if !s.chars().all(|c| c.is_alphanumeric() || c == '_') {
return None;
}
Some(s)
}
fn classify_roles(params: &[CommandParam]) -> Vec<ParamRole> {
let n = params.len();
let mut roles: Vec<ParamRole> = (0..n).map(|_| ParamRole::Unsupported).collect();
for (j, data) in params.iter().enumerate() {
let Some(len_name) = named_len(data) else {
continue;
};
let Some(i) = params.iter().position(|q| q.name == len_name) else {
continue;
};
if i == j {
continue;
}
let count = ¶ms[i];
if is_const_ptr(data)
&& is_scalar(count)
&& matches!(
count.type_name.as_str(),
"uint32_t" | "uint64_t" | "size_t" | "int32_t" | "int64_t" | "int"
)
&& matches!(roles[i], ParamRole::Unsupported)
&& matches!(roles[j], ParamRole::Unsupported)
{
roles[i] = ParamRole::SliceCount {
data_name: escape_param_name(&data.name),
count_rust_ty: rust_type_name(&count.type_name),
};
roles[j] = ParamRole::SliceData {
rust_elem_ty: rust_type_name(&data.type_name),
name: escape_param_name(&data.name),
};
continue;
}
if is_mut_ptr(data)
&& is_mut_ptr(count)
&& count.type_name == "uint32_t"
&& matches!(roles[i], ParamRole::Unsupported)
&& matches!(roles[j], ParamRole::Unsupported)
{
roles[i] = ParamRole::EnumerateCount;
roles[j] = ParamRole::EnumerateData {
rust_elem_ty: rust_type_name(&data.type_name),
};
}
}
for (i, p) in params.iter().enumerate() {
if !matches!(roles[i], ParamRole::Unsupported) {
continue;
}
if i == 0 {
roles[i] = ParamRole::Handle;
continue;
}
if is_allocator_callbacks(p) {
roles[i] = ParamRole::AllocatorPtr;
continue;
}
if pointer_level(p) >= 2 {
roles[i] = ParamRole::Unsupported;
continue;
}
if named_len(p).is_some() {
roles[i] = ParamRole::Unsupported;
continue;
}
if is_scalar(p) {
roles[i] = ParamRole::Scalar {
rust_ty: rust_type_name(&p.type_name),
name: escape_param_name(&p.name),
};
} else if is_const_ptr(p) {
roles[i] = ParamRole::ConstPtr {
rust_ty: rust_type_name(&p.type_name),
name: escape_param_name(&p.name),
};
} else if is_mut_ptr(p) {
roles[i] = ParamRole::MutPtr {
rust_ty: rust_type_name(&p.type_name),
name: escape_param_name(&p.name),
};
} else {
roles[i] = ParamRole::Unsupported;
}
}
roles
}
fn emit_method(cmd: &VulkanCommand, target: Target) -> Option<(String, String)> {
if cmd.parameters.is_empty() {
return None;
}
let roles = classify_roles(&cmd.parameters);
if roles.iter().any(|r| matches!(r, ParamRole::Unsupported)) {
return None;
}
let mut_ptr_count = roles
.iter()
.filter(|r| matches!(r, ParamRole::MutPtr { .. }))
.count();
let enumerate_data_count = roles
.iter()
.filter(|r| matches!(r, ParamRole::EnumerateData { .. }))
.count();
let enumerate_count_count = roles
.iter()
.filter(|r| matches!(r, ParamRole::EnumerateCount))
.count();
if enumerate_data_count != enumerate_count_count {
return None;
}
if enumerate_data_count > 1 {
return None;
}
if mut_ptr_count > 1 {
return None;
}
if mut_ptr_count == 1 && enumerate_data_count == 1 {
return None;
}
let ret_kind = return_kind(&cmd.return_type)?;
if matches!(ret_kind, ReturnKind::Scalar) && (mut_ptr_count > 0 || enumerate_data_count > 0) {
return None;
}
let is_enumerate = enumerate_data_count == 1;
let method = method_name(&cmd.name);
let self_kind = target.self_kind();
let raw_handle = target.raw_handle_expr();
let dispatch_expr = target.dispatch_expr();
let fp_name = &cmd.name;
let mut sig_params: Vec<String> = Vec::new();
let mut call_args_fill: Vec<String> = Vec::new();
let mut call_args_count: Vec<String> = Vec::new();
let mut out_var_name: Option<String> = None;
let mut out_rust_ty: Option<String> = None;
let mut enum_elem_ty: Option<String> = None;
for role in &roles {
match role {
ParamRole::Handle => {
call_args_fill.push(raw_handle.to_string());
call_args_count.push(raw_handle.to_string());
}
ParamRole::Scalar { rust_ty, name } => {
sig_params.push(format!("{name}: {rust_ty}"));
call_args_fill.push(name.clone());
call_args_count.push(name.clone());
}
ParamRole::ConstPtr { rust_ty, name } => {
sig_params.push(format!("{name}: &{rust_ty}"));
call_args_fill.push(format!("{name} as *const _"));
call_args_count.push(format!("{name} as *const _"));
}
ParamRole::MutPtr { rust_ty, name } => {
call_args_fill.push(format!("&mut {name}"));
call_args_count.push(format!("&mut {name}"));
out_var_name = Some(name.clone());
out_rust_ty = Some(rust_ty.clone());
}
ParamRole::AllocatorPtr => {
call_args_fill.push("std::ptr::null()".to_string());
call_args_count.push("std::ptr::null()".to_string());
}
ParamRole::SliceCount {
data_name,
count_rust_ty,
} => {
let expr = format!("{data_name}.len() as {count_rust_ty}");
call_args_fill.push(expr.clone());
call_args_count.push(expr);
}
ParamRole::SliceData { rust_elem_ty, name } => {
sig_params.push(format!("{name}: &[{rust_elem_ty}]"));
let expr = format!(
"if {name}.is_empty() {{ std::ptr::null() }} else {{ {name}.as_ptr() }}"
);
call_args_fill.push(expr.clone());
call_args_count.push(expr);
}
ParamRole::EnumerateCount => {
call_args_fill.push("&mut __enumerate_count".to_string());
call_args_count.push("&mut __enumerate_count".to_string());
}
ParamRole::EnumerateData { rust_elem_ty } => {
call_args_count.push("std::ptr::null_mut()".to_string());
call_args_fill.push("__enumerate_buf.as_mut_ptr()".to_string());
enum_elem_ty = Some(rust_elem_ty.clone());
}
ParamRole::Unsupported => unreachable!(),
}
}
let (return_type, body_tail) = if is_enumerate {
let elem = enum_elem_ty.as_ref().unwrap();
match ret_kind {
ReturnKind::VkResult => (
format!("crate::safe::Result<Vec<{}>>", elem),
format!(
" let f = {dispatch_expr}.{fp_name}.expect(\"{fp_name} not loaded — did you enable its extension?\");\n\
\x20 let mut __enumerate_count: u32 = 0;\n\
\x20 let r = unsafe {{ f({}) }};\n\
\x20 if (r as i32) < 0 {{ return Err(crate::safe::Error::Vk(r)); }}\n\
\x20 let mut __enumerate_buf: Vec<{elem}> = Vec::with_capacity(__enumerate_count as usize);\n\
\x20 let r = unsafe {{ f({}) }};\n\
\x20 if (r as i32) < 0 {{ return Err(crate::safe::Error::Vk(r)); }}\n\
\x20 unsafe {{ __enumerate_buf.set_len(__enumerate_count as usize); }}\n\
\x20 Ok(__enumerate_buf)\n",
call_args_count.join(", "),
call_args_fill.join(", ")
),
),
ReturnKind::Void => (
format!("Vec<{}>", elem),
format!(
" let f = {dispatch_expr}.{fp_name}.expect(\"{fp_name} not loaded — did you enable its extension?\");\n\
\x20 let mut __enumerate_count: u32 = 0;\n\
\x20 unsafe {{ f({}) }};\n\
\x20 let mut __enumerate_buf: Vec<{elem}> = Vec::with_capacity(__enumerate_count as usize);\n\
\x20 unsafe {{ f({}) }};\n\
\x20 unsafe {{ __enumerate_buf.set_len(__enumerate_count as usize); }}\n\
\x20 __enumerate_buf\n",
call_args_count.join(", "),
call_args_fill.join(", ")
),
),
ReturnKind::Scalar => return None,
}
} else {
match (ret_kind, &out_rust_ty) {
(ReturnKind::Void, None) => (
"()".to_string(),
format!(
" let f = {dispatch_expr}.{fp_name}.expect(\"{fp_name} not loaded — did you enable its extension?\");\n\
\x20 unsafe {{ f({}) }};\n",
call_args_fill.join(", ")
),
),
(ReturnKind::Void, Some(out_ty)) => {
let out_var = out_var_name.as_ref().unwrap();
(
out_ty.clone(),
format!(
" let f = {dispatch_expr}.{fp_name}.expect(\"{fp_name} not loaded — did you enable its extension?\");\n\
\x20 let mut {out_var}: {out_ty} = unsafe {{ std::mem::zeroed() }};\n\
\x20 unsafe {{ f({}) }};\n\
\x20 {out_var}\n",
call_args_fill.join(", ")
),
)
}
(ReturnKind::VkResult, None) => (
"crate::safe::Result<()>".to_string(),
format!(
" let f = {dispatch_expr}.{fp_name}.expect(\"{fp_name} not loaded — did you enable its extension?\");\n\
\x20 let r = unsafe {{ f({}) }};\n\
\x20 if (r as i32) < 0 {{ Err(crate::safe::Error::Vk(r)) }} else {{ Ok(()) }}\n",
call_args_fill.join(", ")
),
),
(ReturnKind::VkResult, Some(out_ty)) => {
let out_var = out_var_name.as_ref().unwrap();
(
format!("crate::safe::Result<{}>", out_ty),
format!(
" let f = {dispatch_expr}.{fp_name}.expect(\"{fp_name} not loaded — did you enable its extension?\");\n\
\x20 let mut {out_var}: {out_ty} = unsafe {{ std::mem::zeroed() }};\n\
\x20 let r = unsafe {{ f({}) }};\n\
\x20 if (r as i32) < 0 {{ Err(crate::safe::Error::Vk(r)) }} else {{ Ok({out_var}) }}\n",
call_args_fill.join(", ")
),
)
}
(ReturnKind::Scalar, None) => {
let ret_rust = rust_scalar_return(&cmd.return_type);
(
ret_rust,
format!(
" let f = {dispatch_expr}.{fp_name}.expect(\"{fp_name} not loaded — did you enable its extension?\");\n\
\x20 unsafe {{ f({}) }}\n",
call_args_fill.join(", ")
),
)
}
_ => return None,
}
};
let params_str = if sig_params.is_empty() {
String::new()
} else {
format!(", {}", sig_params.join(", "))
};
let sig = format!(" fn {method}({self_kind}{params_str}) -> {return_type};\n");
let body = format!(
" fn {method}({self_kind}{params_str}) -> {return_type} {{\n{body_tail} }}\n"
);
Some((sig, body))
}
fn unique_commands(commands: &[VulkanCommand]) -> Vec<&VulkanCommand> {
let mut seen: HashSet<&str> = HashSet::new();
commands
.iter()
.filter(|c| !c.is_alias && c.deprecated.is_none() && seen.insert(c.name.as_str()))
.collect()
}
fn emit_file(target: Target, cmds: &[&VulkanCommand]) -> (String, usize, usize) {
let trait_name = target.trait_name();
let impl_target = target.impl_target();
let mut out = String::new();
out.push_str(&format!(
"// Generated by vulkan_gen::safe_ergonomic_gen — do not edit.\n\
//\n\
// Phase-3 auto-safe-layer: one *ergonomic* method per Vulkan\n\
// command whose first argument is `{target_ty}` and whose\n\
// remaining parameters match a recognised shape (scalar args,\n\
// one input struct by reference, at most one trailing output\n\
// value). Commands that don't match fall through to the\n\
// raw-pointer trait in `auto_{target_mod}_ext_generated.rs`.\n\n",
target_ty = target.target_ty(),
target_mod = match target {
Target::Device => "device",
Target::Instance => "instance",
Target::PhysicalDevice => "physical_device",
Target::Queue => "queue",
Target::CommandBuffer => "command_buffer",
}
));
out.push_str(&format!(
"#[allow(non_snake_case, clippy::too_many_arguments, clippy::not_unsafe_ptr_arg_deref, clippy::unused_unit)]\n\
pub trait {trait_name} {{\n"
));
let mut trait_body = String::new();
let mut impl_body = String::new();
let mut emitted = 0usize;
let mut skipped = 0usize;
for cmd in cmds {
if let Some((sig, body)) = emit_method(cmd, target) {
trait_body.push_str(&sig);
impl_body.push_str(&body);
emitted += 1;
} else {
skipped += 1;
}
}
out.push_str(&trait_body);
out.push_str("}\n\n");
out.push_str(&format!(
"#[allow(non_snake_case, clippy::too_many_arguments, clippy::not_unsafe_ptr_arg_deref, clippy::unused_unit)]\n\
impl {trait_name} for {impl_target} {{\n"
));
out.push_str(&impl_body);
out.push_str("}\n");
(out, emitted, skipped)
}
pub fn generate_safe_ergonomic(
intermediate_dir: &Path,
output_dir: &Path,
) -> GeneratorResult<SafeErgonomicStats> {
let fns_path = intermediate_dir.join("functions.json");
let content = fs::read_to_string(&fns_path).map_err(GeneratorError::Io)?;
let commands: Vec<VulkanCommand> = serde_json::from_str(&content)?;
let uniq = unique_commands(&commands);
let mut by_target: std::collections::HashMap<&'static str, Vec<&VulkanCommand>> =
std::collections::HashMap::new();
let mut skipped = 0usize;
for cmd in &uniq {
if is_phase1_handled_command(&cmd.name) {
skipped += 1;
continue;
}
if cmd.parameters.is_empty() {
skipped += 1;
continue;
}
let target = match Target::from_first_param(&cmd.parameters[0].type_name) {
Some(t) => t,
None => {
skipped += 1;
continue;
}
};
by_target.entry(target.trait_name()).or_default().push(cmd);
}
fs::create_dir_all(output_dir).map_err(GeneratorError::Io)?;
let mut stats = SafeErgonomicStats {
device_methods: 0,
instance_methods: 0,
physical_device_methods: 0,
queue_methods: 0,
command_buffer_methods: 0,
skipped,
};
for target in [
Target::Device,
Target::Instance,
Target::PhysicalDevice,
Target::Queue,
Target::CommandBuffer,
] {
let cmds: Vec<&VulkanCommand> = by_target
.get(target.trait_name())
.cloned()
.unwrap_or_default();
let (code, emitted, cmd_skipped) = emit_file(target, &cmds);
stats.skipped += cmd_skipped;
match target {
Target::Device => stats.device_methods = emitted,
Target::Instance => stats.instance_methods = emitted,
Target::PhysicalDevice => stats.physical_device_methods = emitted,
Target::Queue => stats.queue_methods = emitted,
Target::CommandBuffer => stats.command_buffer_methods = emitted,
}
fs::write(output_dir.join(target.file_name()), code).map_err(GeneratorError::Io)?;
}
Ok(stats)
}
#[cfg(test)]
mod tests {
use super::*;
fn mk_param(name: &str, type_name: &str, definition: &str) -> CommandParam {
CommandParam {
name: name.to_string(),
type_name: type_name.to_string(),
optional: None,
len: None,
altlen: None,
externsync: None,
noautovalidity: None,
objecttype: None,
stride: None,
validstructs: None,
api: None,
deprecated: None,
comment: None,
definition: definition.to_string(),
raw_content: String::new(),
source_line: None,
}
}
fn mk_cmd(name: &str, ret: &str, params: Vec<CommandParam>) -> VulkanCommand {
VulkanCommand {
name: name.to_string(),
return_type: ret.to_string(),
comment: None,
successcodes: None,
errorcodes: None,
alias: None,
api: None,
deprecated: None,
cmdbufferlevel: None,
pipeline: None,
queues: None,
renderpass: None,
videocoding: None,
parameters: params,
raw_content: String::new(),
is_alias: false,
source_line: None,
}
}
#[test]
fn method_name_strips_vk_prefix() {
assert_eq!(method_name("vkDeviceWaitIdle"), "device_wait_idle");
assert_eq!(method_name("vkCmdTraceRaysKHR"), "cmd_trace_rays_khr");
assert_eq!(
method_name("vkGetBufferDeviceAddress"),
"get_buffer_device_address"
);
}
#[test]
fn shape_a_scalar_void() {
let cmd = mk_cmd(
"vkCmdDraw",
"void",
vec![
mk_param(
"commandBuffer",
"VkCommandBuffer",
"VkCommandBuffer commandBuffer",
),
mk_param("vertexCount", "uint32_t", "uint32_t vertexCount"),
mk_param("instanceCount", "uint32_t", "uint32_t instanceCount"),
mk_param("firstVertex", "uint32_t", "uint32_t firstVertex"),
mk_param("firstInstance", "uint32_t", "uint32_t firstInstance"),
],
);
let (sig, body) = emit_method(&cmd, Target::CommandBuffer).expect("emitted");
assert!(sig.contains("fn cmd_draw(&mut self"));
assert!(sig.contains("vertexCount: u32"));
assert!(sig.contains("-> ()"));
assert!(body.contains("self.raw_cmd()"));
assert!(body.contains("vkCmdDraw"));
}
#[test]
fn shape_b_scalar_vk_result() {
let cmd = mk_cmd(
"vkDeviceWaitIdle",
"VkResult",
vec![mk_param("device", "VkDevice", "VkDevice device")],
);
let (sig, body) = emit_method(&cmd, Target::Device).expect("emitted");
assert!(sig.contains("fn device_wait_idle(&self)"));
assert!(sig.contains("crate::safe::Result<()>"));
assert!(body.contains("Err(crate::safe::Error::Vk(r))"));
assert!(body.contains("Ok(())"));
}
#[test]
fn shape_c_single_input_struct_void() {
let cmd = mk_cmd(
"vkCmdBeginRendering",
"void",
vec![
mk_param(
"commandBuffer",
"VkCommandBuffer",
"VkCommandBuffer commandBuffer",
),
mk_param(
"pRenderingInfo",
"VkRenderingInfo",
"const VkRenderingInfo* pRenderingInfo",
),
],
);
let (sig, body) = emit_method(&cmd, Target::CommandBuffer).expect("emitted");
assert!(sig.contains("fn cmd_begin_rendering(&mut self"));
assert!(sig.contains("pRenderingInfo: &crate::raw::bindings::VkRenderingInfo"));
assert!(sig.contains("-> ()"));
assert!(body.contains("pRenderingInfo as *const _"));
}
#[test]
fn shape_d_scalar_plus_single_output_void() {
let cmd = mk_cmd(
"vkGetBufferMemoryRequirements",
"void",
vec![
mk_param("device", "VkDevice", "VkDevice device"),
mk_param("buffer", "VkBuffer", "VkBuffer buffer"),
mk_param(
"pMemoryRequirements",
"VkMemoryRequirements",
"VkMemoryRequirements* pMemoryRequirements",
),
],
);
let (sig, body) = emit_method(&cmd, Target::Device).expect("emitted");
assert!(sig.contains("fn get_buffer_memory_requirements(&self"));
assert!(sig.contains("buffer: crate::raw::bindings::VkBuffer"));
assert!(sig.contains("-> crate::raw::bindings::VkMemoryRequirements"));
assert!(body.contains("std::mem::zeroed"));
assert!(body.contains("&mut pMemoryRequirements"));
}
#[test]
fn shape_e_input_plus_output_vk_result() {
let cmd = mk_cmd(
"vkGetMemoryWin32HandleKHR",
"VkResult",
vec![
mk_param("device", "VkDevice", "VkDevice device"),
mk_param(
"pGetWin32HandleInfo",
"VkMemoryGetWin32HandleInfoKHR",
"const VkMemoryGetWin32HandleInfoKHR* pGetWin32HandleInfo",
),
mk_param("pHandle", "HANDLE", "HANDLE* pHandle"),
],
);
let (sig, body) = emit_method(&cmd, Target::Device).expect("emitted");
assert!(sig.contains("fn get_memory_win32_handle_khr(&self"));
assert!(
sig.contains(
"pGetWin32HandleInfo: &crate::raw::bindings::VkMemoryGetWin32HandleInfoKHR"
)
);
assert!(sig.contains("-> crate::safe::Result<crate::raw::bindings::HANDLE>"));
assert!(body.contains("pGetWin32HandleInfo as *const _"));
assert!(body.contains("&mut pHandle"));
assert!(body.contains("Err(crate::safe::Error::Vk(r))"));
}
#[test]
fn shape_f_scalar_return_with_input_struct() {
let cmd = mk_cmd(
"vkGetBufferDeviceAddress",
"VkDeviceAddress",
vec![
mk_param("device", "VkDevice", "VkDevice device"),
mk_param(
"pInfo",
"VkBufferDeviceAddressInfo",
"const VkBufferDeviceAddressInfo* pInfo",
),
],
);
let (sig, body) = emit_method(&cmd, Target::Device).expect("emitted");
assert!(sig.contains("fn get_buffer_device_address(&self"));
assert!(sig.contains("pInfo: &crate::raw::bindings::VkBufferDeviceAddressInfo"));
assert!(sig.contains("-> crate::raw::bindings::VkDeviceAddress"));
assert!(!body.contains("Err("));
}
#[test]
fn shape_g_multi_input_struct_void() {
let cmd = mk_cmd(
"vkCmdTraceRaysKHR",
"void",
vec![
mk_param(
"commandBuffer",
"VkCommandBuffer",
"VkCommandBuffer commandBuffer",
),
mk_param(
"pRaygenShaderBindingTable",
"VkStridedDeviceAddressRegionKHR",
"const VkStridedDeviceAddressRegionKHR* pRaygenShaderBindingTable",
),
mk_param(
"pMissShaderBindingTable",
"VkStridedDeviceAddressRegionKHR",
"const VkStridedDeviceAddressRegionKHR* pMissShaderBindingTable",
),
mk_param(
"pHitShaderBindingTable",
"VkStridedDeviceAddressRegionKHR",
"const VkStridedDeviceAddressRegionKHR* pHitShaderBindingTable",
),
mk_param(
"pCallableShaderBindingTable",
"VkStridedDeviceAddressRegionKHR",
"const VkStridedDeviceAddressRegionKHR* pCallableShaderBindingTable",
),
mk_param("width", "uint32_t", "uint32_t width"),
mk_param("height", "uint32_t", "uint32_t height"),
mk_param("depth", "uint32_t", "uint32_t depth"),
],
);
let (sig, _body) = emit_method(&cmd, Target::CommandBuffer).expect("emitted");
assert!(sig.contains("fn cmd_trace_rays_khr(&mut self"));
assert!(sig.contains(
"pRaygenShaderBindingTable: &crate::raw::bindings::VkStridedDeviceAddressRegionKHR"
));
assert!(sig.contains(
"pCallableShaderBindingTable: &crate::raw::bindings::VkStridedDeviceAddressRegionKHR"
));
assert!(sig.contains("width: u32"));
}
#[test]
fn shape_h_slice_coalesces_count_and_data() {
let mut p = mk_param("pSubmits", "VkSubmitInfo", "const VkSubmitInfo* pSubmits");
p.len = Some("submitCount".to_string());
let cmd = mk_cmd(
"vkQueueSubmit",
"VkResult",
vec![
mk_param("queue", "VkQueue", "VkQueue queue"),
mk_param("submitCount", "uint32_t", "uint32_t submitCount"),
p,
mk_param("fence", "VkFence", "VkFence fence"),
],
);
let (sig, body) = emit_method(&cmd, Target::Queue).expect("emitted");
assert!(sig.contains("fn queue_submit(&self"));
assert!(sig.contains("pSubmits: &[crate::raw::bindings::VkSubmitInfo]"));
assert!(sig.contains("fence: crate::raw::bindings::VkFence"));
assert!(!sig.contains("submitCount"));
assert!(body.contains("pSubmits.len() as u32"));
assert!(body.contains("pSubmits.is_empty()"));
assert!(body.contains("pSubmits.as_ptr()"));
}
#[test]
fn shape_h_multi_slice_sharing_count_coalesces_both() {
let mut buffers = mk_param("pBuffers", "VkBuffer", "const VkBuffer* pBuffers");
buffers.len = Some("bindingCount".to_string());
let mut offsets = mk_param("pOffsets", "VkDeviceSize", "const VkDeviceSize* pOffsets");
offsets.len = Some("bindingCount".to_string());
let cmd = mk_cmd(
"vkCmdBindVertexBuffers",
"void",
vec![
mk_param(
"commandBuffer",
"VkCommandBuffer",
"VkCommandBuffer commandBuffer",
),
mk_param("firstBinding", "uint32_t", "uint32_t firstBinding"),
mk_param("bindingCount", "uint32_t", "uint32_t bindingCount"),
buffers,
offsets,
],
);
assert!(emit_method(&cmd, Target::CommandBuffer).is_none());
}
#[test]
fn shape_i_enumerate_emits_vec_return() {
let mut data = mk_param(
"pPhysicalDevices",
"VkPhysicalDevice",
"VkPhysicalDevice* pPhysicalDevices",
);
data.len = Some("pPhysicalDeviceCount".to_string());
let cmd = mk_cmd(
"vkEnumeratePhysicalDevices",
"VkResult",
vec![
mk_param("instance", "VkInstance", "VkInstance instance"),
mk_param(
"pPhysicalDeviceCount",
"uint32_t",
"uint32_t* pPhysicalDeviceCount",
),
data,
],
);
let (sig, body) = emit_method(&cmd, Target::Instance).expect("emitted");
assert!(sig.contains("fn enumerate_physical_devices(&self)"));
assert!(
sig.contains("-> crate::safe::Result<Vec<crate::raw::bindings::VkPhysicalDevice>>")
);
assert_eq!(body.matches("unsafe {").count(), 3); assert!(body.contains("Vec::with_capacity"));
assert!(body.contains("std::ptr::null_mut()"));
assert!(body.contains("__enumerate_buf.as_mut_ptr()"));
assert!(body.contains("set_len"));
}
#[test]
fn shape_i_enumerate_with_extra_scalar_arg() {
let mut data = mk_param(
"pQueueFamilyProperties",
"VkQueueFamilyProperties",
"VkQueueFamilyProperties* pQueueFamilyProperties",
);
data.len = Some("pQueueFamilyPropertyCount".to_string());
let cmd = mk_cmd(
"vkGetPhysicalDeviceQueueFamilyProperties",
"void",
vec![
mk_param(
"physicalDevice",
"VkPhysicalDevice",
"VkPhysicalDevice physicalDevice",
),
mk_param(
"pQueueFamilyPropertyCount",
"uint32_t",
"uint32_t* pQueueFamilyPropertyCount",
),
data,
],
);
let (sig, _body) = emit_method(&cmd, Target::PhysicalDevice).expect("emitted");
assert!(sig.contains("fn get_physical_device_queue_family_properties(&self)"));
assert!(sig.contains("-> Vec<crate::raw::bindings::VkQueueFamilyProperties>"));
}
#[test]
fn numeric_len_skipped_not_coalesced() {
let mut p = mk_param("blendConstants", "float", "const float* blendConstants");
p.len = Some("4".to_string());
let cmd = mk_cmd(
"vkCmdSetBlendConstants",
"void",
vec![
mk_param(
"commandBuffer",
"VkCommandBuffer",
"VkCommandBuffer commandBuffer",
),
p,
],
);
let (sig, _body) = emit_method(&cmd, Target::CommandBuffer).expect("emitted");
assert!(
sig.contains("blendConstants: &core::ffi::c_float")
|| sig.contains("blendConstants: &f32")
);
}
#[test]
fn allocator_callback_is_omitted() {
let cmd = mk_cmd(
"vkDestroyFence",
"void",
vec![
mk_param("device", "VkDevice", "VkDevice device"),
mk_param("fence", "VkFence", "VkFence fence"),
mk_param(
"pAllocator",
"VkAllocationCallbacks",
"const VkAllocationCallbacks* pAllocator",
),
],
);
assert!(is_phase1_handled_command(&cmd.name));
let cmd_renamed = mk_cmd("vkTrimCommandPool", "void", cmd.parameters.clone());
let (sig, body) = emit_method(&cmd_renamed, Target::Device).expect("emitted");
assert!(sig.contains("fence: crate::raw::bindings::VkFence"));
assert!(!sig.contains("pAllocator"));
assert!(body.contains("std::ptr::null()"));
}
}