Skip to main content

vm/
gpu_layout.rs

1use anyhow::{Result, anyhow, bail};
2use compiler::{SymbolTable, eval_const_int_type};
3use dynamic::{Dynamic, Type};
4use smol_str::SmolStr;
5use std::{collections::BTreeMap, rc::Rc};
6
7#[derive(Debug, Clone)]
8pub struct GpuStructLayout {
9    pub name: SmolStr,
10    pub ty: Type,
11    pub size: usize,
12    pub align: usize,
13    pub fields: Vec<GpuFieldLayout>,
14}
15
16#[derive(Debug, Clone)]
17pub struct GpuFieldLayout {
18    pub name: SmolStr,
19    pub ty: Type,
20    pub offset: usize,
21    pub size: usize,
22    pub align: usize,
23}
24
25impl GpuStructLayout {
26    pub(crate) fn from_symbol_table(symbols: &SymbolTable, name: &str, params: &[Type]) -> Result<Self> {
27        let id = symbols.get_id(name)?;
28        let ty = resolve_gpu_type(symbols, &Type::Symbol { id, params: params.to_vec() })?;
29        Self::from_type(name.into(), ty)
30    }
31
32    pub(crate) fn from_type(name: SmolStr, ty: Type) -> Result<Self> {
33        let Type::Struct { fields, .. } = &ty else {
34            bail!("{name} is not a struct type: {ty:?}");
35        };
36        let (size, offsets) = Type::struct_layout(fields);
37        let fields = fields.iter().zip(offsets).map(|((name, ty), offset)| GpuFieldLayout { name: name.clone(), ty: ty.clone(), offset: offset as usize, size: gpu_type_size(ty), align: gpu_type_align(ty) }).collect();
38        let align = gpu_type_align(&ty);
39        Ok(Self { name, ty, size: size as usize, align, fields })
40    }
41
42    pub fn pack_map(&self, value: &Dynamic) -> Result<Vec<u8>> {
43        let mut bytes = vec![0; self.size];
44        self.write_map(value, &mut bytes)?;
45        Ok(bytes)
46    }
47
48    pub fn write_map(&self, value: &Dynamic, dst: &mut [u8]) -> Result<()> {
49        if dst.len() < self.size {
50            bail!("destination buffer too small for {}: need {}, got {}", self.name, self.size, dst.len());
51        }
52        write_value(&self.ty, value, &mut dst[..self.size])
53    }
54
55    pub fn unpack_map(&self, bytes: &[u8]) -> Result<Dynamic> {
56        if bytes.len() < self.size {
57            bail!("source buffer too small for {}: need {}, got {}", self.name, self.size, bytes.len());
58        }
59        read_value(&self.ty, &bytes[..self.size])
60    }
61}
62
63pub(crate) fn resolve_gpu_type(symbols: &SymbolTable, ty: &Type) -> Result<Type> {
64    let resolved = symbols.get_type(ty).unwrap_or_else(|_| ty.clone());
65    match resolved {
66        Type::Symbol { id, params } => {
67            let params = params.iter().map(|param| resolve_gpu_type(symbols, param)).collect::<Result<Vec<_>>>()?;
68            let resolved = symbols.get_type(&Type::Symbol { id, params })?;
69            resolve_gpu_type(symbols, &resolved)
70        }
71        Type::Ident { name, params } => {
72            let params = params.iter().map(|param| resolve_gpu_type(symbols, param)).collect::<Result<Vec<_>>>()?;
73            let ident = Type::Ident { name, params };
74            match symbols.get_type(&ident) {
75                Ok(resolved) => resolve_gpu_type(symbols, &resolved),
76                Err(_) => Ok(ident),
77            }
78        }
79        Type::Struct { params, fields } => Ok(Type::Struct {
80            params,
81            fields: fields
82                .iter()
83                .filter_map(|(name, ty)| match resolve_gpu_type(symbols, ty) {
84                    Ok(Type::Fn { .. }) => None,
85                    Ok(ty) => Some(Ok((name.clone(), ty))),
86                    Err(err) => Some(Err(err)),
87                })
88                .collect::<Result<Vec<_>>>()?,
89        }),
90        Type::Vec(elem, len) => Ok(Type::Vec(Rc::new(resolve_gpu_type(symbols, &elem)?), len)),
91        Type::Array(elem, len) => Ok(Type::Array(Rc::new(resolve_gpu_type(symbols, &elem)?), len)),
92        Type::ArrayParam(elem, len) => {
93            let elem = resolve_gpu_type(symbols, &elem)?;
94            let len = resolve_gpu_type(symbols, &len)?;
95            if let Some(len) = eval_const_int_type(&len) {
96                let len = u32::try_from(len).map_err(|_| anyhow!("array length out of u32 range"))?;
97                Ok(Type::Array(Rc::new(elem), len))
98            } else {
99                Ok(Type::ArrayParam(Rc::new(elem), Rc::new(len)))
100            }
101        }
102        Type::Fn { tys, ret } => Ok(Type::Fn { tys: tys.iter().map(|ty| resolve_gpu_type(symbols, ty)).collect::<Result<Vec<_>>>()?, ret: Rc::new(resolve_gpu_type(symbols, &ret)?) }),
103        Type::Tuple(items) => Ok(Type::Tuple(items.iter().map(|ty| resolve_gpu_type(symbols, ty)).collect::<Result<Vec<_>>>()?)),
104        other => Ok(other),
105    }
106}
107
108fn gpu_type_align(ty: &Type) -> usize {
109    ty.align().min(8).max(1) as usize
110}
111
112fn gpu_type_size(ty: &Type) -> usize {
113    ty.storage_width() as usize
114}
115
116fn write_value(ty: &Type, value: &Dynamic, dst: &mut [u8]) -> Result<()> {
117    match ty {
118        Type::Bool => write_scalar(dst, if value.is_true() { 1u8 } else { 0u8 }),
119        Type::U8 => write_scalar(dst, u8::try_from(value.clone())?),
120        Type::I8 => write_scalar(dst, i8::try_from(value.clone())?),
121        Type::U16 => write_scalar(dst, u16::try_from(value.clone())?),
122        Type::I16 => write_scalar(dst, i16::try_from(value.clone())?),
123        Type::U32 => write_scalar(dst, u32::try_from(value.clone())?),
124        Type::I32 => write_scalar(dst, i32::try_from(value.clone())?),
125        Type::U64 => write_scalar(dst, u64::try_from(value.clone())?),
126        Type::I64 => write_scalar(dst, i64::try_from(value.clone())?),
127        Type::F32 => write_scalar(dst, f32::try_from(value.clone())?),
128        Type::F64 => write_scalar(dst, f64::try_from(value.clone())?),
129        Type::Struct { fields, .. } => {
130            let (_, offsets) = Type::struct_layout(fields);
131            for ((field_name, field_ty), offset) in fields.iter().zip(offsets) {
132                let field_value = value.get_dynamic(field_name.as_str()).ok_or_else(|| anyhow!("missing struct field {field_name}"))?;
133                let start = offset as usize;
134                let end = start + gpu_type_size(field_ty);
135                write_value(field_ty, &field_value, dst.get_mut(start..end).ok_or_else(|| anyhow!("field {field_name} out of buffer"))?)?;
136            }
137            Ok(())
138        }
139        Type::Array(elem, len) | Type::Vec(elem, len) if *len > 0 => write_sequence(elem, *len as usize, value, dst),
140        Type::Tuple(items) => {
141            for (idx, item_ty) in items.iter().enumerate() {
142                let item = value.get_idx(idx).ok_or_else(|| anyhow!("missing tuple item {idx}"))?;
143                let start = tuple_offset(items, idx);
144                let end = start + gpu_type_size(item_ty);
145                write_value(item_ty, &item, dst.get_mut(start..end).ok_or_else(|| anyhow!("tuple item {idx} out of buffer"))?)?;
146            }
147            Ok(())
148        }
149        Type::Void => Ok(()),
150        other => bail!("unsupported GPU layout write type: {other:?}"),
151    }
152}
153
154fn write_sequence(elem: &Type, len: usize, value: &Dynamic, dst: &mut [u8]) -> Result<()> {
155    let stride = gpu_type_size(elem);
156    if value.len() < len {
157        bail!("sequence needs {len} items, got {}", value.len());
158    }
159    for idx in 0..len {
160        let item = value.get_idx(idx).ok_or_else(|| anyhow!("missing sequence item {idx}"))?;
161        let start = idx * stride;
162        let end = start + stride;
163        write_value(elem, &item, dst.get_mut(start..end).ok_or_else(|| anyhow!("sequence item {idx} out of buffer"))?)?;
164    }
165    Ok(())
166}
167
168fn read_value(ty: &Type, src: &[u8]) -> Result<Dynamic> {
169    match ty {
170        Type::Bool => Ok(Dynamic::Bool(read_scalar::<u8>(src)? != 0)),
171        Type::U8 => Ok(Dynamic::U8(read_scalar(src)?)),
172        Type::I8 => Ok(Dynamic::I8(read_scalar(src)?)),
173        Type::U16 => Ok(Dynamic::U16(read_scalar(src)?)),
174        Type::I16 => Ok(Dynamic::I16(read_scalar(src)?)),
175        Type::U32 => Ok(Dynamic::U32(read_scalar(src)?)),
176        Type::I32 => Ok(Dynamic::I32(read_scalar(src)?)),
177        Type::U64 => Ok(Dynamic::U64(read_scalar(src)?)),
178        Type::I64 => Ok(Dynamic::I64(read_scalar(src)?)),
179        Type::F32 => Ok(Dynamic::F32(read_scalar(src)?)),
180        Type::F64 => Ok(Dynamic::F64(read_scalar(src)?)),
181        Type::Struct { fields, .. } => {
182            let (_, offsets) = Type::struct_layout(fields);
183            let mut map = BTreeMap::new();
184            for ((field_name, field_ty), offset) in fields.iter().zip(offsets) {
185                let start = offset as usize;
186                let end = start + gpu_type_size(field_ty);
187                let value = read_value(field_ty, src.get(start..end).ok_or_else(|| anyhow!("field {field_name} out of buffer"))?)?;
188                map.insert(field_name.clone(), value);
189            }
190            Ok(Dynamic::map(map))
191        }
192        Type::Array(elem, len) | Type::Vec(elem, len) if *len > 0 => read_sequence(elem, *len as usize, src),
193        Type::Tuple(items) => {
194            let values = items
195                .iter()
196                .enumerate()
197                .map(|(idx, item_ty)| {
198                    let start = tuple_offset(items, idx);
199                    let end = start + gpu_type_size(item_ty);
200                    read_value(item_ty, src.get(start..end).ok_or_else(|| anyhow!("tuple item {idx} out of buffer"))?)
201                })
202                .collect::<Result<Vec<_>>>()?;
203            Ok(Dynamic::list(values))
204        }
205        Type::Void => Ok(Dynamic::Null),
206        other => bail!("unsupported GPU layout read type: {other:?}"),
207    }
208}
209
210fn read_sequence(elem: &Type, len: usize, src: &[u8]) -> Result<Dynamic> {
211    let stride = gpu_type_size(elem);
212    let values = (0..len)
213        .map(|idx| {
214            let start = idx * stride;
215            let end = start + stride;
216            read_value(elem, src.get(start..end).ok_or_else(|| anyhow!("sequence item {idx} out of buffer"))?)
217        })
218        .collect::<Result<Vec<_>>>()?;
219    Ok(Dynamic::list(values))
220}
221
222fn tuple_offset(items: &[Type], idx: usize) -> usize {
223    items.iter().take(idx).map(gpu_type_size).sum()
224}
225
226fn write_scalar<T: bytemuck::NoUninit>(dst: &mut [u8], value: T) -> Result<()> {
227    let bytes = bytemuck::bytes_of(&value);
228    let target = dst.get_mut(..bytes.len()).ok_or_else(|| anyhow!("buffer too small for scalar"))?;
229    target.copy_from_slice(bytes);
230    Ok(())
231}
232
233fn read_scalar<T: bytemuck::AnyBitPattern>(src: &[u8]) -> Result<T> {
234    let size = std::mem::size_of::<T>();
235    let bytes = src.get(..size).ok_or_else(|| anyhow!("buffer too small for scalar"))?;
236    Ok(bytemuck::pod_read_unaligned(bytes))
237}