use alloc::format;
use core::{fmt, mem};
use super::{super::FunctionCtx, BackendResult, Error};
use crate::{
proc::{Alignment, NameKey, TypeResolution},
Handle,
};
const STORE_TEMP_NAME: &str = "_value";
#[derive(Debug)]
pub(super) enum SubAccess {
BufferOffset {
group: u32,
offset: u32,
},
Offset(u32),
Index {
value: Handle<crate::Expression>,
stride: u32,
},
}
pub(super) enum StoreValue {
Expression(Handle<crate::Expression>),
TempIndex {
depth: usize,
index: u32,
ty: TypeResolution,
},
TempAccess {
depth: usize,
base: Handle<crate::Type>,
member_index: u32,
},
TempColumnAccess {
depth: usize,
base: Handle<crate::Type>,
member_index: u32,
column: u32,
},
}
impl<W: fmt::Write> super::Writer<'_, W> {
pub(super) fn write_storage_address(
&mut self,
module: &crate::Module,
chain: &[SubAccess],
func_ctx: &FunctionCtx,
) -> BackendResult {
if chain.is_empty() {
write!(self.out, "0")?;
}
for (i, access) in chain.iter().enumerate() {
if i != 0 {
write!(self.out, "+")?;
}
match *access {
SubAccess::BufferOffset { group, offset } => {
write!(self.out, "__dynamic_buffer_offsets{group}._{offset}")?;
}
SubAccess::Offset(offset) => {
write!(self.out, "{offset}")?;
}
SubAccess::Index { value, stride } => {
self.write_expr(module, value, func_ctx)?;
write!(self.out, "*{stride}")?;
}
}
}
Ok(())
}
fn write_storage_load_sequence<I: Iterator<Item = (TypeResolution, u32)>>(
&mut self,
module: &crate::Module,
var_handle: Handle<crate::GlobalVariable>,
sequence: I,
func_ctx: &FunctionCtx,
) -> BackendResult {
for (i, (ty_resolution, offset)) in sequence.enumerate() {
self.temp_access_chain.push(SubAccess::Offset(offset));
if i != 0 {
write!(self.out, ", ")?;
};
self.write_storage_load(module, var_handle, ty_resolution, func_ctx)?;
self.temp_access_chain.pop();
}
Ok(())
}
pub(super) fn write_storage_load(
&mut self,
module: &crate::Module,
var_handle: Handle<crate::GlobalVariable>,
result_ty: TypeResolution,
func_ctx: &FunctionCtx,
) -> BackendResult {
match *result_ty.inner_with(&module.types) {
crate::TypeInner::Scalar(scalar) => {
let chain = mem::take(&mut self.temp_access_chain);
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
if scalar.width == 4 {
let cast = scalar.kind.to_hlsl_cast();
write!(self.out, "{cast}({var_name}.Load(")?;
} else {
let ty = scalar.to_hlsl_str()?;
write!(self.out, "{var_name}.Load<{ty}>(")?;
};
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ")")?;
if scalar.width == 4 {
write!(self.out, ")")?;
}
self.temp_access_chain = chain;
}
crate::TypeInner::Vector { size, scalar } => {
let chain = mem::take(&mut self.temp_access_chain);
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
let size = size as u8;
if scalar.width == 4 {
let cast = scalar.kind.to_hlsl_cast();
write!(self.out, "{cast}({var_name}.Load{size}(")?;
} else {
let ty = scalar.to_hlsl_str()?;
write!(self.out, "{var_name}.Load<{ty}{size}>(")?;
};
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ")")?;
if scalar.width == 4 {
write!(self.out, ")")?;
}
self.temp_access_chain = chain;
}
crate::TypeInner::Matrix {
columns,
rows,
scalar,
} => {
write!(
self.out,
"{}{}x{}(",
scalar.to_hlsl_str()?,
columns as u8,
rows as u8,
)?;
let row_stride = Alignment::from(rows) * scalar.width as u32;
let iter = (0..columns as u32).map(|i| {
let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
(TypeResolution::Value(ty_inner), i * row_stride)
});
self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
write!(self.out, ")")?;
}
crate::TypeInner::Array {
base,
size: crate::ArraySize::Constant(size),
stride,
} => {
let constructor = super::help::WrappedConstructor {
ty: result_ty.handle().unwrap(),
};
self.write_wrapped_constructor_function_name(module, constructor)?;
write!(self.out, "(")?;
let iter = (0..size.get()).map(|i| (TypeResolution::Handle(base), stride * i));
self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
write!(self.out, ")")?;
}
crate::TypeInner::Struct { ref members, .. } => {
let constructor = super::help::WrappedConstructor {
ty: result_ty.handle().unwrap(),
};
self.write_wrapped_constructor_function_name(module, constructor)?;
write!(self.out, "(")?;
let iter = members
.iter()
.map(|m| (TypeResolution::Handle(m.ty), m.offset));
self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
write!(self.out, ")")?;
}
_ => unreachable!(),
}
Ok(())
}
fn write_store_value(
&mut self,
module: &crate::Module,
value: &StoreValue,
func_ctx: &FunctionCtx,
) -> BackendResult {
match *value {
StoreValue::Expression(expr) => self.write_expr(module, expr, func_ctx)?,
StoreValue::TempIndex {
depth,
index,
ty: _,
} => write!(self.out, "{STORE_TEMP_NAME}{depth}[{index}]")?,
StoreValue::TempAccess {
depth,
base,
member_index,
} => {
let name = &self.names[&NameKey::StructMember(base, member_index)];
write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}")?
}
StoreValue::TempColumnAccess {
depth,
base,
member_index,
column,
} => {
let name = &self.names[&NameKey::StructMember(base, member_index)];
write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}_{column}")?
}
}
Ok(())
}
pub(super) fn write_storage_store(
&mut self,
module: &crate::Module,
var_handle: Handle<crate::GlobalVariable>,
value: StoreValue,
func_ctx: &FunctionCtx,
level: crate::back::Level,
within_struct: Option<Handle<crate::Type>>,
) -> BackendResult {
let temp_resolution;
let ty_resolution = match value {
StoreValue::Expression(expr) => &func_ctx.info[expr].ty,
StoreValue::TempIndex {
depth: _,
index: _,
ref ty,
} => ty,
StoreValue::TempAccess {
depth: _,
base,
member_index,
} => {
let ty_handle = match module.types[base].inner {
crate::TypeInner::Struct { ref members, .. } => {
members[member_index as usize].ty
}
_ => unreachable!(),
};
temp_resolution = TypeResolution::Handle(ty_handle);
&temp_resolution
}
StoreValue::TempColumnAccess { .. } => {
unreachable!("attempting write_storage_store for TempColumnAccess");
}
};
match *ty_resolution.inner_with(&module.types) {
crate::TypeInner::Scalar(scalar) => {
let chain = mem::take(&mut self.temp_access_chain);
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
if scalar.width == 4 {
write!(self.out, "{level}{var_name}.Store(")?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ", asuint(")?;
self.write_store_value(module, &value, func_ctx)?;
writeln!(self.out, "));")?;
} else {
write!(self.out, "{level}{var_name}.Store(")?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ", ")?;
self.write_store_value(module, &value, func_ctx)?;
writeln!(self.out, ");")?;
}
self.temp_access_chain = chain;
}
crate::TypeInner::Vector { size, scalar } => {
let chain = mem::take(&mut self.temp_access_chain);
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
if scalar.width == 4 {
write!(self.out, "{}{}.Store{}(", level, var_name, size as u8)?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ", asuint(")?;
self.write_store_value(module, &value, func_ctx)?;
writeln!(self.out, "));")?;
} else {
write!(self.out, "{level}{var_name}.Store(")?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ", ")?;
self.write_store_value(module, &value, func_ctx)?;
writeln!(self.out, ");")?;
}
self.temp_access_chain = chain;
}
crate::TypeInner::Matrix {
columns,
rows,
scalar,
} => {
let row_stride = Alignment::from(rows) * scalar.width as u32;
writeln!(self.out, "{level}{{")?;
match within_struct {
Some(containing_struct) if rows == crate::VectorSize::Bi => {
let mut chain = mem::take(&mut self.temp_access_chain);
for i in 0..columns as u32 {
chain.push(SubAccess::Offset(i * row_stride));
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
let StoreValue::TempAccess { member_index, .. } = value else {
unreachable!(
"write_storage_store within_struct but not TempAccess"
);
};
let column_value = StoreValue::TempColumnAccess {
depth: level.0, base: containing_struct,
member_index,
column: i,
};
if scalar.width == 4 {
write!(
self.out,
"{}{}.Store{}(",
level.next(),
var_name,
rows as u8
)?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ", asuint(")?;
self.write_store_value(module, &column_value, func_ctx)?;
writeln!(self.out, "));")?;
} else {
write!(self.out, "{}{var_name}.Store(", level.next())?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ", ")?;
self.write_store_value(module, &column_value, func_ctx)?;
writeln!(self.out, ");")?;
}
chain.pop();
}
self.temp_access_chain = chain;
}
_ => {
let depth = level.0 + 1;
write!(
self.out,
"{}{}{}x{} {}{} = ",
level.next(),
scalar.to_hlsl_str()?,
columns as u8,
rows as u8,
STORE_TEMP_NAME,
depth,
)?;
self.write_store_value(module, &value, func_ctx)?;
writeln!(self.out, ";")?;
for i in 0..columns as u32 {
self.temp_access_chain
.push(SubAccess::Offset(i * row_stride));
let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
let sv = StoreValue::TempIndex {
depth,
index: i,
ty: TypeResolution::Value(ty_inner),
};
self.write_storage_store(
module,
var_handle,
sv,
func_ctx,
level.next(),
None,
)?;
self.temp_access_chain.pop();
}
}
}
writeln!(self.out, "{level}}}")?;
}
crate::TypeInner::Array {
base,
size: crate::ArraySize::Constant(size),
stride,
} => {
writeln!(self.out, "{level}{{")?;
write!(self.out, "{}", level.next())?;
self.write_type(module, base)?;
let depth = level.next().0;
write!(self.out, " {STORE_TEMP_NAME}{depth}")?;
self.write_array_size(module, base, crate::ArraySize::Constant(size))?;
write!(self.out, " = ")?;
self.write_store_value(module, &value, func_ctx)?;
writeln!(self.out, ";")?;
for i in 0..size.get() {
self.temp_access_chain.push(SubAccess::Offset(i * stride));
let sv = StoreValue::TempIndex {
depth,
index: i,
ty: TypeResolution::Handle(base),
};
self.write_storage_store(module, var_handle, sv, func_ctx, level.next(), None)?;
self.temp_access_chain.pop();
}
writeln!(self.out, "{level}}}")?;
}
crate::TypeInner::Struct { ref members, .. } => {
writeln!(self.out, "{level}{{")?;
let depth = level.next().0;
let struct_ty = ty_resolution.handle().unwrap();
let struct_name = &self.names[&NameKey::Type(struct_ty)];
write!(
self.out,
"{}{} {}{} = ",
level.next(),
struct_name,
STORE_TEMP_NAME,
depth
)?;
self.write_store_value(module, &value, func_ctx)?;
writeln!(self.out, ";")?;
for (i, member) in members.iter().enumerate() {
self.temp_access_chain
.push(SubAccess::Offset(member.offset));
let sv = StoreValue::TempAccess {
depth,
base: struct_ty,
member_index: i as u32,
};
self.write_storage_store(
module,
var_handle,
sv,
func_ctx,
level.next(),
Some(struct_ty),
)?;
self.temp_access_chain.pop();
}
writeln!(self.out, "{level}}}")?;
}
_ => unreachable!(),
}
Ok(())
}
pub(super) fn fill_access_chain(
&mut self,
module: &crate::Module,
mut cur_expr: Handle<crate::Expression>,
func_ctx: &FunctionCtx,
) -> Result<Handle<crate::GlobalVariable>, Error> {
enum AccessIndex {
Expression(Handle<crate::Expression>),
Constant(u32),
}
enum Parent<'a> {
Array { stride: u32 },
Struct(&'a [crate::StructMember]),
}
self.temp_access_chain.clear();
loop {
let (next_expr, access_index) = match func_ctx.expressions[cur_expr] {
crate::Expression::GlobalVariable(handle) => {
if let Some(ref binding) = module.global_variables[handle].binding {
let bt = self.options.resolve_resource_binding(binding).unwrap();
if let Some(dynamic_storage_buffer_offsets_index) =
bt.dynamic_storage_buffer_offsets_index
{
self.temp_access_chain.push(SubAccess::BufferOffset {
group: binding.group,
offset: dynamic_storage_buffer_offsets_index,
});
}
}
return Ok(handle);
}
crate::Expression::Access { base, index } => (base, AccessIndex::Expression(index)),
crate::Expression::AccessIndex { base, index } => {
(base, AccessIndex::Constant(index))
}
ref other => {
return Err(Error::Unimplemented(format!("Pointer access of {other:?}")))
}
};
let parent = match *func_ctx.resolve_type(next_expr, &module.types) {
crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
crate::TypeInner::Struct { ref members, .. } => Parent::Struct(members),
crate::TypeInner::Array { stride, .. } => Parent::Array { stride },
crate::TypeInner::Vector { scalar, .. } => Parent::Array {
stride: scalar.width as u32,
},
crate::TypeInner::Matrix { rows, scalar, .. } => Parent::Array {
stride: Alignment::from(rows) * scalar.width as u32,
},
_ => unreachable!(),
},
crate::TypeInner::ValuePointer { scalar, .. } => Parent::Array {
stride: scalar.width as u32,
},
_ => unreachable!(),
};
let sub = match (parent, access_index) {
(Parent::Array { stride }, AccessIndex::Expression(value)) => {
SubAccess::Index { value, stride }
}
(Parent::Array { stride }, AccessIndex::Constant(index)) => {
SubAccess::Offset(stride * index)
}
(Parent::Struct(members), AccessIndex::Constant(index)) => {
SubAccess::Offset(members[index as usize].offset)
}
(Parent::Struct(_), AccessIndex::Expression(_)) => unreachable!(),
};
self.temp_access_chain.push(sub);
cur_expr = next_expr;
}
}
}