1use anyhow::{Result, anyhow, bail};
2use compiler::{SymbolTable, eval_const_int_type};
3use dynamic::{Dynamic, Type};
4use smol_str::SmolStr;
5use std::{collections::BTreeMap, rc::Rc};
6
7#[derive(Debug, Clone)]
8pub struct GpuStructLayout {
9 pub name: SmolStr,
10 pub ty: Type,
11 pub size: usize,
12 pub align: usize,
13 pub fields: Vec<GpuFieldLayout>,
14}
15
16#[derive(Debug, Clone)]
17pub struct GpuFieldLayout {
18 pub name: SmolStr,
19 pub ty: Type,
20 pub offset: usize,
21 pub size: usize,
22 pub align: usize,
23}
24
25impl GpuStructLayout {
26 pub(crate) fn from_symbol_table(symbols: &SymbolTable, name: &str, params: &[Type]) -> Result<Self> {
27 let id = symbols.get_id(name)?;
28 let ty = resolve_gpu_type(symbols, &Type::Symbol { id, params: params.to_vec() })?;
29 Self::from_type(name.into(), ty)
30 }
31
32 pub(crate) fn from_type(name: SmolStr, ty: Type) -> Result<Self> {
33 let Type::Struct { fields, .. } = &ty else {
34 bail!("{name} is not a struct type: {ty:?}");
35 };
36 let (size, offsets) = Type::struct_layout(fields);
37 let fields = fields.iter().zip(offsets).map(|((name, ty), offset)| GpuFieldLayout { name: name.clone(), ty: ty.clone(), offset: offset as usize, size: gpu_type_size(ty), align: gpu_type_align(ty) }).collect();
38 let align = gpu_type_align(&ty);
39 Ok(Self { name, ty, size: size as usize, align, fields })
40 }
41
42 pub fn pack_map(&self, value: &Dynamic) -> Result<Vec<u8>> {
43 let mut bytes = vec![0; self.size];
44 self.write_map(value, &mut bytes)?;
45 Ok(bytes)
46 }
47
48 pub fn write_map(&self, value: &Dynamic, dst: &mut [u8]) -> Result<()> {
49 if dst.len() < self.size {
50 bail!("destination buffer too small for {}: need {}, got {}", self.name, self.size, dst.len());
51 }
52 write_value(&self.ty, value, &mut dst[..self.size])
53 }
54
55 pub fn unpack_map(&self, bytes: &[u8]) -> Result<Dynamic> {
56 if bytes.len() < self.size {
57 bail!("source buffer too small for {}: need {}, got {}", self.name, self.size, bytes.len());
58 }
59 read_value(&self.ty, &bytes[..self.size])
60 }
61}
62
63pub(crate) fn resolve_gpu_type(symbols: &SymbolTable, ty: &Type) -> Result<Type> {
64 let resolved = symbols.get_type(ty).unwrap_or_else(|_| ty.clone());
65 match resolved {
66 Type::Symbol { id, params } => {
67 let params = params.iter().map(|param| resolve_gpu_type(symbols, param)).collect::<Result<Vec<_>>>()?;
68 let resolved = symbols.get_type(&Type::Symbol { id, params })?;
69 resolve_gpu_type(symbols, &resolved)
70 }
71 Type::Ident { name, params } => {
72 let params = params.iter().map(|param| resolve_gpu_type(symbols, param)).collect::<Result<Vec<_>>>()?;
73 let ident = Type::Ident { name, params };
74 match symbols.get_type(&ident) {
75 Ok(resolved) => resolve_gpu_type(symbols, &resolved),
76 Err(_) => Ok(ident),
77 }
78 }
79 Type::Struct { params, fields } => Ok(Type::Struct {
80 params,
81 fields: fields
82 .iter()
83 .filter_map(|(name, ty)| match resolve_gpu_type(symbols, ty) {
84 Ok(Type::Fn { .. }) => None,
85 Ok(ty) => Some(Ok((name.clone(), ty))),
86 Err(err) => Some(Err(err)),
87 })
88 .collect::<Result<Vec<_>>>()?,
89 }),
90 Type::Vec(elem, len) => Ok(Type::Vec(Rc::new(resolve_gpu_type(symbols, &elem)?), len)),
91 Type::Array(elem, len) => Ok(Type::Array(Rc::new(resolve_gpu_type(symbols, &elem)?), len)),
92 Type::ArrayParam(elem, len) => {
93 let elem = resolve_gpu_type(symbols, &elem)?;
94 let len = resolve_gpu_type(symbols, &len)?;
95 if let Some(len) = eval_const_int_type(&len) {
96 let len = u32::try_from(len).map_err(|_| anyhow!("array length out of u32 range"))?;
97 Ok(Type::Array(Rc::new(elem), len))
98 } else {
99 Ok(Type::ArrayParam(Rc::new(elem), Rc::new(len)))
100 }
101 }
102 Type::Fn { tys, ret } => Ok(Type::Fn { tys: tys.iter().map(|ty| resolve_gpu_type(symbols, ty)).collect::<Result<Vec<_>>>()?, ret: Rc::new(resolve_gpu_type(symbols, &ret)?) }),
103 Type::Tuple(items) => Ok(Type::Tuple(items.iter().map(|ty| resolve_gpu_type(symbols, ty)).collect::<Result<Vec<_>>>()?)),
104 other => Ok(other),
105 }
106}
107
108fn gpu_type_align(ty: &Type) -> usize {
109 ty.align().min(8).max(1) as usize
110}
111
112fn gpu_type_size(ty: &Type) -> usize {
113 ty.storage_width() as usize
114}
115
116fn write_value(ty: &Type, value: &Dynamic, dst: &mut [u8]) -> Result<()> {
117 match ty {
118 Type::Bool => write_scalar(dst, if value.is_true() { 1u8 } else { 0u8 }),
119 Type::U8 => write_scalar(dst, u8::try_from(value.clone())?),
120 Type::I8 => write_scalar(dst, i8::try_from(value.clone())?),
121 Type::U16 => write_scalar(dst, u16::try_from(value.clone())?),
122 Type::I16 => write_scalar(dst, i16::try_from(value.clone())?),
123 Type::U32 => write_scalar(dst, u32::try_from(value.clone())?),
124 Type::I32 => write_scalar(dst, i32::try_from(value.clone())?),
125 Type::U64 => write_scalar(dst, u64::try_from(value.clone())?),
126 Type::I64 => write_scalar(dst, i64::try_from(value.clone())?),
127 Type::F32 => write_scalar(dst, f32::try_from(value.clone())?),
128 Type::F64 => write_scalar(dst, f64::try_from(value.clone())?),
129 Type::Struct { fields, .. } => {
130 let (_, offsets) = Type::struct_layout(fields);
131 for ((field_name, field_ty), offset) in fields.iter().zip(offsets) {
132 let field_value = value.get_dynamic(field_name.as_str()).ok_or_else(|| anyhow!("missing struct field {field_name}"))?;
133 let start = offset as usize;
134 let end = start + gpu_type_size(field_ty);
135 write_value(field_ty, &field_value, dst.get_mut(start..end).ok_or_else(|| anyhow!("field {field_name} out of buffer"))?)?;
136 }
137 Ok(())
138 }
139 Type::Array(elem, len) | Type::Vec(elem, len) if *len > 0 => write_sequence(elem, *len as usize, value, dst),
140 Type::Tuple(items) => {
141 for (idx, item_ty) in items.iter().enumerate() {
142 let item = value.get_idx(idx).ok_or_else(|| anyhow!("missing tuple item {idx}"))?;
143 let start = tuple_offset(items, idx);
144 let end = start + gpu_type_size(item_ty);
145 write_value(item_ty, &item, dst.get_mut(start..end).ok_or_else(|| anyhow!("tuple item {idx} out of buffer"))?)?;
146 }
147 Ok(())
148 }
149 Type::Void => Ok(()),
150 other => bail!("unsupported GPU layout write type: {other:?}"),
151 }
152}
153
154fn write_sequence(elem: &Type, len: usize, value: &Dynamic, dst: &mut [u8]) -> Result<()> {
155 let stride = gpu_type_size(elem);
156 if value.len() < len {
157 bail!("sequence needs {len} items, got {}", value.len());
158 }
159 for idx in 0..len {
160 let item = value.get_idx(idx).ok_or_else(|| anyhow!("missing sequence item {idx}"))?;
161 let start = idx * stride;
162 let end = start + stride;
163 write_value(elem, &item, dst.get_mut(start..end).ok_or_else(|| anyhow!("sequence item {idx} out of buffer"))?)?;
164 }
165 Ok(())
166}
167
168fn read_value(ty: &Type, src: &[u8]) -> Result<Dynamic> {
169 match ty {
170 Type::Bool => Ok(Dynamic::Bool(read_scalar::<u8>(src)? != 0)),
171 Type::U8 => Ok(Dynamic::U8(read_scalar(src)?)),
172 Type::I8 => Ok(Dynamic::I8(read_scalar(src)?)),
173 Type::U16 => Ok(Dynamic::U16(read_scalar(src)?)),
174 Type::I16 => Ok(Dynamic::I16(read_scalar(src)?)),
175 Type::U32 => Ok(Dynamic::U32(read_scalar(src)?)),
176 Type::I32 => Ok(Dynamic::I32(read_scalar(src)?)),
177 Type::U64 => Ok(Dynamic::U64(read_scalar(src)?)),
178 Type::I64 => Ok(Dynamic::I64(read_scalar(src)?)),
179 Type::F32 => Ok(Dynamic::F32(read_scalar(src)?)),
180 Type::F64 => Ok(Dynamic::F64(read_scalar(src)?)),
181 Type::Struct { fields, .. } => {
182 let (_, offsets) = Type::struct_layout(fields);
183 let mut map = BTreeMap::new();
184 for ((field_name, field_ty), offset) in fields.iter().zip(offsets) {
185 let start = offset as usize;
186 let end = start + gpu_type_size(field_ty);
187 let value = read_value(field_ty, src.get(start..end).ok_or_else(|| anyhow!("field {field_name} out of buffer"))?)?;
188 map.insert(field_name.clone(), value);
189 }
190 Ok(Dynamic::map(map))
191 }
192 Type::Array(elem, len) | Type::Vec(elem, len) if *len > 0 => read_sequence(elem, *len as usize, src),
193 Type::Tuple(items) => {
194 let values = items
195 .iter()
196 .enumerate()
197 .map(|(idx, item_ty)| {
198 let start = tuple_offset(items, idx);
199 let end = start + gpu_type_size(item_ty);
200 read_value(item_ty, src.get(start..end).ok_or_else(|| anyhow!("tuple item {idx} out of buffer"))?)
201 })
202 .collect::<Result<Vec<_>>>()?;
203 Ok(Dynamic::list(values))
204 }
205 Type::Void => Ok(Dynamic::Null),
206 other => bail!("unsupported GPU layout read type: {other:?}"),
207 }
208}
209
210fn read_sequence(elem: &Type, len: usize, src: &[u8]) -> Result<Dynamic> {
211 let stride = gpu_type_size(elem);
212 let values = (0..len)
213 .map(|idx| {
214 let start = idx * stride;
215 let end = start + stride;
216 read_value(elem, src.get(start..end).ok_or_else(|| anyhow!("sequence item {idx} out of buffer"))?)
217 })
218 .collect::<Result<Vec<_>>>()?;
219 Ok(Dynamic::list(values))
220}
221
222fn tuple_offset(items: &[Type], idx: usize) -> usize {
223 items.iter().take(idx).map(gpu_type_size).sum()
224}
225
226fn write_scalar<T: bytemuck::NoUninit>(dst: &mut [u8], value: T) -> Result<()> {
227 let bytes = bytemuck::bytes_of(&value);
228 let target = dst.get_mut(..bytes.len()).ok_or_else(|| anyhow!("buffer too small for scalar"))?;
229 target.copy_from_slice(bytes);
230 Ok(())
231}
232
233fn read_scalar<T: bytemuck::AnyBitPattern>(src: &[u8]) -> Result<T> {
234 let size = std::mem::size_of::<T>();
235 let bytes = src.get(..size).ok_or_else(|| anyhow!("buffer too small for scalar"))?;
236 Ok(bytemuck::pod_read_unaligned(bytes))
237}